KD_Lib.Quantization.qat

KD_Lib.Quantization.qat.qat module

class KD_Lib.Quantization.qat.qat.QAT_Quantizer(model, train_loader, test_loader, optimizer, qconfig=QConfig(activation=functools.partial(<class 'torch.quantization.fake_quantize.FakeQuantize'>, observer=<class 'torch.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, reduce_range=True), weight=functools.partial(<class 'torch.quantization.fake_quantize.FakeQuantize'>, observer=<class 'torch.quantization.observer.MovingAveragePerChannelMinMaxObserver'>, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0)), criterion=CrossEntropyLoss(), device=device(type='cpu'))[source]

Bases: KD_Lib.Quantization.common.base_class.Quantizer

Implementation of Quantization-Aware Training (QAT) for PyTorch models.

Parameters:
  • model (torch.nn.Module) – (Quantizable) Model that needs to be quantized
  • train_loader (torch.utils.data.DataLoader) – DataLoader used for training
  • test_loader (torch.utils.data.DataLoader) – DataLoader used for testing
  • optimizer (torch.optim.*) – Optimizer for training
  • qconfig (Qconfig) – Configuration used for quantization
  • criterion (Loss_fn) – Loss function used for training
  • device (torch.device) – Device used for training (“cpu” or “cuda”)
quantize(num_train_epochs=10, num_train_batches=10, param_freeze_epoch=3, bn_freeze_epoch=2)[source]

Function used for quantization

Parameters:
  • num_train_epochs (int) – Number of epochs used for training
  • num_train_batches (int) – Number of batches used for training
  • param_freeze_epoch (int) – Epoch after which quantizer parameters need to be freezed
  • bn_freeze_epoch (int) – Epoch after which batch norm mean and variance stats are freezed