KD_Lib.Pruning.lottery_tickets¶
KD_Lib.Pruning.lottery_tickets.lottery_tickets module¶
-
class
KD_Lib.Pruning.lottery_tickets.lottery_tickets.
LotteryTicketsPruner
(model, train_loader, test_loader, loss_fn=CrossEntropyLoss(), device='cpu')[source]¶ Bases:
KD_Lib.Pruning.common.iterative_base_class.BaseIterativePruner
Implementation of Lottery Tickets Pruning for PyTorch models.
Parameters: - model (torch.nn.Module) – Model that needs to be pruned
- train_loader (torch.utils.data.DataLoader) – Dataloader for training
- test_loader (torch.utils.data.DataLoader) – Dataloader for validation/testing
- loss_fn (torch.nn.Module) – Loss function to be used for training
- device (torch.device) – Device used for implementation (“cpu” by default)