Source code for KD_Lib.models.resnet

import torch.nn as nn
import torch.nn.functional as F


[docs]class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=1, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d( in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False, ), nn.BatchNorm2d(self.expansion * planes), )
[docs] def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out
[docs]class Bottleneck(nn.Module): expansion = 4 def __init__(self, in_planes, planes, stride=1): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d( planes, self.expansion * planes, kernel_size=1, bias=False ) self.bn3 = nn.BatchNorm2d(self.expansion * planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d( in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False, ), nn.BatchNorm2d(self.expansion * planes), )
[docs] def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = F.relu(self.bn2(self.conv2(out))) out = self.bn3(self.conv3(out)) out += self.shortcut(x) out = F.relu(out) return out
[docs]class ResNet(nn.Module): def __init__(self, block, num_blocks, params, num_channel=3, num_classes=10): super(ResNet, self).__init__() self.in_planes = params[0] self.conv1 = nn.Conv2d( num_channel, params[0], kernel_size=3, stride=1, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(params[0]) self.layer1 = self._make_layer(block, params[1], num_blocks[0], 1) self.layer2 = self._make_layer(block, params[2], num_blocks[1], 2) self.layer3 = self._make_layer(block, params[3], num_blocks[2], 2) self.layer4 = self._make_layer(block, params[4], num_blocks[3], 2) self.linear = nn.Linear(params[4] * block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers)
[docs] def forward(self, x, out_feature=False): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) feature = out.view(out.size(0), -1) out = self.linear(feature) if not out_feature: return out else: return out, feature
[docs]class ResnetWithAT(ResNet):
[docs] def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = F.relu(out) at1 = self.layer1(out) at2 = self.layer2(at1) at3 = self.layer3(at2) at4 = self.layer4(at3) out = F.avg_pool2d(at4, 4) feature = out.view(out.size(0), -1) out = self.linear(feature) return out, at1, at2, at3, at4
[docs]class MeanResnet(ResNet): def __init__(self, block, num_blocks, params, num_channel=3, num_classes=10): super(MeanResnet, self).__init__( block, num_blocks, params, num_channel, num_classes ) self.linear2 = nn.Linear(params[4] * block.expansion, num_classes)
[docs] def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = F.relu(out) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) out = out.view(out.size(0), -1) return self.linear(out), self.linear2(out)
[docs]def ResNet18(parameters, num_channel=3, num_classes=10, att=False, mean=False): """ Function that creates a ResNet 18 model :param parameters (list or tuple): List of parameters for the model :param num_channel (int): Number of channels in input specimens :param num_classes (int): Number of classes for classification :param att (bool): True if attention needs to be used :param mean (bool): True if mean teacher model needs to be used """ model = ResNet if att and not mean: model = ResnetWithAT elif not att and mean: model = MeanResnet return model( BasicBlock, [2, 2, 2, 2], parameters, num_channel, num_classes=num_classes )
[docs]def ResNet34(parameters, num_channel=3, num_classes=10, att=False, mean=False): """ Function that creates a ResNet 34 model :param parameters (list or tuple): List of parameters for the model :param num_channel (int): Number of channels in input specimens :param num_classes (int): Number of classes for classification :param att (bool): True if attention needs to be used :param mean (bool): True if mean teacher model needs to be used """ model = ResNet if att and not mean: model = ResnetWithAT elif not att and mean: model = MeanResnet return model( BasicBlock, [3, 4, 6, 3], parameters, num_channel, num_classes=num_classes )
[docs]def ResNet50(parameters, num_channel=3, num_classes=10, att=False, mean=False): """ Function that creates a ResNet 50 model :param parameters (list or tuple): List of parameters for the model :param num_channel (int): Number of channels in input specimens :param num_classes (int): Number of classes for classification :param att (bool): True if attention needs to be used :param mean (bool): True if mean teacher model needs to be used """ model = ResNet if att and not mean: model = ResnetWithAT elif not att and mean: model = MeanResnet return model( Bottleneck, [3, 4, 6, 3], parameters, num_channel, num_classes=num_classes )
[docs]def ResNet101(parameters, num_channel=3, num_classes=10, att=False, mean=False): """ Function that creates a ResNet 101 model :param parameters (list or tuple): List of parameters for the model :param num_channel (int): Number of channels in input specimens :param num_classes (int): Number of classes for classification :param att (bool): True if attention needs to be used :param mean (bool): True if mean teacher model needs to be used """ model = ResNet if att and not mean: model = ResnetWithAT elif not att and mean: model = MeanResnet return model( Bottleneck, [3, 4, 23, 3], parameters, num_channel, num_classes=num_classes )
[docs]def ResNet152(parameters, num_channel=3, num_classes=10, att=False, mean=False): """ Function that creates a ResNet 152 model :param parameters (list or tuple): List of parameters for the model :param num_channel (int): Number of channels in input specimens :param num_classes (int): Number of classes for classification :param att (bool): True if attention needs to be used :param mean (bool): True if mean teacher model needs to be used """ model = ResNet if att and not mean: model = ResnetWithAT elif not att and mean: model = MeanResnet return model( Bottleneck, [3, 8, 36, 3], parameters, num_channel, num_classes=num_classes )
resnet_book = { "18": ResNet18, "34": ResNet34, "50": ResNet50, "101": ResNet101, "152": ResNet152, }