import torch import torch.nn as nn from model import MODEL from model.base.module import Conv2dBlock, LinearBlock from model.image_translation.CycleGAN import Encoder, Decoder class CAMClassifier(nn.Module): def __init__(self, in_channels, activation_type="ReLU"): super(CAMClassifier, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.avg_fc = nn.Linear(in_channels, 1, bias=False) self.max_pool = nn.AdaptiveMaxPool2d(1) self.max_fc = nn.Linear(in_channels, 1, bias=False) self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, kernel_size=1, stride=1, bias=True, activation_type=activation_type, norm_type="NONE") def forward(self, x): avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1)) max_logit = self.max_fc(self.max_pool(x).view(x.size(0), -1)) return self.fusion_conv(torch.cat( [x * self.avg_fc.weight.unsqueeze(2).unsqueeze(3), x * self.max_fc.weight.unsqueeze(2).unsqueeze(3)], dim=1 )), torch.cat([avg_logit, max_logit], 1) @MODEL.register_module("UGATIT-Generator") class Generator(nn.Module): def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False, activation_type="ReLU", norm_type="IN", padding_mode='reflect', pre_activation=False): super(Generator, self).__init__() self.light = light n_down_sampling = 2 self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks, padding_mode=padding_mode, activation_type=activation_type, down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type, pre_activation=pre_activation) mult = 2 ** n_down_sampling self.cam = CAMClassifier(base_channels * mult, activation_type) # Gamma, Beta block if self.light: self.fc = nn.Sequential( LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE"), LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE") ) else: self.fc = nn.Sequential( LinearBlock(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, False, "ReLU", "NONE"), LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE") ) self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False) self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False) self.decoder = Decoder( base_channels * mult, out_channels, n_down_sampling, num_blocks, activation_type=activation_type, padding_mode=padding_mode, up_conv_kernel_size=3, up_conv_norm_type="ILN", res_norm_type="AdaILN", pre_activation=pre_activation ) def forward(self, x): x = self.encoder(x) x, cam_logit = self.cam(x) heatmap = torch.sum(x, dim=1, keepdim=True) if self.light: x_ = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)) x_ = self.fc(x_.view(x_.shape[0], -1)) else: x_ = self.fc(x.view(x.shape[0], -1)) gamma, beta = self.gamma(x_), self.beta(x_) for blk in self.decoder.residual_blocks: blk.conv1.normalization.set_condition(gamma, beta) blk.conv2.normalization.set_condition(gamma, beta) return self.decoder(x), cam_logit, heatmap @MODEL.register_module("UGATIT-Discriminator") class Discriminator(nn.Module): def __init__(self, in_channels, base_channels=64, num_blocks=5, activation_type="LeakyReLU", norm_type="NONE", padding_mode='reflect'): super().__init__() sequence = [Conv2dBlock( in_channels, base_channels, kernel_size=4, stride=2, padding=1, padding_mode=padding_mode, activation_type=activation_type, norm_type=norm_type )] sequence += [Conv2dBlock( base_channels * (2 ** i), base_channels * (2 ** i) * 2, kernel_size=4, stride=2, padding=1, padding_mode=padding_mode, activation_type=activation_type, norm_type=norm_type) for i in range(num_blocks - 3)] sequence.append( Conv2dBlock(base_channels * (2 ** (num_blocks - 3)), base_channels * (2 ** (num_blocks - 2)), kernel_size=4, stride=1, padding=1, padding_mode=padding_mode, activation_type=activation_type, norm_type=norm_type) ) self.sequence = nn.Sequential(*sequence) mult = 2 ** (num_blocks - 2) self.cam = CAMClassifier(base_channels * mult, activation_type) self.conv = nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect") def forward(self, x, return_heatmap=False): x = self.sequence(x) x, cam_logit = self.cam(x) if return_heatmap: heatmap = torch.sum(x, dim=1, keepdim=True) return self.conv(x), cam_logit, heatmap else: return self.conv(x), cam_logit