Welcome to KD_Lib’s documentation!¶
KD-Lib¶
A PyTorch model compression library containing easy-to-use methods for knowledge distillation, pruning, and quantization.
Installation¶
Building from source (recommended)
If you intend to install the latest unreleased version of the library (i.e from source), you can simply do:
$ git clone https://github.com/SforAiDl/KD_Lib.git
$ cd KD_Lib
$ python setup.py install
Stable release
KD_Lib is compatible with Python 3.6 or later and also depends on PyTorch. The easiest way to install KD_Lib is with pip, Python’s preferred package installer.
$ pip install KD-Lib
Note that KD_Lib is an active project and routinely publishes new releases. In order to upgrade KD_Lib to the latest version, use pip as follows.
$ pip install -U KD-Lib
Usage¶
To implement the most basic version of knowledge distillation from Distilling the Knowledge in a Neural Network and plot losses
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import VanillaKD
# This part is where you define your datasets, dataloaders, models and optimizers
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
teacher_model = <your model>
student_model = <your model>
teacher_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
student_optimizer = optim.SGD(student_model.parameters(), 0.01)
# Now, this is where KD_Lib comes into the picture
distiller = VanillaKD(teacher_model, student_model, train_loader, test_loader,
teacher_optimizer, student_optimizer)
distiller.train_teacher(epochs=5, plot_losses=True, save_model=True) # Train the teacher network
distiller.train_student(epochs=5, plot_losses=True, save_model=True) # Train the student network
distiller.evaluate(teacher=False) # Evaluate the student network
distiller.get_parameters() # A utility function to get the number of parameters in the teacher and the student network
To train a collection of 3 models in an online fashion using the framework in Deep Mutual Learning and log training details to Tensorboard
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import DML
from KD_Lib.models import ResNet18, ResNet50 # To use models packaged in KD_Lib
# This part is where you define your datasets, dataloaders, models and optimizers
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
student_params = [4, 4, 4, 4, 4]
student_model_1 = ResNet50(student_params, 1, 10)
student_model_2 = ResNet18(student_params, 1, 10)
student_cohort = [student_model_1, student_model_2]
student_optimizer_1 = optim.SGD(student_model_1.parameters(), 0.01)
student_optimizer_2 = optim.SGD(student_model_2.parameters(), 0.01)
student_optimizers = [student_optimizer_1, student_optimizer_2]
# Now, this is where KD_Lib comes into the picture
distiller = DML(student_cohort, train_loader, test_loader, student_optimizers, log=True, logdir="./Logs")
distiller.train_students(epochs=5)
distiller.evaluate()
distiller.get_parameters()
Implemented works¶
Some benchmark results can be found in the logs file.
Paper | Link | Repository (KD_Lib/) |
---|---|---|
Distilling the Knowledge in a Neural Network | https://arxiv.org/abs/1503.02531 | KD/vision/vanilla |
Improved Knowledge Distillation via Teacher Assistant | https://arxiv.org/abs/1902.03393 | KD/vision/TAKD |
Relational Knowledge Distillation | https://arxiv.org/abs/1904.05068 | KD/vision/RKD |
Distilling Knowledge from Noisy Teachers | https://arxiv.org/abs/1610.09650 | KD/vision/noisy |
Paying More Attention To The Attention | https://arxiv.org/abs/1612.03928 | KD/vision/attention |
Revisit Knowledge Distillation: a Teacher-free Framework | https://arxiv.org/abs/1909.11723 | KD/vision/teacher_free |
Mean Teachers are Better Role Models | https://arxiv.org/abs/1703.01780 | KD/vision/mean_teacher |
Knowledge Distillation via Route Constrained Optimization | https://arxiv.org/abs/1904.09149 | KD/vision/RCO |
Born Again Neural Networks | https://arxiv.org/abs/1805.04770 | KD/vision/BANN |
Preparing Lessons: Improve Knowledge Distillation with Better Supervision | https://arxiv.org/abs/1911.07471 | KD/vision/KA |
Improving Generalization Robustness with Noisy Collaboration in Knowledge Distillation | https://arxiv.org/abs/1910.05057 | KD/vision/noisy |
Distilling Task-Specific Knowledge from BERT into Simple Neural Networks | https://arxiv.org/abs/1903.12136 | KD/text/BERT2LSTM |
Deep Mutual Learning | https://arxiv.org/abs/1706.00384 | KD/vision/DML |
The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks | https://arxiv.org/abs/1803.03635 | Pruning/ lottery_tickets |
Regularizing Class-wise Predictions via Self- knowledge Distillation. | https://arxiv.org/abs/2003.13964 | KD/vision/CSDK |
Please cite our pre-print if you find KD_Lib useful in any way :)
@misc{shah2020kdlib,
title={KD-Lib: A PyTorch library for Knowledge Distillation, Pruning and Quantization},
author={Het Shah and Avishree Khare and Neelay Shah and Khizir Siddiqui},
year={2020},
eprint={2011.14691},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Installation¶
From source (recommended)¶
If you intend to install the latest unreleased version of the library (i.e from source), you can simply do:
$ git clone https://github.com/SforAiDl/KD_Lib.git
$ cd KD_lib
$ python setup.py install
Stable release¶
KD_Lib is compatible with Python 3.6 or later and also depends on PyTorch. KD-Lib can be installed from PyPI via pip,
$ pip install KD-Lib
Note that KD_Lib is an active project and routinely publishes new releases. In order to upgrade KD_Lib to the latest version, use pip as follows.
$ pip install -U KD-Lib
Tutorials¶
VanillaKD using KD_Lib¶
To implement the most basic version of knowledge distillation from Distilling the Knowledge in a Neural Network and plot losses
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import VanillaKD
# Define datasets, dataloaders, models and optimizers
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
teacher_model = <your model>
student_model = <your model>
teacher_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
student_optimizer = optim.SGD(student_model.parameters(), 0.01)
# Now, this is where KD_Lib comes into the picture
distiller = VanillaKD(teacher_model, student_model, train_loader, test_loader,
teacher_optimizer, student_optimizer)
distiller.train_teacher(epochs=5, plot_losses=True, save_model=True) # Train the teacher network
distiller.train_student(epochs=5, plot_losses=True, save_model=True) # Train the student network
distiller.evaluate(teacher=False) # Evaluate the student network
distiller.get_parameters() # A utility function to get the number of parameters in the teacher and the student network
Deep Mutual Learning using KD_Lib¶
- Deep Mutual Learning is an online algortihm wherein an ensemble of students learn collaboratively and teach each other throughout the training process.
- Rather performing a one way transfer from a powerful and large and pre-trained teacher network, DML uses a pool of untrained students who learn simultaneously to solve the task together.
- Each student is trained with two losses: a conventional supervised learning loss, and a mimicry loss that aligns each student’s class posterior with the class probabilities of other students.
Snippet from the paper illustrating the DML algorithm -

To use DML with KD_Lib, create a list of student models (student cohort) to be used for collective training and a list of optmizers for them as well. The student models may have different architectures. Remember to match the order of the students with that of their optimizers in the list.
To use DML with 3 students on MNIST -
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import DML
# Define datasets, dataloaders, models and optimizers
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
# Set device to be trained on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Define a cohort of student models
student_model_1 = <your model>
student_model_2 = <your model>
student_model_3 = <your model>
student_cohort = (student_model_1, student_model_2, student_model_3)
# Make a list of optimizers for the models keeping in mind the order
student_optimizer_1 = optim.SGD(student_model_1.parameters(), 0.01)
student_optimizer_2 = optim.SGD(student_model_2.parameters(), 0.01)
student_optimizer_3 = optim.SGD(student_model_3.parameters(), 0.01)
optimizers = [student_optimizer_1, student_optimizer_2, student_optimizer_3]
# Train using KD_Lib
distiller = DML(student_cohort, train_loader, test_loader, optimizers,
device=device)
distiller.train_students(epochs=5, plot_losses=True, save_model=True) # Train the student cohort
distiller.evaluate() # Evaluate the student models
Label Smooth Regularization using KD_Lib¶
- Considering a sample x of class k with ground truth label distribution l = δ(k), where δ(·) is impulse signal, the LSR label is given as -

where K is the number of classes
To use the label smooth regularization with incorrect teacher predictions replaced with labels where the correct classes have a probability of 0.9 -
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import LabelSmoothReg
# Define datasets, dataloaders, models and optimizers
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
# Set device to be trained on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Define student and teacher models
teacher_model = <your model>
student_model = <your model>
# Define optimizers
teacher_optimizer = optim.SGD(teacher_model.parameters(), lr=0.01)
student_optimizer = optim.SGD(student_model.parameters(), lr=0.01)
# Train using KD_Lib
distiller = LabelSmoothReg(teacher_model, student_model, train_loader, test_loader, teacher_optimizer,
student_optimizer, correct_prob=0.9, device=device)
distiller.train_teacher(epochs=5) # Train the teacher model
distiller.train_students(epochs=5) # Train the student model
distiller.evaluate(teacher=True) # Evaluate the teacher model
distiller.evaluate() # Evaluate the student model
Probability Shift using KD_Lib¶
- Given an incorrect soft target, the probability shift algorithm simply swaps the value of ground truth (the theoretical maximum) and the value of predicted class (the predicted maximum), to assure the maximum confidence is reached at ground truth label

To use the probability shift algorithm to train a student on MNIST for 5 epcohs -
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import ProbShift
# Define datasets, dataloaders, models and optimizers
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
# Set device to be trained on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Define student and teacher models
teacher_model = <your model>
student_model = <your model>
# Define optimizers
teacher_optimizer = optim.SGD(teacher_model.parameters(), lr=0.01)
student_optimizer = optim.SGD(student_model.parameters(), lr=0.01)
# Train using KD_Lib
distiller = ProbShift(teacher_model, student_model, train_loader, test_loader, teacher_optimizer,
student_optimizer, device=device)
distiller.train_teacher(epochs=5) # Train the teacher model
distiller.train_students(epochs=5) # Train the student model
distiller.evaluate(teacher=True) # Evaluate the teacher model
distiller.evaluate() # Evaluate the student model
Route Constrained Optimization using KD_Lib¶
- The route constrained optimization algorithm considers knowledge distillation from the perspective of curriculum learning by routing
- Instead of supervising the student model with a converged teacher model, it is supervised with some anchor points selected from the route in parameter space that the teacher model passed by
- This has been demonstrated to greatly reduce the lower bound of congruence loss for knowledge distillation, hint and mimicking learning

To use RCO with the the student mimicking the teacher’s trajectory at an interval of 5 epochs -
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import RCO
# Define datasets, dataloaders, models and optimizers
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
# Set device to be trained on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Define student and teacher models
teacher_model = <your model>
student_model = <your model>
# Define optimizers
teacher_optimizer = optim.SGD(teacher_model.parameters(), lr=0.01)
student_optimizer = optim.SGD(student_model.parameters(), lr=0.01)
# Train using KD_Lib
distiller = RCO(teacher_model, student_model, train_loader, test_loader, teacher_optimizer,
student_optimizer, epoch_interval=5, device=device)
distiller.train_teacher(epochs=20) # Train the teacher model
distiller.train_students(epochs=20) # Train the student model
distiller.evaluate(teacher=True) # Evaluate the teacher model
distiller.evaluate() # Evaluate the student model
Self Training using KD_Lib¶
- The student model is first trained in the normal way to obtain a pre-trained model, which is then used as the teacher to train itself by transferring soft targets
To use the self training algorithm to train a student on MNIST for 5 epcohs -
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import SelfTraining
# Define datasets, dataloaders, models and optimizers
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
# Set device to be trained on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Define student model
student_model = <your model>
# Define optimizer
student_optimizer = optim.SGD(student_model.parameters(), lr=0.01)
# Train using KD_Lib
distiller = SelfTraining(student_model, train_loader, test_loader, student_optimizer,
device=device)
distiller.train_student(epochs=5) # Train the student model
distiller.evaluate() # Evaluate the student model
Hyperparameter Tuning using Optuna¶
Hyperparameter optimization is one of the crucial steps in training machine learning models. It is often quite a tedious process with many parameters to optimize and long training times for models. Optuna is an automatic hyperparameter optimization software framework, particularly designed for machine learning You can find more about Optuna here.
Optuna an be installed using pip -
$ pip install optuna
or using conda -
$ conda install -c conda-forge optuna
To search for the best hyperparameters for the VanillaKD algorithm -
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import VanillaKD
import optuna
from sklearn.externals import joblib
# Define datasets, dataloaders, models and optimizers
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
# Set device to be trained on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Optuna requires defining an objective function
# The hyperparameters are then optimized for maximizing/minimizing this objective function
def tune_VanillaKD(trial):
teacher_model = <your model>
student_model = <your model>
# Define hyperparams and choose what ranges they should be trialled for
lr = trial.suggest_float("lr", 1e-4, 1e-1)
momentum = trial.suggest_float("momentum", 0.9, 0.99)
optimizer = trial.suggest_categorical('optimizer',[optim.SGD, optim.Adam])
teacher_optimizer = optimizer(teacher_model.parameters(), lr, momentum)
student_optimizer = optimizer(student_model.parameters(), lr, momentum)
temperature = trial.suggest_float("temperature", 5.0, 20.0)
distil_weight = trial.suggest_float("distil_weight", 0.0, 1.0)
loss_fn = trial.suggest_categorical("loss_fn",[nn.KLDivLoss(), nn.MSELoss()])
# Instiate disitller object using KD_Lib and train
distiller = VanillaKD(teacher_model, student_model, train_loader, test_loader,
teacher_optimizer, student_optimizer, loss_fn,
temperature, distil_weight, device)
distiller.train_teacher(epochs=10)
distiller.train_student(epochs=10)
test_accuracy = disitller.evaluate()
# The objective function must return the quantity we're trying to maximize/minimize
return test_accuracy
# Create a study
study = optuna.create_study(study_name="Hyperparameter Optimization",
direction="maximize")
study.optimize(tune_VanillaKD, n_trials=10)
# Access results
results = study.trials_dataframe()
results.head()
# Get best values of hyperparameter
for key, value in study.best_trial.__dict__.items():
print("{} : {}".format(key, value))
# Write results of the study
joblib.dump(study, <your path>)
# Access results at a later time
study = joblib.load(<your path>)
results = study.trials_dataframe()
results.head()
Knowledge Distillation¶
Vision¶
KD_Lib.KD.vision.attention¶
KD_Lib.KD.vision.attention.attention module¶
KD_Lib.KD.vision.attention.loss_metric module¶
KD_Lib.KD.vision.noisy¶
KD_Lib.KD.vision.noisy.messy_collab module¶
KD_Lib.KD.vision.noisy.noisy_teacher module¶
KD_Lib.KD.vision.noisy.soft_random module¶
KD_Lib.KD.vision.noisy.utils module¶
Text¶
KD_Lib.KD.text.BERT2LSTM package¶
Submodules¶
KD_Lib.KD.text.BERT2LSTM.bert2lstm module¶
-
class
KD_Lib.KD.text.BERT2LSTM.bert2lstm.
BERT2LSTM
(student_model, distill_train_loader, distill_val_loader, optimizer_student, train_df, val_df, num_classes=2, seed=42, distil_weight=0.5, device='cpu', log=False, logdir='./Experiments', max_seq_length=128)[source]¶ Bases:
KD_Lib.KD.common.base_class.BaseClass
Implementation of Knowledge distillation from the paper “Distilling Task-Specific Knowledge from BERT into Simple Neural Networks” https://arxiv.org/pdf/1903.12136.pdf
Parameters: - (torch.nn.Module) (student_model) – Student model
- (torch.utils.data.DataLoader) (distill_val_loader) – Student Training Dataloader for distillation
- (torch.utils.data.DataLoader) – Student Testing/validation Dataloader
- (pandas.DataFrame) (val_df) – Dataframe for training the teacher model
- (pandas.DataFrame) – Dataframe for validating the teacher model
- (torch.nn.module) (loss_fn) – Loss function
- (float) (distil_weight) – Temperature parameter for distillation
- (float) – Weight paramter for distillation loss
- (str) (logdir) – Device used for training; ‘cpu’ for cpu and ‘cuda’ for gpu
- (bool) (log) – True if logging required
- (str) – Directory for storing logs
-
calculate_kd_loss
(y_pred_student, y_pred_teacher, y_true)[source]¶ Function used for calculating the KD loss during distillation
Parameters: - (torch.FloatTensor) (y_true) – Prediction made by the student model
- (torch.FloatTensor) – Prediction made by the teacher model
- (torch.FloatTensor) – Original label
-
evaluate_student
(verbose=True)[source]¶ Function used for evaluating student
Parameters: (bool) (verbose) – True if the accuracy needs to be printed else False
-
evaluate_teacher
(val_batch_size=16, verbose=True)[source]¶ Function used for evaluating student
Parameters: - (int) (val_batch_size) – Maximum sequence length paramter for generating dataloaders
- (int) – Batch size paramter for generating dataloaders
- (bool) (verbose) – True if the accuracy needs to be printed else False
-
train_student
(epochs=10, plot_losses=True, save_model=True, save_model_pth='./models/student.pth')[source]¶ Function that will be training the student
Parameters: - (int) (epochs) – Number of epochs you want to train the teacher
- (bool) (save_model) – True if you want to plot the losses
- (bool) – True if you want to save the student model
- (str) (save_model_pth) – Path where you want to save the student model
-
train_teacher
(epochs=1, plot_losses=True, save_model=True, save_model_pth='./models/teacher.pt', train_batch_size=16, batch_print_freq=40, val_batch_size=16)[source]¶ Function that will be training the teacher
Parameters: - (int) (batch_print_freq) – Number of epochs you want to train the teacher
- (bool) (save_model) – True if you want to plot the losses
- (bool) – True if you want to save the teacher model
- (str) (save_model_pth) – Path where you want to store the teacher model
- (int) – Batch size paramter for generating dataloaders
- (int) – Frequency at which batch number needs to be printed per epoch
KD_Lib.KD.text.BERT2LSTM.utils module¶
Module contents¶
Common¶
KD_Lib.KD.common.base_class module¶
-
class
KD_Lib.KD.common.base_class.
BaseClass
(teacher_model, student_model, train_loader, val_loader, optimizer_teacher, optimizer_student, loss_fn=KLDivLoss(), temp=20.0, distil_weight=0.5, device='cpu', log=False, logdir='./Experiments')[source]¶ Bases:
object
Basic implementation of a general Knowledge Distillation framework
Parameters: - (torch.nn.Module) (loss_fn) – Teacher model
- (torch.nn.Module) – Student model
- (torch.utils.data.DataLoader) (val_loader) – Dataloader for training
- (torch.utils.data.DataLoader) – Dataloader for validation/testing
- (torch.optim.*) (optimizer_student) – Optimizer used for training teacher
- (torch.optim.*) – Optimizer used for training student
- (torch.nn.Module) – Loss Function used for distillation
- (float) (distil_weight) – Temperature parameter for distillation
- (float) – Weight paramter for distillation loss
- (str) (logdir) – Device used for training; ‘cpu’ for cpu and ‘cuda’ for gpu
- (bool) (log) – True if logging required
- (str) – Directory for storing logs
-
calculate_kd_loss
(y_pred_student, y_pred_teacher, y_true)[source]¶ Custom loss function to calculate the KD loss for various implementations
Parameters: - (Tensor) (y_true) – Predicted outputs from the student network
- (Tensor) – Predicted outputs from the teacher network
- (Tensor) – True labels
-
evaluate
(teacher=False)[source]¶ Evaluate method for printing accuracies of the trained network
Parameters: (bool) (teacher) – True if you want accuracy of the teacher network
-
post_epoch_call
(epoch)[source]¶ Any changes to be made after an epoch is completed.
:param epoch (int) : current epoch number :return : nothing (void)
-
train_student
(epochs=10, plot_losses=True, save_model=True, save_model_pth='./models/student.pt')[source]¶ Function that will be training the student
Parameters: - (int) (epochs) – Number of epochs you want to train the teacher
- (bool) (save_model) – True if you want to plot the losses
- (bool) – True if you want to save the student model
- (str) (save_model_pth) – Path where you want to save the student model
-
train_teacher
(epochs=20, plot_losses=True, save_model=True, save_model_pth='./models/teacher.pt')[source]¶ Function that will be training the teacher
Parameters: - (int) (epochs) – Number of epochs you want to train the teacher
- (bool) (save_model) – True if you want to plot the losses
- (bool) – True if you want to save the teacher model
- (str) (save_model_pth) – Path where you want to store the teacher model
Pruning¶
KD_Lib.Pruning.lottery_tickets¶
KD_Lib.Pruning.lottery_tickets.lottery_tickets module¶
-
class
KD_Lib.Pruning.lottery_tickets.lottery_tickets.
LotteryTicketsPruner
(model, train_loader, test_loader, loss_fn=CrossEntropyLoss(), device='cpu')[source]¶ Bases:
KD_Lib.Pruning.common.iterative_base_class.BaseIterativePruner
Implementation of Lottery Tickets Pruning for PyTorch models.
Parameters: - model (torch.nn.Module) – Model that needs to be pruned
- train_loader (torch.utils.data.DataLoader) – Dataloader for training
- test_loader (torch.utils.data.DataLoader) – Dataloader for validation/testing
- loss_fn (torch.nn.Module) – Loss function to be used for training
- device (torch.device) – Device used for implementation (“cpu” by default)
Quantization¶
KD_Lib.Quantization.common¶
KD_Lib.Quantization.common.base_class module¶
-
class
KD_Lib.Quantization.common.base_class.
Quantizer
(model, qconfig, train_loader=None, test_loader=None, optimizer=None, criterion=None, device=device(type='cpu'))[source]¶ Bases:
object
Basic Implementation of Quantization for PyTorch models.
Parameters: - model (torch.nn.Module) – Model that needs to be pruned
- qconfig (Qconfig) – Configuration used for quantization
- train_loader (torch.utils.data.DataLoader) – DataLoader used for training
- test_loader (torch.utils.data.DataLoader) – DataLoader used for testing
- optimizer (torch.optim.*) – Optimizer for training
- criterion (Loss_fn) – Loss function used for calibration
- device (torch.device) – Device used for training (“cpu” or “cuda”)
KD_Lib.Quantization.dynamic¶
KD_Lib.Quantization.dynamic.dynamic_quantization module¶
-
class
KD_Lib.Quantization.dynamic.dynamic_quantization.
Dynamic_Quantizer
(model, test_loader, qconfig_spec=None)[source]¶ Bases:
KD_Lib.Quantization.common.base_class.Quantizer
Implementation of Dynamic Quantization for PyTorch models.
Parameters: - model (torch.nn.Module) – Model that needs to be quantized
- qconfig_spec (Qconfig_spec) – Qconfig spec
- test_loader (torch.utils.data.DataLoader) – DataLoader used for testing
KD_Lib.Quantization.static¶
KD_Lib.Quantization.static.static_quantization module¶
-
class
KD_Lib.Quantization.static.static_quantization.
Static_Quantizer
(model, train_loader, test_loader, qconfig=QConfig(activation=functools.partial(<class 'torch.quantization.observer.MinMaxObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)), criterion=CrossEntropyLoss(), device=device(type='cpu'))[source]¶ Bases:
KD_Lib.Quantization.common.base_class.Quantizer
Implementation of Static Quantization for PyTorch models.
Parameters: - model (torch.nn.Module) – Model that needs to be pruned
- qconfig (Qconfig) – Configuration used for quantization
- train_loader (torch.utils.data.DataLoader) – DataLoader used for training (calibration)
- test_loader (torch.utils.data.DataLoader) – DataLoader used for testing
- criterion (Loss_fn) – Loss function used for calibration
- device (torch.device) – Device used for training (“cpu” or “cuda”)
KD_Lib.Quantization.qat¶
KD_Lib.Quantization.qat.qat module¶
-
class
KD_Lib.Quantization.qat.qat.
QAT_Quantizer
(model, train_loader, test_loader, optimizer, qconfig=QConfig(activation=functools.partial(<class 'torch.quantization.fake_quantize.FakeQuantize'>, observer=<class 'torch.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, reduce_range=True), weight=functools.partial(<class 'torch.quantization.fake_quantize.FakeQuantize'>, observer=<class 'torch.quantization.observer.MovingAveragePerChannelMinMaxObserver'>, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0)), criterion=CrossEntropyLoss(), device=device(type='cpu'))[source]¶ Bases:
KD_Lib.Quantization.common.base_class.Quantizer
Implementation of Quantization-Aware Training (QAT) for PyTorch models.
Parameters: - model (torch.nn.Module) – (Quantizable) Model that needs to be quantized
- train_loader (torch.utils.data.DataLoader) – DataLoader used for training
- test_loader (torch.utils.data.DataLoader) – DataLoader used for testing
- optimizer (torch.optim.*) – Optimizer for training
- qconfig (Qconfig) – Configuration used for quantization
- criterion (Loss_fn) – Loss function used for training
- device (torch.device) – Device used for training (“cpu” or “cuda”)
-
quantize
(num_train_epochs=10, num_train_batches=10, param_freeze_epoch=3, bn_freeze_epoch=2)[source]¶ Function used for quantization
Parameters: - num_train_epochs (int) – Number of epochs used for training
- num_train_batches (int) – Number of batches used for training
- param_freeze_epoch (int) – Epoch after which quantizer parameters need to be freezed
- bn_freeze_epoch (int) – Epoch after which batch norm mean and variance stats are freezed
Models¶
KD_Lib.models.lenet module¶
-
class
KD_Lib.models.lenet.
LeNet
(img_size=32, num_classes=10, in_channels=3)[source]¶ Bases:
torch.nn.modules.module.Module
Implementation of a LeNet model
Parameters: - (int) (in_channels) – Dimension of input image
- (int) – Hidden layer dimension
- (int) – Number of classes for classification
- (int) – Number of channels in input specimens
-
forward
(x)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
class
KD_Lib.models.lenet.
ModLeNet
(img_size=32, num_classes=10, in_channels=3)[source]¶ Bases:
torch.nn.modules.module.Module
Implementation of a ModLeNet model
Parameters: - (int) (in_channels) – Dimension of input image
- (int) – Hidden layer dimension
- (int) – Number of classes for classification
- (int) – Number of channels in input specimens
-
forward
(x)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
KD_Lib.models.lstm module¶
-
class
KD_Lib.models.lstm.
LSTMNet
(input_dim=100, embed_dim=50, hidden_dim=32, num_classes=2, num_layers=5, dropout_prob=0, bidirectional=False, pad_idx=0)[source]¶ Bases:
torch.nn.modules.module.Module
Implementation of an LSTM model for classification
Parameters: - (int) (batch_size) – Size of the vocabulary
- (int) – Embedding dimension (word vector size)
- (int) – Hidden dimension for LSTM layers
- (int) – Number of classes for classification
- (int) – Dropout probability
- (int) – True if bidirectional LSTM needed
- (int) – Batch size of input
-
forward
(x, x_len)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
KD_Lib.models.nin module¶
-
class
KD_Lib.models.nin.
NetworkInNetwork
(num_classes=10, in_channels=3)[source]¶ Bases:
torch.nn.modules.module.Module
Implementation of a Network In Network model
Parameters: - (int) (in_channels) – Number of classes for classification
- (int) – Number of channels in input specimens
-
forward
(x)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
KD_Lib.models.resnet module¶
-
class
KD_Lib.models.resnet.
BasicBlock
(in_planes, planes, stride=1)[source]¶ Bases:
torch.nn.modules.module.Module
-
expansion
= 1¶
-
forward
(x)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
KD_Lib.models.resnet.
Bottleneck
(in_planes, planes, stride=1)[source]¶ Bases:
torch.nn.modules.module.Module
-
expansion
= 4¶
-
forward
(x)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
KD_Lib.models.resnet.
MeanResnet
(block, num_blocks, params, num_channel=3, num_classes=10)[source]¶ Bases:
KD_Lib.models.resnet.ResNet
-
forward
(x)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
KD_Lib.models.resnet.
ResNet
(block, num_blocks, params, num_channel=3, num_classes=10)[source]¶ Bases:
torch.nn.modules.module.Module
-
forward
(x, out_feature=False)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
KD_Lib.models.resnet.
ResNet101
(parameters, num_channel=3, num_classes=10, att=False, mean=False)[source]¶ Function that creates a ResNet 101 model
Parameters: - (list or tuple) (parameters) – List of parameters for the model
- (int) (num_classes) – Number of channels in input specimens
- (int) – Number of classes for classification
- (bool) (mean) – True if attention needs to be used
- (bool) – True if mean teacher model needs to be used
-
KD_Lib.models.resnet.
ResNet152
(parameters, num_channel=3, num_classes=10, att=False, mean=False)[source]¶ Function that creates a ResNet 152 model
Parameters: - (list or tuple) (parameters) – List of parameters for the model
- (int) (num_classes) – Number of channels in input specimens
- (int) – Number of classes for classification
- (bool) (mean) – True if attention needs to be used
- (bool) – True if mean teacher model needs to be used
-
KD_Lib.models.resnet.
ResNet18
(parameters, num_channel=3, num_classes=10, att=False, mean=False)[source]¶ Function that creates a ResNet 18 model
Parameters: - (list or tuple) (parameters) – List of parameters for the model
- (int) (num_classes) – Number of channels in input specimens
- (int) – Number of classes for classification
- (bool) (mean) – True if attention needs to be used
- (bool) – True if mean teacher model needs to be used
-
KD_Lib.models.resnet.
ResNet34
(parameters, num_channel=3, num_classes=10, att=False, mean=False)[source]¶ Function that creates a ResNet 34 model
Parameters: - (list or tuple) (parameters) – List of parameters for the model
- (int) (num_classes) – Number of channels in input specimens
- (int) – Number of classes for classification
- (bool) (mean) – True if attention needs to be used
- (bool) – True if mean teacher model needs to be used
-
KD_Lib.models.resnet.
ResNet50
(parameters, num_channel=3, num_classes=10, att=False, mean=False)[source]¶ Function that creates a ResNet 50 model
Parameters: - (list or tuple) (parameters) – List of parameters for the model
- (int) (num_classes) – Number of channels in input specimens
- (int) – Number of classes for classification
- (bool) (mean) – True if attention needs to be used
- (bool) – True if mean teacher model needs to be used
-
class
KD_Lib.models.resnet.
ResnetWithAT
(block, num_blocks, params, num_channel=3, num_classes=10)[source]¶ Bases:
KD_Lib.models.resnet.ResNet
-
forward
(x)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
KD_Lib.models.shallow module¶
-
class
KD_Lib.models.shallow.
Shallow
(img_size=28, hidden_size=800, num_classes=10, num_channels=1)[source]¶ Bases:
torch.nn.modules.module.Module
Implementation of a Shallow model
Parameters: - (int) (num_classes) – Dimension of input image
- (int) – Hidden layer dimension
- (int) – Number of classes for classification
-
forward
(x)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.