KD_Lib.Quantization.static

KD_Lib.Quantization.static.static_quantization module

class KD_Lib.Quantization.static.static_quantization.Static_Quantizer(model, train_loader, test_loader, qconfig=QConfig(activation=functools.partial(<class 'torch.quantization.observer.MinMaxObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)), criterion=CrossEntropyLoss(), device=device(type='cpu'))[source]

Bases: KD_Lib.Quantization.common.base_class.Quantizer

Implementation of Static Quantization for PyTorch models.

Parameters:
  • model (torch.nn.Module) – Model that needs to be pruned
  • qconfig (Qconfig) – Configuration used for quantization
  • train_loader (torch.utils.data.DataLoader) – DataLoader used for training (calibration)
  • test_loader (torch.utils.data.DataLoader) – DataLoader used for testing
  • criterion (Loss_fn) – Loss function used for calibration
  • device (torch.device) – Device used for training (“cpu” or “cuda”)
quantize(num_calibration_batches=10)[source]

Function used for quantization

Parameters:num_calibration_batches (int) – Number of batches used for calibration