import math import torch.nn as nn from .registry import MODEL # --- gaussian initialize --- def init_layer(l): # Initialization using fan-in if isinstance(l, nn.Conv2d): n = l.kernel_size[0] * l.kernel_size[1] * l.out_channels l.weight.data.normal_(0, math.sqrt(2.0 / float(n))) elif isinstance(l, nn.BatchNorm2d): l.weight.data.fill_(1) l.bias.data.fill_(0) elif isinstance(l, nn.Linear): l.bias.data.fill_(0) class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): return x.view(x.size(0), -1) class SimpleBlock(nn.Module): def __init__(self, in_channels, out_channels, half_res, leakyrelu=False): super(SimpleBlock, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.block = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), ) self.relu = nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True) if in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, 2 if half_res else 1, bias=False), nn.BatchNorm2d(out_channels) ) else: self.shortcut = nn.Identity() def forward(self, x): o = self.block(x) return self.relu(o + self.shortcut(x)) class ResNet(nn.Module): def __init__(self, block, layers, dims, num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False): super().__init__() assert len(layers) == 4, 'Can have only four stages' self.inplanes = 64 self.start = nn.Sequential( nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) ) trunk = [] in_channels = self.inplanes for i in range(4): for j in range(layers[i]): half_res = i >= 1 and j == 0 trunk.append(block(in_channels, dims[i], half_res, leakyrelu)) in_channels = dims[i] if flatten: trunk.append(nn.AvgPool2d(7)) trunk.append(Flatten()) if num_classes is not None: if classifier_type == "linear": trunk.append(nn.Linear(in_channels, num_classes)) elif classifier_type == "distlinear": pass else: raise ValueError(f"invalid classifier_type:{classifier_type}") self.trunk = nn.Sequential(*trunk) self.apply(init_layer) def forward(self, x): return self.trunk(self.start(x)) @MODEL.register_module() def resnet10(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False): return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu) @MODEL.register_module() def resnet18(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False): return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu) @MODEL.register_module() def resnet34(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False): return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)