# Common¶

## KD_Lib.KD.common.base_class module¶

class KD_Lib.KD.common.base_class.BaseClass(teacher_model, student_model, train_loader, val_loader, optimizer_teacher, optimizer_student, loss_fn=KLDivLoss(), temp=20.0, distil_weight=0.5, device='cpu', log=False, logdir='./Experiments')[source]

Bases: object

Basic implementation of a general Knowledge Distillation framework

Parameters: (torch.nn.Module) (loss_fn) – Teacher model (torch.nn.Module) – Student model (torch.utils.data.DataLoader) (val_loader) – Dataloader for training (torch.utils.data.DataLoader) – Dataloader for validation/testing (torch.optim.*) (optimizer_student) – Optimizer used for training teacher (torch.optim.*) – Optimizer used for training student (torch.nn.Module) – Loss Function used for distillation (float) (distil_weight) – Temperature parameter for distillation (float) – Weight paramter for distillation loss (str) (logdir) – Device used for training; ‘cpu’ for cpu and ‘cuda’ for gpu (bool) (log) – True if logging required (str) – Directory for storing logs
calculate_kd_loss(y_pred_student, y_pred_teacher, y_true)[source]

Custom loss function to calculate the KD loss for various implementations

Parameters: (Tensor) (y_true) – Predicted outputs from the student network (Tensor) – Predicted outputs from the teacher network (Tensor) – True labels
evaluate(teacher=False)[source]

Evaluate method for printing accuracies of the trained network

Parameters: (bool) (teacher) – True if you want accuracy of the teacher network
get_parameters()[source]

Get the number of parameters for the teacher and the student network

post_epoch_call(epoch)[source]

Any changes to be made after an epoch is completed.

:param epoch (int) : current epoch number :return : nothing (void)

train_student(epochs=10, plot_losses=True, save_model=True, save_model_pth='./models/student.pt')[source]

Function that will be training the student

Parameters: (int) (epochs) – Number of epochs you want to train the teacher (bool) (save_model) – True if you want to plot the losses (bool) – True if you want to save the student model (str) (save_model_pth) – Path where you want to save the student model
train_teacher(epochs=20, plot_losses=True, save_model=True, save_model_pth='./models/teacher.pt')[source]

Function that will be training the teacher

Parameters: (int) (epochs) – Number of epochs you want to train the teacher (bool) (save_model) – True if you want to plot the losses (bool) – True if you want to save the teacher model (str) (save_model_pth) – Path where you want to store the teacher model