Common¶
KD_Lib.KD.common.base_class module¶
-
class
KD_Lib.KD.common.base_class.
BaseClass
(teacher_model, student_model, train_loader, val_loader, optimizer_teacher, optimizer_student, loss_fn=KLDivLoss(), temp=20.0, distil_weight=0.5, device='cpu', log=False, logdir='./Experiments')[source]¶ Bases:
object
Basic implementation of a general Knowledge Distillation framework
Parameters: - (torch.nn.Module) (loss_fn) – Teacher model
- (torch.nn.Module) – Student model
- (torch.utils.data.DataLoader) (val_loader) – Dataloader for training
- (torch.utils.data.DataLoader) – Dataloader for validation/testing
- (torch.optim.*) (optimizer_student) – Optimizer used for training teacher
- (torch.optim.*) – Optimizer used for training student
- (torch.nn.Module) – Loss Function used for distillation
- (float) (distil_weight) – Temperature parameter for distillation
- (float) – Weight paramter for distillation loss
- (str) (logdir) – Device used for training; ‘cpu’ for cpu and ‘cuda’ for gpu
- (bool) (log) – True if logging required
- (str) – Directory for storing logs
-
calculate_kd_loss
(y_pred_student, y_pred_teacher, y_true)[source]¶ Custom loss function to calculate the KD loss for various implementations
Parameters: - (Tensor) (y_true) – Predicted outputs from the student network
- (Tensor) – Predicted outputs from the teacher network
- (Tensor) – True labels
-
evaluate
(teacher=False)[source]¶ Evaluate method for printing accuracies of the trained network
Parameters: (bool) (teacher) – True if you want accuracy of the teacher network
-
post_epoch_call
(epoch)[source]¶ Any changes to be made after an epoch is completed.
:param epoch (int) : current epoch number :return : nothing (void)
-
train_student
(epochs=10, plot_losses=True, save_model=True, save_model_pth='./models/student.pt')[source]¶ Function that will be training the student
Parameters: - (int) (epochs) – Number of epochs you want to train the teacher
- (bool) (save_model) – True if you want to plot the losses
- (bool) – True if you want to save the student model
- (str) (save_model_pth) – Path where you want to save the student model
-
train_teacher
(epochs=20, plot_losses=True, save_model=True, save_model_pth='./models/teacher.pt')[source]¶ Function that will be training the teacher
Parameters: - (int) (epochs) – Number of epochs you want to train the teacher
- (bool) (save_model) – True if you want to plot the losses
- (bool) – True if you want to save the teacher model
- (str) (save_model_pth) – Path where you want to store the teacher model