import torch import torch.nn as nn from .residual_generator import ResidualBlock from model.registry import MODEL class RhoClipper(object): def __init__(self, clip_min, clip_max): self.clip_min = clip_min self.clip_max = clip_max assert clip_min < clip_max def __call__(self, module): if hasattr(module, 'rho'): w = module.rho.data w = w.clamp(self.clip_min, self.clip_max) module.rho.data = w @MODEL.register_module("UGATIT-Generator") class Generator(nn.Module): def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False): assert (num_blocks >= 0) super(Generator, self).__init__() self.input_channels = in_channels self.output_channels = out_channels self.base_channels = base_channels self.num_blocks = num_blocks self.img_size = img_size self.light = light down_encoder = [nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect", bias=False), nn.InstanceNorm2d(base_channels), nn.ReLU(True)] n_down_sampling = 2 for i in range(n_down_sampling): mult = 2 ** i down_encoder += [nn.Conv2d(base_channels * mult, base_channels * mult * 2, kernel_size=3, stride=2, padding=1, bias=False, padding_mode="reflect"), nn.InstanceNorm2d(base_channels * mult * 2), nn.ReLU(True)] # Down-Sampling Bottleneck mult = 2 ** n_down_sampling for i in range(num_blocks): # TODO: change ResnetBlock to ResidualBlock, check use_bias param down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)] self.down_encoder = nn.Sequential(*down_encoder) # Class Activation Map self.gap_fc = nn.Linear(base_channels * mult, 1, bias=False) self.gmp_fc = nn.Linear(base_channels * mult, 1, bias=False) self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True) self.relu = nn.ReLU(True) # Gamma, Beta block if self.light: fc = [nn.Linear(base_channels * mult, base_channels * mult, bias=False), nn.ReLU(True), nn.Linear(base_channels * mult, base_channels * mult, bias=False), nn.ReLU(True)] else: fc = [ nn.Linear(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, bias=False), nn.ReLU(True), nn.Linear(base_channels * mult, base_channels * mult, bias=False), nn.ReLU(True)] self.fc = nn.Sequential(*fc) self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False) self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False) # Up-Sampling Bottleneck self.up_bottleneck = nn.ModuleList( [ResnetAdaILNBlock(base_channels * mult, use_bias=False) for _ in range(num_blocks)]) # Up-Sampling up_decoder = [] for i in range(n_down_sampling): mult = 2 ** (n_down_sampling - i) up_decoder += [nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(base_channels * mult, base_channels * mult // 2, kernel_size=3, stride=1, padding=1, padding_mode="reflect", bias=False), ILN(base_channels * mult // 2), nn.ReLU(True)] up_decoder += [nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect", bias=False), nn.Tanh()] self.up_decoder = nn.Sequential(*up_decoder) # self.up_decoder = nn.ModuleDict({ # "up_1": nn.Upsample(scale_factor=2, mode='nearest'), # "up_conv_1": nn.Sequential( # nn.Conv2d(base_channels * 4, base_channels * 4 // 2, kernel_size=3, stride=1, # padding=1, padding_mode="reflect", bias=False), # ILN(base_channels * 4 // 2), # nn.ReLU(True)), # "up_2": nn.Upsample(scale_factor=2, mode='nearest'), # "up_conv_2": nn.Sequential( # nn.Conv2d(base_channels * 2, base_channels * 2 // 2, kernel_size=3, stride=1, # padding=1, padding_mode="reflect", bias=False), # ILN(base_channels * 2 // 2), # nn.ReLU(True)), # "up_end": nn.Sequential(nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3, # padding_mode="reflect", bias=False), nn.Tanh()) # }) def forward(self, x): x = self.down_encoder(x) gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3) gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3) cam_logit = torch.cat([gap_logit, gmp_logit], 1) x = torch.cat([gap, gmp], 1) x = self.relu(self.conv1x1(x)) heatmap = torch.sum(x, dim=1, keepdim=True) if self.light: x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1) x_ = self.fc(x_.view(x_.shape[0], -1)) else: x_ = self.fc(x.view(x.shape[0], -1)) gamma, beta = self.gamma(x_), self.beta(x_) for ub in self.up_bottleneck: x = ub(x, gamma, beta) x = self.up_decoder(x) return x, cam_logit, heatmap class ResnetAdaILNBlock(nn.Module): def __init__(self, dim, use_bias): super(ResnetAdaILNBlock, self).__init__() self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect") self.norm1 = AdaILN(dim) self.relu1 = nn.ReLU(True) self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect") self.norm2 = AdaILN(dim) def forward(self, x, gamma, beta): out = self.conv1(x) out = self.norm1(out, gamma, beta) out = self.relu1(out) out = self.conv2(out) out = self.norm2(out, gamma, beta) return out + x def instance_layer_normalization(x, gamma, beta, rho, eps=1e-5): in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True) out_in = (x - in_mean) / torch.sqrt(in_var + eps) ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True) out_ln = (x - ln_mean) / torch.sqrt(ln_var + eps) out = rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - rho.expand(x.shape[0], -1, -1, -1)) * out_ln out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3) return out class AdaILN(nn.Module): def __init__(self, num_features, eps=1e-5, default_rho=0.9): super(AdaILN, self).__init__() self.eps = eps self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) self.rho.data.fill_(default_rho) def forward(self, x, gamma, beta): return instance_layer_normalization(x, gamma, beta, self.rho, self.eps) class ILN(nn.Module): def __init__(self, num_features, eps=1e-5): super(ILN, self).__init__() self.eps = eps self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) self.gamma = nn.Parameter(torch.Tensor(1, num_features)) self.beta = nn.Parameter(torch.Tensor(1, num_features)) self.rho.data.fill_(0.0) self.gamma.data.fill_(1.0) self.beta.data.fill_(0.0) def forward(self, x): return instance_layer_normalization( x, self.gamma.expand(x.shape[0], -1), self.beta.expand(x.shape[0], -1), self.rho, self.eps) @MODEL.register_module("UGATIT-Discriminator") class Discriminator(nn.Module): def __init__(self, in_channels, base_channels=64, num_blocks=5): super(Discriminator, self).__init__() encoder = [self.build_conv_block(in_channels, base_channels)] for i in range(1, num_blocks - 2): mult = 2 ** (i - 1) encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2)) mult = 2 ** (num_blocks - 2 - 1) encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2, stride=1)) self.encoder = nn.Sequential(*encoder) # Class Activation Map mult = 2 ** (num_blocks - 2) self.gap_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False)) self.gmp_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False)) self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True) self.leaky_relu = nn.LeakyReLU(0.2, True) self.conv = nn.utils.spectral_norm( nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect")) @staticmethod def build_conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"): return nn.Sequential(*[ nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=True, padding=padding, padding_mode=padding_mode)), nn.LeakyReLU(0.2, True), ]) def forward(self, x, return_heatmap=False): x = self.encoder(x) batch_size = x.size(0) gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) # B x C x 1 x 1, avg of per channel gap_logit = self.gap_fc(gap.view(batch_size, -1)) gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3) gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) gmp_logit = self.gmp_fc(gmp.view(batch_size, -1)) gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3) cam_logit = torch.cat([gap_logit, gmp_logit], 1) x = torch.cat([gap, gmp], 1) x = self.leaky_relu(self.conv1x1(x)) if return_heatmap: heatmap = torch.sum(x, dim=1, keepdim=True) return self.conv(x), cam_logit, heatmap else: return self.conv(x), cam_logit