Source code for KD_Lib.Pruning.lottery_tickets.lottery_tickets

import copy

import numpy as np
import torch
import torch.nn as nn

from ..common import BaseIterativePruner


[docs]class LotteryTicketsPruner(BaseIterativePruner): """ Implementation of Lottery Tickets Pruning for PyTorch models. :param model: Model that needs to be pruned :type model: torch.nn.Module :param train_loader: Dataloader for training :type train_loader: torch.utils.data.DataLoader :param test_loader: Dataloader for validation/testing :type test_loader: torch.utils.data.DataLoader :param loss_fn: Loss function to be used for training :type loss_fn: torch.nn.Module :param device: Device used for implementation ("cpu" by default) :type device: torch.device """ def __init__( self, model, train_loader, test_loader, loss_fn=nn.CrossEntropyLoss(), device="cpu", ): super().__init__(model, train_loader, test_loader, loss_fn, device) self.initial_state_dict = copy.deepcopy(self.model.state_dict())
[docs] def prune_model(self, prune_percent=10): """ Function used for pruning :param prune_percent: Pruning percent per iteration (percentage of alive weights to zero per pruning iteration) :type prune_percent: int """ for name, param in self.model.named_parameters(): if "weight" in name: param_data = param.data.cpu().numpy() alive = param_data[np.nonzero(param_data)] percentile = np.percentile(abs(alive), prune_percent) new_param_data = np.where( abs(param_data) < percentile, 0, self.initial_state_dict[name] ) param.data = torch.from_numpy(new_param_data).to(param.device) if "bias" in name: param.data = self.initial_state_dict[name]