106 lines
3.8 KiB
Python
106 lines
3.8 KiB
Python
import math
|
|
|
|
import torch.nn as nn
|
|
|
|
from .registry import MODEL
|
|
|
|
|
|
# --- gaussian initialize ---
|
|
def init_layer(l):
|
|
# Initialization using fan-in
|
|
if isinstance(l, nn.Conv2d):
|
|
n = l.kernel_size[0] * l.kernel_size[1] * l.out_channels
|
|
l.weight.data.normal_(0, math.sqrt(2.0 / float(n)))
|
|
elif isinstance(l, nn.BatchNorm2d):
|
|
l.weight.data.fill_(1)
|
|
l.bias.data.fill_(0)
|
|
elif isinstance(l, nn.Linear):
|
|
l.bias.data.fill_(0)
|
|
|
|
|
|
class Flatten(nn.Module):
|
|
def __init__(self):
|
|
super(Flatten, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x.view(x.size(0), -1)
|
|
|
|
|
|
class SimpleBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels, half_res, leakyrelu=False):
|
|
super(SimpleBlock, self).__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
|
|
self.block = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
)
|
|
self.relu = nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True)
|
|
if in_channels != out_channels:
|
|
self.shortcut = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 1, 2 if half_res else 1, bias=False),
|
|
nn.BatchNorm2d(out_channels)
|
|
)
|
|
else:
|
|
self.shortcut = nn.Identity()
|
|
|
|
def forward(self, x):
|
|
o = self.block(x)
|
|
return self.relu(o + self.shortcut(x))
|
|
|
|
|
|
class ResNet(nn.Module):
|
|
def __init__(self, block, layers, dims, num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
|
|
super().__init__()
|
|
assert len(layers) == 4, 'Can have only four stages'
|
|
self.inplanes = 64
|
|
|
|
self.start = nn.Sequential(
|
|
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
|
|
nn.BatchNorm2d(self.inplanes),
|
|
nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True),
|
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
)
|
|
|
|
trunk = []
|
|
in_channels = self.inplanes
|
|
for i in range(4):
|
|
for j in range(layers[i]):
|
|
half_res = i >= 1 and j == 0
|
|
trunk.append(block(in_channels, dims[i], half_res, leakyrelu))
|
|
in_channels = dims[i]
|
|
if flatten:
|
|
trunk.append(nn.AvgPool2d(7))
|
|
trunk.append(Flatten())
|
|
if num_classes is not None:
|
|
if classifier_type == "linear":
|
|
trunk.append(nn.Linear(in_channels, num_classes))
|
|
elif classifier_type == "distlinear":
|
|
pass
|
|
else:
|
|
raise ValueError(f"invalid classifier_type:{classifier_type}")
|
|
self.trunk = nn.Sequential(*trunk)
|
|
self.apply(init_layer)
|
|
|
|
def forward(self, x):
|
|
return self.trunk(self.start(x))
|
|
|
|
|
|
@MODEL.register_module()
|
|
def resnet10(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
|
|
return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
|
|
|
|
|
|
@MODEL.register_module()
|
|
def resnet18(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
|
|
return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
|
|
|
|
|
|
@MODEL.register_module()
|
|
def resnet34(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
|
|
return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
|