wrapper¶
- class model_compression_toolkit.wrapper.mct_wrapper.MCTWrapper¶
Wrapper class for Model Compression Toolkit (MCT) quantization and export.
This class provides a unified interface for various neural network quantization methods including Post-Training Quantization (PTQ), Gradient Post-Training Quantization (GPTQ). It supports both TensorFlow and PyTorch frameworks with optional mixed-precision quantization.
The wrapper manages the complete quantization pipeline from model input to quantized model export, handling framework-specific configurations and Target Platform Capabilities (TPC) setup.
Initialize MCTWrapper with default parameters.
Users can update the following parameters in param_items:
PTQ
Parameter Key
Default Value
Description
target_platform_version
‘v1’
Target platform version (use_internal_tpc=True)
tpc_version
‘5.0’
TPC version (use_internal_tpc=False)
activation_error_method
mct.core.QuantizationErrorMethod.MSE
Activation quantization error method
weights_bias_correction
True
Enable weights bias correction
z_threshold
float(‘inf’)
Z-threshold for quantization
linear_collapsing
True
Enable linear layer collapsing
residual_collapsing
True
Enable residual connection collapsing
save_model_path
‘./qmodel.keras’ / ‘./qmodel.onnx’
Path to save quantized model (Keras/Pytorch)
PTQ, mixed_precision
Parameter Key
Default Value
Description
target_platform_version
‘v1’
Target platform version (use_internal_tpc=True)
tpc_version
‘5.0’
TPC version (use_internal_tpc=False)
num_of_images
5
Number of images for mixed precision
use_hessian_based_scores
False
Use Hessian-based scores for mixed precision
weights_compression_ratio
None
Weights compression ratio for resource util
save_model_path
‘./qmodel.keras’ / ‘./qmodel.onnx’
Path to save quantized model (Keras/Pytorch)
GPTQ
Parameter Key
Default Value
Description
target_platform_version
‘v1’
Target platform version (use_internal_tpc=True)
tpc_version
‘5.0’
TPC version (use_internal_tpc=False)
n_epochs
5
Number of training epochs for GPTQ
optimizer
None
Optimizer for GPTQ training
save_model_path
‘./qmodel.keras’ / ‘./qmodel.onnx’
Path to save quantized model (Keras/Pytorch)
GPTQ, mixed_precision
Parameter Key
Default Value
Description
target_platform_version
‘v1’
Target platform version (use_internal_tpc=True)
tpc_version
‘5.0’
TPC version (use_internal_tpc=False)
n_epochs
5
Number of training epochs for GPTQ
optimizer
None
Optimizer for GPTQ training
num_of_images
5
Number of images for mixed precision
use_hessian_based_scores
False
Use Hessian-based scores for mixed precision
weights_compression_ratio
None
Weights compression ratio for resource util
save_model_path
‘./qmodel.keras’ / ‘./qmodel.onnx’
Path to save quantized model (Keras/Pytorch)
- quantize_and_export(float_model, representative_dataset, method='PTQ', framework='pytorch', use_internal_tpc=True, use_mixed_precision=False, param_items=None)¶
Main function to perform model quantization and export.
- Return type:
Tuple[bool,Any]- Parameters:
float_model – The float model to be quantized.
representative_dataset (Callable, np.array, tf.Tensor) – Representative dataset for calibration.
method (str) – Quantization method, e.g., ‘PTQ’ or ‘GPTQ’. Default: ‘PTQ’
framework (str) – ‘tensorflow’ or ‘pytorch’. Default: ‘pytorch’
use_internal_tpc (bool) – Whether to use internal_tpc. Default: True
use_mixed_precision (bool) – Whether to use mixed-precision quantization. Default: False
param_items (list) – List of parameter settings. [[key,value],…]. Default: None
- Returns:
tuple (quantization success flag, quantized model)
Examples
Import MCT:
>>> import model_compression_toolkit as mct
Prepare the float model and dataset
>>> float_model = ... >>> representative_dataset = ...
Create an instance of the MCTWrapper
>>> wrapper = mct.MCTWrapper()
set method, framework, and other parameters
>>> method = 'PTQ' >>> framework = 'tensorflow' >>> use_internal_tpc = True >>> use_mixed_precision = False
set parameters if needed
>>> param_items = [[key, value]...]
Quantize and export the model
>>> flag, quantized_model = wrapper.quantize_and_export( ... float_model=float_model, ... representative_dataset=representative_dataset, ... method=method, ... framework=framework, ... use_internal_tpc=use_internal_tpc, ... use_mixed_precision=use_mixed_precision, ... param_items=param_items ... )