KD_Lib.Quantization.qat¶
KD_Lib.Quantization.qat.qat module¶
-
class
KD_Lib.Quantization.qat.qat.
QAT_Quantizer
(model, train_loader, test_loader, optimizer, qconfig=QConfig(activation=functools.partial(<class 'torch.quantization.fake_quantize.FakeQuantize'>, observer=<class 'torch.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, reduce_range=True), weight=functools.partial(<class 'torch.quantization.fake_quantize.FakeQuantize'>, observer=<class 'torch.quantization.observer.MovingAveragePerChannelMinMaxObserver'>, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0)), criterion=CrossEntropyLoss(), device=device(type='cpu'))[source]¶ Bases:
KD_Lib.Quantization.common.base_class.Quantizer
Implementation of Quantization-Aware Training (QAT) for PyTorch models.
Parameters: - model (torch.nn.Module) – (Quantizable) Model that needs to be quantized
- train_loader (torch.utils.data.DataLoader) – DataLoader used for training
- test_loader (torch.utils.data.DataLoader) – DataLoader used for testing
- optimizer (torch.optim.*) – Optimizer for training
- qconfig (Qconfig) – Configuration used for quantization
- criterion (Loss_fn) – Loss function used for training
- device (torch.device) – Device used for training (“cpu” or “cuda”)
-
quantize
(num_train_epochs=10, num_train_batches=10, param_freeze_epoch=3, bn_freeze_epoch=2)[source]¶ Function used for quantization
Parameters: - num_train_epochs (int) – Number of epochs used for training
- num_train_batches (int) – Number of batches used for training
- param_freeze_epoch (int) – Epoch after which quantizer parameters need to be freezed
- bn_freeze_epoch (int) – Epoch after which batch norm mean and variance stats are freezed