KD_Lib.Pruning.lottery_tickets

KD_Lib.Pruning.lottery_tickets.lottery_tickets module

class KD_Lib.Pruning.lottery_tickets.lottery_tickets.LotteryTicketsPruner(model, train_loader, test_loader, loss_fn=CrossEntropyLoss(), device='cpu')[source]

Bases: KD_Lib.Pruning.common.iterative_base_class.BaseIterativePruner

Implementation of Lottery Tickets Pruning for PyTorch models.

Parameters:
  • model (torch.nn.Module) – Model that needs to be pruned
  • train_loader (torch.utils.data.DataLoader) – Dataloader for training
  • test_loader (torch.utils.data.DataLoader) – Dataloader for validation/testing
  • loss_fn (torch.nn.Module) – Loss function to be used for training
  • device (torch.device) – Device used for implementation (“cpu” by default)
prune_model(prune_percent=10)[source]

Function used for pruning

Parameters:prune_percent (int) – Pruning percent per iteration (percentage of alive weights to zero per pruning iteration)