KD_Lib.KD.text.BERT2LSTM package¶
Submodules¶
KD_Lib.KD.text.BERT2LSTM.bert2lstm module¶
-
class
KD_Lib.KD.text.BERT2LSTM.bert2lstm.
BERT2LSTM
(student_model, distill_train_loader, distill_val_loader, optimizer_student, train_df, val_df, num_classes=2, seed=42, distil_weight=0.5, device='cpu', log=False, logdir='./Experiments', max_seq_length=128)[source]¶ Bases:
KD_Lib.KD.common.base_class.BaseClass
Implementation of Knowledge distillation from the paper “Distilling Task-Specific Knowledge from BERT into Simple Neural Networks” https://arxiv.org/pdf/1903.12136.pdf
Parameters: - (torch.nn.Module) (student_model) – Student model
- (torch.utils.data.DataLoader) (distill_val_loader) – Student Training Dataloader for distillation
- (torch.utils.data.DataLoader) – Student Testing/validation Dataloader
- (pandas.DataFrame) (val_df) – Dataframe for training the teacher model
- (pandas.DataFrame) – Dataframe for validating the teacher model
- (torch.nn.module) (loss_fn) – Loss function
- (float) (distil_weight) – Temperature parameter for distillation
- (float) – Weight paramter for distillation loss
- (str) (logdir) – Device used for training; ‘cpu’ for cpu and ‘cuda’ for gpu
- (bool) (log) – True if logging required
- (str) – Directory for storing logs
-
calculate_kd_loss
(y_pred_student, y_pred_teacher, y_true)[source]¶ Function used for calculating the KD loss during distillation
Parameters: - (torch.FloatTensor) (y_true) – Prediction made by the student model
- (torch.FloatTensor) – Prediction made by the teacher model
- (torch.FloatTensor) – Original label
-
evaluate_student
(verbose=True)[source]¶ Function used for evaluating student
Parameters: (bool) (verbose) – True if the accuracy needs to be printed else False
-
evaluate_teacher
(val_batch_size=16, verbose=True)[source]¶ Function used for evaluating student
Parameters: - (int) (val_batch_size) – Maximum sequence length paramter for generating dataloaders
- (int) – Batch size paramter for generating dataloaders
- (bool) (verbose) – True if the accuracy needs to be printed else False
-
train_student
(epochs=10, plot_losses=True, save_model=True, save_model_pth='./models/student.pth')[source]¶ Function that will be training the student
Parameters: - (int) (epochs) – Number of epochs you want to train the teacher
- (bool) (save_model) – True if you want to plot the losses
- (bool) – True if you want to save the student model
- (str) (save_model_pth) – Path where you want to save the student model
-
train_teacher
(epochs=1, plot_losses=True, save_model=True, save_model_pth='./models/teacher.pt', train_batch_size=16, batch_print_freq=40, val_batch_size=16)[source]¶ Function that will be training the teacher
Parameters: - (int) (batch_print_freq) – Number of epochs you want to train the teacher
- (bool) (save_model) – True if you want to plot the losses
- (bool) – True if you want to save the teacher model
- (str) (save_model_pth) – Path where you want to store the teacher model
- (int) – Batch size paramter for generating dataloaders
- (int) – Frequency at which batch number needs to be printed per epoch