import os
from copy import deepcopy
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
[docs]class BaseClass:
"""
Basic implementation of a general Knowledge Distillation framework
:param teacher_model (torch.nn.Module): Teacher model
:param student_model (torch.nn.Module): Student model
:param train_loader (torch.utils.data.DataLoader): Dataloader for training
:param val_loader (torch.utils.data.DataLoader): Dataloader for validation/testing
:param optimizer_teacher (torch.optim.*): Optimizer used for training teacher
:param optimizer_student (torch.optim.*): Optimizer used for training student
:param loss_fn (torch.nn.Module): Loss Function used for distillation
:param temp (float): Temperature parameter for distillation
:param distil_weight (float): Weight paramter for distillation loss
:param device (str): Device used for training; 'cpu' for cpu and 'cuda' for gpu
:param log (bool): True if logging required
:param logdir (str): Directory for storing logs
"""
def __init__(
self,
teacher_model,
student_model,
train_loader,
val_loader,
optimizer_teacher,
optimizer_student,
loss_fn=nn.KLDivLoss(),
temp=20.0,
distil_weight=0.5,
device="cpu",
log=False,
logdir="./Experiments",
):
self.train_loader = train_loader
self.val_loader = val_loader
self.optimizer_teacher = optimizer_teacher
self.optimizer_student = optimizer_student
self.temp = temp
self.distil_weight = distil_weight
self.log = log
self.logdir = logdir
if self.log:
self.writer = SummaryWriter(logdir)
if device == "cpu":
self.device = torch.device("cpu")
elif device == "cuda":
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
print(
"Either an invalid device or CUDA is not available. Defaulting to CPU."
)
self.device = torch.device("cpu")
if teacher_model:
self.teacher_model = teacher_model.to(self.device)
else:
print("Warning!!! Teacher is NONE.")
self.student_model = student_model.to(self.device)
self.loss_fn = loss_fn.to(self.device)
self.ce_fn = nn.CrossEntropyLoss().to(self.device)
[docs] def train_teacher(
self,
epochs=20,
plot_losses=True,
save_model=True,
save_model_pth="./models/teacher.pt",
):
"""
Function that will be training the teacher
:param epochs (int): Number of epochs you want to train the teacher
:param plot_losses (bool): True if you want to plot the losses
:param save_model (bool): True if you want to save the teacher model
:param save_model_pth (str): Path where you want to store the teacher model
"""
self.teacher_model.train()
loss_arr = []
length_of_dataset = len(self.train_loader.dataset)
best_acc = 0.0
self.best_teacher_model_weights = deepcopy(self.teacher_model.state_dict())
save_dir = os.path.dirname(save_model_pth)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print("Training Teacher... ")
for ep in range(epochs):
epoch_loss = 0.0
correct = 0
for (data, label) in self.train_loader:
data = data.to(self.device)
label = label.to(self.device)
out = self.teacher_model(data)
if isinstance(out, tuple):
out = out[0]
pred = out.argmax(dim=1, keepdim=True)
correct += pred.eq(label.view_as(pred)).sum().item()
loss = self.ce_fn(out, label)
self.optimizer_teacher.zero_grad()
loss.backward()
self.optimizer_teacher.step()
epoch_loss += loss.item()
epoch_acc = correct / length_of_dataset
epoch_val_acc = self.evaluate(teacher=True)
if epoch_val_acc > best_acc:
best_acc = epoch_val_acc
self.best_teacher_model_weights = deepcopy(
self.teacher_model.state_dict()
)
if self.log:
self.writer.add_scalar("Training loss/Teacher", epoch_loss, epochs)
self.writer.add_scalar("Training accuracy/Teacher", epoch_acc, epochs)
self.writer.add_scalar(
"Validation accuracy/Teacher", epoch_val_acc, epochs
)
loss_arr.append(epoch_loss)
print(
"Epoch: {}, Loss: {}, Accuracy: {}".format(
ep + 1, epoch_loss, epoch_acc
)
)
self.post_epoch_call(ep)
self.teacher_model.load_state_dict(self.best_teacher_model_weights)
if save_model:
torch.save(self.teacher_model.state_dict(), save_model_pth)
if plot_losses:
plt.plot(loss_arr)
def _train_student(
self,
epochs=10,
plot_losses=True,
save_model=True,
save_model_pth="./models/student.pt",
):
"""
Function to train student model - for internal use only.
:param epochs (int): Number of epochs you want to train the teacher
:param plot_losses (bool): True if you want to plot the losses
:param save_model (bool): True if you want to save the student model
:param save_model_pth (str): Path where you want to save the student model
"""
self.teacher_model.eval()
self.student_model.train()
loss_arr = []
length_of_dataset = len(self.train_loader.dataset)
best_acc = 0.0
self.best_student_model_weights = deepcopy(self.student_model.state_dict())
save_dir = os.path.dirname(save_model_pth)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print("Training Student...")
for ep in range(epochs):
epoch_loss = 0.0
correct = 0
for (data, label) in self.train_loader:
data = data.to(self.device)
label = label.to(self.device)
student_out = self.student_model(data)
teacher_out = self.teacher_model(data)
loss = self.calculate_kd_loss(student_out, teacher_out, label)
if isinstance(student_out, tuple):
student_out = student_out[0]
pred = student_out.argmax(dim=1, keepdim=True)
correct += pred.eq(label.view_as(pred)).sum().item()
self.optimizer_student.zero_grad()
loss.backward()
self.optimizer_student.step()
epoch_loss += loss.item()
epoch_acc = correct / length_of_dataset
_, epoch_val_acc = self._evaluate_model(self.student_model, verbose=True)
if epoch_val_acc > best_acc:
best_acc = epoch_val_acc
self.best_student_model_weights = deepcopy(
self.student_model.state_dict()
)
if self.log:
self.writer.add_scalar("Training loss/Student", epoch_loss, epochs)
self.writer.add_scalar("Training accuracy/Student", epoch_acc, epochs)
self.writer.add_scalar(
"Validation accuracy/Student", epoch_val_acc, epochs
)
loss_arr.append(epoch_loss)
print(
"Epoch: {}, Loss: {}, Accuracy: {}".format(
ep + 1, epoch_loss, epoch_acc
)
)
self.student_model.load_state_dict(self.best_student_model_weights)
if save_model:
torch.save(self.student_model.state_dict(), save_model_pth)
if plot_losses:
plt.plot(loss_arr)
[docs] def train_student(
self,
epochs=10,
plot_losses=True,
save_model=True,
save_model_pth="./models/student.pt",
):
"""
Function that will be training the student
:param epochs (int): Number of epochs you want to train the teacher
:param plot_losses (bool): True if you want to plot the losses
:param save_model (bool): True if you want to save the student model
:param save_model_pth (str): Path where you want to save the student model
"""
self._train_student(epochs, plot_losses, save_model, save_model_pth)
[docs] def calculate_kd_loss(self, y_pred_student, y_pred_teacher, y_true):
"""
Custom loss function to calculate the KD loss for various implementations
:param y_pred_student (Tensor): Predicted outputs from the student network
:param y_pred_teacher (Tensor): Predicted outputs from the teacher network
:param y_true (Tensor): True labels
"""
raise NotImplementedError
def _evaluate_model(self, model, verbose=True):
"""
Evaluate the given model's accuaracy over val set.
For internal use only.
:param model (nn.Module): Model to be used for evaluation
:param verbose (bool): Display Accuracy
"""
model.eval()
length_of_dataset = len(self.val_loader.dataset)
correct = 0
outputs = []
with torch.no_grad():
for data, target in self.val_loader:
data = data.to(self.device)
target = target.to(self.device)
output = model(data)
if isinstance(output, tuple):
output = output[0]
outputs.append(output)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / length_of_dataset
if verbose:
print("-" * 80)
print("Validation Accuracy: {}".format(accuracy))
return outputs, accuracy
[docs] def evaluate(self, teacher=False):
"""
Evaluate method for printing accuracies of the trained network
:param teacher (bool): True if you want accuracy of the teacher network
"""
if teacher:
model = deepcopy(self.teacher_model).to(self.device)
else:
model = deepcopy(self.student_model).to(self.device)
_, accuracy = self._evaluate_model(model)
return accuracy
[docs] def get_parameters(self):
"""
Get the number of parameters for the teacher and the student network
"""
teacher_params = sum(p.numel() for p in self.teacher_model.parameters())
student_params = sum(p.numel() for p in self.student_model.parameters())
print("-" * 80)
print("Total parameters for the teacher network are: {}".format(teacher_params))
print("Total parameters for the student network are: {}".format(student_params))
[docs] def post_epoch_call(self, epoch):
"""
Any changes to be made after an epoch is completed.
:param epoch (int) : current epoch number
:return : nothing (void)
"""
pass