Self Training using KD_LibΒΆ

Paper

  • The student model is first trained in the normal way to obtain a pre-trained model, which is then used as the teacher to train itself by transferring soft targets

To use the self training algorithm to train a student on MNIST for 5 epcohs -

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import SelfTraining

# Define datasets, dataloaders, models and optimizers

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

# Set device to be trained on

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define student model

student_model = <your model>

# Define optimizer

student_optimizer = optim.SGD(student_model.parameters(), lr=0.01)


# Train using KD_Lib

distiller = SelfTraining(student_model, train_loader, test_loader, student_optimizer,
                         device=device)
distiller.train_student(epochs=5)                                      # Train the student model
distiller.evaluate()                                                    # Evaluate the student model