238 lines
9.6 KiB
Python
238 lines
9.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from .residual_generator import ResidualBlock
|
|
from model.registry import MODEL
|
|
|
|
|
|
class RhoClipper(object):
|
|
def __init__(self, clip_min, clip_max):
|
|
self.clip_min = clip_min
|
|
self.clip_max = clip_max
|
|
assert clip_min < clip_max
|
|
|
|
def __call__(self, module):
|
|
if hasattr(module, 'rho'):
|
|
w = module.rho.data
|
|
w = w.clamp(self.clip_min, self.clip_max)
|
|
module.rho.data = w
|
|
|
|
|
|
@MODEL.register_module("UGATIT-Generator")
|
|
class Generator(nn.Module):
|
|
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False):
|
|
assert (num_blocks >= 0)
|
|
super(Generator, self).__init__()
|
|
self.input_channels = in_channels
|
|
self.output_channels = out_channels
|
|
self.base_channels = base_channels
|
|
self.num_blocks = num_blocks
|
|
self.img_size = img_size
|
|
self.light = light
|
|
|
|
down_encoder = [nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3,
|
|
padding_mode="reflect", bias=False),
|
|
nn.InstanceNorm2d(base_channels),
|
|
nn.ReLU(True)]
|
|
|
|
n_down_sampling = 2
|
|
for i in range(n_down_sampling):
|
|
mult = 2 ** i
|
|
down_encoder += [nn.Conv2d(base_channels * mult, base_channels * mult * 2, kernel_size=3, stride=2,
|
|
padding=1, bias=False, padding_mode="reflect"),
|
|
nn.InstanceNorm2d(base_channels * mult * 2),
|
|
nn.ReLU(True)]
|
|
|
|
# Down-Sampling Bottleneck
|
|
mult = 2 ** n_down_sampling
|
|
for i in range(num_blocks):
|
|
# TODO: change ResnetBlock to ResidualBlock, check use_bias param
|
|
down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)]
|
|
self.down_encoder = nn.Sequential(*down_encoder)
|
|
|
|
# Class Activation Map
|
|
self.gap_fc = nn.Linear(base_channels * mult, 1, bias=False)
|
|
self.gmp_fc = nn.Linear(base_channels * mult, 1, bias=False)
|
|
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
|
|
self.relu = nn.ReLU(True)
|
|
|
|
# Gamma, Beta block
|
|
if self.light:
|
|
fc = [nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
|
nn.ReLU(True),
|
|
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
|
nn.ReLU(True)]
|
|
else:
|
|
fc = [
|
|
nn.Linear(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, bias=False),
|
|
nn.ReLU(True),
|
|
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
|
nn.ReLU(True)]
|
|
self.fc = nn.Sequential(*fc)
|
|
|
|
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
|
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
|
|
|
# Up-Sampling Bottleneck
|
|
self.up_bottleneck = nn.ModuleList(
|
|
[ResnetAdaILNBlock(base_channels * mult, use_bias=False) for _ in range(num_blocks)])
|
|
|
|
# Up-Sampling
|
|
up_decoder = []
|
|
for i in range(n_down_sampling):
|
|
mult = 2 ** (n_down_sampling - i)
|
|
up_decoder += [nn.Upsample(scale_factor=2, mode='nearest'),
|
|
nn.Conv2d(base_channels * mult, base_channels * mult // 2, kernel_size=3, stride=1,
|
|
padding=1, padding_mode="reflect", bias=False),
|
|
ILN(base_channels * mult // 2),
|
|
nn.ReLU(True)]
|
|
|
|
up_decoder += [nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
|
|
padding_mode="reflect", bias=False),
|
|
nn.Tanh()]
|
|
self.up_decoder = nn.Sequential(*up_decoder)
|
|
|
|
def forward(self, x):
|
|
x = self.down_encoder(x)
|
|
|
|
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
|
|
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
|
|
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
|
|
|
|
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
|
|
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
|
|
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
|
|
|
|
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
|
|
|
|
x = torch.cat([gap, gmp], 1)
|
|
x = self.relu(self.conv1x1(x))
|
|
|
|
heatmap = torch.sum(x, dim=1, keepdim=True)
|
|
|
|
if self.light:
|
|
x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
|
|
x_ = self.fc(x_.view(x_.shape[0], -1))
|
|
else:
|
|
x_ = self.fc(x.view(x.shape[0], -1))
|
|
gamma, beta = self.gamma(x_), self.beta(x_)
|
|
|
|
for ub in self.up_bottleneck:
|
|
x = ub(x, gamma, beta)
|
|
|
|
x = self.up_decoder(x)
|
|
return x, cam_logit, heatmap
|
|
|
|
|
|
class ResnetAdaILNBlock(nn.Module):
|
|
def __init__(self, dim, use_bias):
|
|
super(ResnetAdaILNBlock, self).__init__()
|
|
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
|
|
self.norm1 = AdaILN(dim)
|
|
self.relu1 = nn.ReLU(True)
|
|
|
|
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
|
|
self.norm2 = AdaILN(dim)
|
|
|
|
def forward(self, x, gamma, beta):
|
|
out = self.conv1(x)
|
|
out = self.norm1(out, gamma, beta)
|
|
out = self.relu1(out)
|
|
out = self.conv2(out)
|
|
out = self.norm2(out, gamma, beta)
|
|
|
|
return out + x
|
|
|
|
|
|
def instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
|
|
in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True)
|
|
out_in = (x - in_mean) / torch.sqrt(in_var + eps)
|
|
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
|
out_ln = (x - ln_mean) / torch.sqrt(ln_var + eps)
|
|
out = rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - rho.expand(x.shape[0], -1, -1, -1)) * out_ln
|
|
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
|
|
return out
|
|
|
|
|
|
class AdaILN(nn.Module):
|
|
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
|
|
super(AdaILN, self).__init__()
|
|
self.eps = eps
|
|
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
|
self.rho.data.fill_(default_rho)
|
|
|
|
def forward(self, x, gamma, beta):
|
|
return instance_layer_normalization(x, gamma, beta, self.rho, self.eps)
|
|
|
|
|
|
class ILN(nn.Module):
|
|
def __init__(self, num_features, eps=1e-5):
|
|
super(ILN, self).__init__()
|
|
self.eps = eps
|
|
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
|
self.gamma = nn.Parameter(torch.Tensor(1, num_features))
|
|
self.beta = nn.Parameter(torch.Tensor(1, num_features))
|
|
self.rho.data.fill_(0.0)
|
|
self.gamma.data.fill_(1.0)
|
|
self.beta.data.fill_(0.0)
|
|
|
|
def forward(self, x):
|
|
return instance_layer_normalization(
|
|
x, self.gamma.expand(x.shape[0], -1), self.beta.expand(x.shape[0], -1), self.rho, self.eps)
|
|
|
|
|
|
@MODEL.register_module("UGATIT-Discriminator")
|
|
class Discriminator(nn.Module):
|
|
def __init__(self, in_channels, base_channels=64, num_blocks=5):
|
|
super(Discriminator, self).__init__()
|
|
encoder = [self.build_conv_block(in_channels, base_channels)]
|
|
|
|
for i in range(1, num_blocks - 2):
|
|
mult = 2 ** (i - 1)
|
|
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2))
|
|
|
|
mult = 2 ** (num_blocks - 2 - 1)
|
|
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2, stride=1))
|
|
|
|
self.encoder = nn.Sequential(*encoder)
|
|
|
|
# Class Activation Map
|
|
mult = 2 ** (num_blocks - 2)
|
|
self.gap_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
|
|
self.gmp_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
|
|
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
|
|
self.leaky_relu = nn.LeakyReLU(0.2, True)
|
|
|
|
self.conv = nn.utils.spectral_norm(
|
|
nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect"))
|
|
|
|
@staticmethod
|
|
def build_conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"):
|
|
return nn.Sequential(*[
|
|
nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
|
bias=True, padding=padding, padding_mode=padding_mode)),
|
|
nn.LeakyReLU(0.2, True),
|
|
])
|
|
|
|
def forward(self, x, return_heatmap=False):
|
|
x = self.encoder(x)
|
|
batch_size = x.size(0)
|
|
|
|
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) # B x C x 1 x 1, avg of per channel
|
|
gap_logit = self.gap_fc(gap.view(batch_size, -1))
|
|
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
|
|
|
|
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
|
|
gmp_logit = self.gmp_fc(gmp.view(batch_size, -1))
|
|
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
|
|
|
|
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
|
|
|
|
x = torch.cat([gap, gmp], 1)
|
|
x = self.leaky_relu(self.conv1x1(x))
|
|
|
|
if return_heatmap:
|
|
heatmap = torch.sum(x, dim=1, keepdim=True)
|
|
return self.conv(x), cam_logit, heatmap
|
|
else:
|
|
return self.conv(x), cam_logit
|