KD_Lib.KD.text.BERT2LSTM package

Submodules

KD_Lib.KD.text.BERT2LSTM.bert2lstm module

class KD_Lib.KD.text.BERT2LSTM.bert2lstm.BERT2LSTM(student_model, distill_train_loader, distill_val_loader, optimizer_student, train_df, val_df, num_classes=2, seed=42, distil_weight=0.5, device='cpu', log=False, logdir='./Experiments', max_seq_length=128)[source]

Bases: KD_Lib.KD.common.base_class.BaseClass

Implementation of Knowledge distillation from the paper “Distilling Task-Specific Knowledge from BERT into Simple Neural Networks” https://arxiv.org/pdf/1903.12136.pdf

Parameters:
  • (torch.nn.Module) (student_model) – Student model
  • (torch.utils.data.DataLoader) (distill_val_loader) – Student Training Dataloader for distillation
  • (torch.utils.data.DataLoader) – Student Testing/validation Dataloader
  • (pandas.DataFrame) (val_df) – Dataframe for training the teacher model
  • (pandas.DataFrame) – Dataframe for validating the teacher model
  • (torch.nn.module) (loss_fn) – Loss function
  • (float) (distil_weight) – Temperature parameter for distillation
  • (float) – Weight paramter for distillation loss
  • (str) (logdir) – Device used for training; ‘cpu’ for cpu and ‘cuda’ for gpu
  • (bool) (log) – True if logging required
  • (str) – Directory for storing logs
calculate_kd_loss(y_pred_student, y_pred_teacher, y_true)[source]

Function used for calculating the KD loss during distillation

Parameters:
  • (torch.FloatTensor) (y_true) – Prediction made by the student model
  • (torch.FloatTensor) – Prediction made by the teacher model
  • (torch.FloatTensor) – Original label
evaluate_student(verbose=True)[source]

Function used for evaluating student

Parameters:(bool) (verbose) – True if the accuracy needs to be printed else False
evaluate_teacher(val_batch_size=16, verbose=True)[source]

Function used for evaluating student

Parameters:
  • (int) (val_batch_size) – Maximum sequence length paramter for generating dataloaders
  • (int) – Batch size paramter for generating dataloaders
  • (bool) (verbose) – True if the accuracy needs to be printed else False
set_seed(seed)[source]
train_student(epochs=10, plot_losses=True, save_model=True, save_model_pth='./models/student.pth')[source]

Function that will be training the student

Parameters:
  • (int) (epochs) – Number of epochs you want to train the teacher
  • (bool) (save_model) – True if you want to plot the losses
  • (bool) – True if you want to save the student model
  • (str) (save_model_pth) – Path where you want to save the student model
train_teacher(epochs=1, plot_losses=True, save_model=True, save_model_pth='./models/teacher.pt', train_batch_size=16, batch_print_freq=40, val_batch_size=16)[source]

Function that will be training the teacher

Parameters:
  • (int) (batch_print_freq) – Number of epochs you want to train the teacher
  • (bool) (save_model) – True if you want to plot the losses
  • (bool) – True if you want to save the teacher model
  • (str) (save_model_pth) – Path where you want to store the teacher model
  • (int) – Batch size paramter for generating dataloaders
  • (int) – Frequency at which batch number needs to be printed per epoch

KD_Lib.KD.text.BERT2LSTM.utils module

Module contents