Common

KD_Lib.KD.common.base_class module

class KD_Lib.KD.common.base_class.BaseClass(teacher_model, student_model, train_loader, val_loader, optimizer_teacher, optimizer_student, loss_fn=KLDivLoss(), temp=20.0, distil_weight=0.5, device='cpu', log=False, logdir='./Experiments')[source]

Bases: object

Basic implementation of a general Knowledge Distillation framework

Parameters:
  • (torch.nn.Module) (loss_fn) – Teacher model
  • (torch.nn.Module) – Student model
  • (torch.utils.data.DataLoader) (val_loader) – Dataloader for training
  • (torch.utils.data.DataLoader) – Dataloader for validation/testing
  • (torch.optim.*) (optimizer_student) – Optimizer used for training teacher
  • (torch.optim.*) – Optimizer used for training student
  • (torch.nn.Module) – Loss Function used for distillation
  • (float) (distil_weight) – Temperature parameter for distillation
  • (float) – Weight paramter for distillation loss
  • (str) (logdir) – Device used for training; ‘cpu’ for cpu and ‘cuda’ for gpu
  • (bool) (log) – True if logging required
  • (str) – Directory for storing logs
calculate_kd_loss(y_pred_student, y_pred_teacher, y_true)[source]

Custom loss function to calculate the KD loss for various implementations

Parameters:
  • (Tensor) (y_true) – Predicted outputs from the student network
  • (Tensor) – Predicted outputs from the teacher network
  • (Tensor) – True labels
evaluate(teacher=False)[source]

Evaluate method for printing accuracies of the trained network

Parameters:(bool) (teacher) – True if you want accuracy of the teacher network
get_parameters()[source]

Get the number of parameters for the teacher and the student network

post_epoch_call(epoch)[source]

Any changes to be made after an epoch is completed.

:param epoch (int) : current epoch number :return : nothing (void)

train_student(epochs=10, plot_losses=True, save_model=True, save_model_pth='./models/student.pt')[source]

Function that will be training the student

Parameters:
  • (int) (epochs) – Number of epochs you want to train the teacher
  • (bool) (save_model) – True if you want to plot the losses
  • (bool) – True if you want to save the student model
  • (str) (save_model_pth) – Path where you want to save the student model
train_teacher(epochs=20, plot_losses=True, save_model=True, save_model_pth='./models/teacher.pt')[source]

Function that will be training the teacher

Parameters:
  • (int) (epochs) – Number of epochs you want to train the teacher
  • (bool) (save_model) – True if you want to plot the losses
  • (bool) – True if you want to save the teacher model
  • (str) (save_model_pth) – Path where you want to store the teacher model