wrapper

class model_compression_toolkit.wrapper.mct_wrapper.MCTWrapper

Wrapper class for Model Compression Toolkit (MCT) quantization and export.

This class provides a unified interface for various neural network quantization methods including Post-Training Quantization (PTQ), Gradient Post-Training Quantization (GPTQ). It supports both TensorFlow and PyTorch frameworks with optional mixed-precision quantization.

The wrapper manages the complete quantization pipeline from model input to quantized model export, handling framework-specific configurations and Target Platform Capabilities (TPC) setup.

Initialize MCTWrapper with default parameters.

Users can update the following parameters in param_items:

PTQ

Parameter Key

Default Value

Description

target_platform_version

‘v1’

Target platform version (use_internal_tpc=True)

tpc_version

‘5.0’

TPC version (use_internal_tpc=False)

activation_error_method

mct.core.QuantizationErrorMethod.MSE

Activation quantization error method

weights_bias_correction

True

Enable weights bias correction

z_threshold

float(‘inf’)

Z-threshold for quantization

linear_collapsing

True

Enable linear layer collapsing

residual_collapsing

True

Enable residual connection collapsing

save_model_path

‘./qmodel.keras’ / ‘./qmodel.onnx’

Path to save quantized model (Keras/Pytorch)

PTQ, mixed_precision

Parameter Key

Default Value

Description

target_platform_version

‘v1’

Target platform version (use_internal_tpc=True)

tpc_version

‘5.0’

TPC version (use_internal_tpc=False)

num_of_images

5

Number of images for mixed precision

use_hessian_based_scores

False

Use Hessian-based scores for mixed precision

weights_compression_ratio

None

Weights compression ratio for resource util

save_model_path

‘./qmodel.keras’ / ‘./qmodel.onnx’

Path to save quantized model (Keras/Pytorch)

GPTQ

Parameter Key

Default Value

Description

target_platform_version

‘v1’

Target platform version (use_internal_tpc=True)

tpc_version

‘5.0’

TPC version (use_internal_tpc=False)

n_epochs

5

Number of training epochs for GPTQ

optimizer

None

Optimizer for GPTQ training

save_model_path

‘./qmodel.keras’ / ‘./qmodel.onnx’

Path to save quantized model (Keras/Pytorch)

GPTQ, mixed_precision

Parameter Key

Default Value

Description

target_platform_version

‘v1’

Target platform version (use_internal_tpc=True)

tpc_version

‘5.0’

TPC version (use_internal_tpc=False)

n_epochs

5

Number of training epochs for GPTQ

optimizer

None

Optimizer for GPTQ training

num_of_images

5

Number of images for mixed precision

use_hessian_based_scores

False

Use Hessian-based scores for mixed precision

weights_compression_ratio

None

Weights compression ratio for resource util

save_model_path

‘./qmodel.keras’ / ‘./qmodel.onnx’

Path to save quantized model (Keras/Pytorch)

quantize_and_export(float_model, representative_dataset, method='PTQ', framework='pytorch', use_internal_tpc=True, use_mixed_precision=False, param_items=None)

Main function to perform model quantization and export.

Return type:

Tuple[bool, Any]

Parameters:
  • float_model – The float model to be quantized.

  • representative_dataset (Callable, np.array, tf.Tensor) – Representative dataset for calibration.

  • method (str) – Quantization method, e.g., ‘PTQ’ or ‘GPTQ’. Default: ‘PTQ’

  • framework (str) – ‘tensorflow’ or ‘pytorch’. Default: ‘pytorch’

  • use_internal_tpc (bool) – Whether to use internal_tpc. Default: True

  • use_mixed_precision (bool) – Whether to use mixed-precision quantization. Default: False

  • param_items (list) – List of parameter settings. [[key,value],…]. Default: None

Returns:

tuple (quantization success flag, quantized model)

Examples

Import MCT:

>>> import model_compression_toolkit as mct

Prepare the float model and dataset

>>> float_model = ...
>>> representative_dataset = ...

Create an instance of the MCTWrapper

>>> wrapper = mct.MCTWrapper()

set method, framework, and other parameters

>>> method = 'PTQ'
>>> framework = 'tensorflow'
>>> use_internal_tpc = True
>>> use_mixed_precision = False

set parameters if needed

>>> param_items = [[key, value]...]

Quantize and export the model

>>> flag, quantized_model = wrapper.quantize_and_export(
...     float_model=float_model,
...     representative_dataset=representative_dataset,
...     method=method,
...     framework=framework,
...     use_internal_tpc=use_internal_tpc,
...     use_mixed_precision=use_mixed_precision,
...     param_items=param_items
... )