139 lines
5.6 KiB
Python
139 lines
5.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from model import MODEL
|
|
from model.base.module import Conv2dBlock, LinearBlock
|
|
from model.image_translation.CycleGAN import Encoder, Decoder
|
|
|
|
|
|
class RhoClipper(object):
|
|
def __init__(self, clip_min, clip_max):
|
|
self.clip_min = clip_min
|
|
self.clip_max = clip_max
|
|
assert clip_min < clip_max
|
|
|
|
def __call__(self, module):
|
|
if hasattr(module, 'rho'):
|
|
w = module.rho.data
|
|
w = w.clamp(self.clip_min, self.clip_max)
|
|
module.rho.data = w
|
|
|
|
|
|
class CAMClassifier(nn.Module):
|
|
def __init__(self, in_channels, activation_type="ReLU"):
|
|
super(CAMClassifier, self).__init__()
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
self.avg_fc = nn.Linear(in_channels, 1, bias=False)
|
|
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
|
self.max_fc = nn.Linear(in_channels, 1, bias=False)
|
|
self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, kernel_size=1, stride=1, bias=True,
|
|
activation_type=activation_type, norm_type="NONE")
|
|
|
|
def forward(self, x):
|
|
avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1))
|
|
max_logit = self.max_fc(self.max_pool(x).view(x.size(0), -1))
|
|
|
|
return self.fusion_conv(torch.cat(
|
|
[x * self.avg_fc.weight.unsqueeze(2).unsqueeze(3), x * self.max_fc.weight.unsqueeze(2).unsqueeze(3)],
|
|
dim=1
|
|
)), torch.cat([avg_logit, max_logit], 1)
|
|
|
|
|
|
@MODEL.register_module("UGATIT-Generator")
|
|
class Generator(nn.Module):
|
|
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False,
|
|
activation_type="ReLU", norm_type="IN", padding_mode='reflect', pre_activation=False):
|
|
super(Generator, self).__init__()
|
|
|
|
self.light = light
|
|
|
|
n_down_sampling = 2
|
|
self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
|
|
padding_mode=padding_mode, activation_type=activation_type,
|
|
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type,
|
|
pre_activation=pre_activation)
|
|
mult = 2 ** n_down_sampling
|
|
self.cam = CAMClassifier(base_channels * mult, activation_type)
|
|
|
|
# Gamma, Beta block
|
|
if self.light:
|
|
self.fc = nn.Sequential(
|
|
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE"),
|
|
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
|
|
)
|
|
else:
|
|
self.fc = nn.Sequential(
|
|
LinearBlock(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, False,
|
|
"ReLU", "NONE"),
|
|
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
|
|
)
|
|
|
|
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
|
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
|
|
|
self.decoder = Decoder(
|
|
base_channels * mult, out_channels, n_down_sampling, num_blocks,
|
|
activation_type=activation_type, padding_mode=padding_mode,
|
|
up_conv_kernel_size=3, up_conv_norm_type="ILN",
|
|
res_norm_type="AdaILN", pre_activation=pre_activation
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.encoder(x)
|
|
|
|
x, cam_logit = self.cam(x)
|
|
|
|
heatmap = torch.sum(x, dim=1, keepdim=True)
|
|
|
|
if self.light:
|
|
x_ = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
|
x_ = self.fc(x_.view(x_.shape[0], -1))
|
|
else:
|
|
x_ = self.fc(x.view(x.shape[0], -1))
|
|
gamma, beta = self.gamma(x_), self.beta(x_)
|
|
|
|
for blk in self.decoder.residual_blocks:
|
|
blk.conv1.normalization.set_condition(gamma, beta)
|
|
blk.conv2.normalization.set_condition(gamma, beta)
|
|
return self.decoder(x), cam_logit, heatmap
|
|
|
|
|
|
@MODEL.register_module("UGATIT-Discriminator")
|
|
class Discriminator(nn.Module):
|
|
def __init__(self, in_channels, base_channels=64, num_blocks=5,
|
|
activation_type="LeakyReLU", norm_type="NONE", padding_mode='reflect'):
|
|
super().__init__()
|
|
|
|
sequence = [Conv2dBlock(
|
|
in_channels, base_channels, kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
|
activation_type=activation_type, norm_type=norm_type
|
|
)]
|
|
|
|
sequence += [Conv2dBlock(
|
|
base_channels * (2 ** i), base_channels * (2 ** i) * 2,
|
|
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
|
activation_type=activation_type, norm_type=norm_type) for i in range(num_blocks - 3)]
|
|
|
|
sequence.append(
|
|
Conv2dBlock(base_channels * (2 ** (num_blocks - 3)), base_channels * (2 ** (num_blocks - 2)),
|
|
kernel_size=4, stride=1, padding=1, padding_mode=padding_mode,
|
|
activation_type=activation_type, norm_type=norm_type)
|
|
)
|
|
self.sequence = nn.Sequential(*sequence)
|
|
|
|
mult = 2 ** (num_blocks - 2)
|
|
self.cam = CAMClassifier(base_channels * mult, activation_type)
|
|
self.conv = nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False,
|
|
padding_mode="reflect")
|
|
|
|
def forward(self, x, return_heatmap=False):
|
|
x = self.sequence(x)
|
|
|
|
x, cam_logit = self.cam(x)
|
|
|
|
if return_heatmap:
|
|
heatmap = torch.sum(x, dim=1, keepdim=True)
|
|
return self.conv(x), cam_logit, heatmap
|
|
else:
|
|
return self.conv(x), cam_logit
|