diff --git a/model/GAN/TAHG.py b/model/GAN/TAHG.py new file mode 100644 index 0000000..963230c --- /dev/null +++ b/model/GAN/TAHG.py @@ -0,0 +1,177 @@ +import torch +import torch.nn as nn +from .residual_generator import ResidualBlock +from model.registry import MODEL +from torchvision.models import vgg19 +from model.normalization import select_norm_layer + + +class VGG19StyleEncoder(nn.Module): + def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE", + vgg19_layers=(0, 5, 10, 19)): + super().__init__() + self.vgg19_layers = vgg19_layers + self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1] + self.vgg19.requires_grad_(False) + + norm_layer = select_norm_layer(norm_type) + + self.conv0 = nn.Sequential( + nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, + bias=True), + norm_layer(base_channels), + nn.ReLU(True), + ) + self.conv = nn.ModuleList([ + nn.Sequential( + nn.Conv2d(base_channels * (2 ** i), base_channels * (2 ** i), kernel_size=4, stride=2, padding=1, + padding_mode=padding_mode, bias=True), + norm_layer(base_channels), + nn.ReLU(True), + ) for i in range(1, 4) + ]) + self.pool = nn.AdaptiveAvgPool2d(1) + self.conv1x1 = nn.Conv2d(base_channels * (2 ** 4), style_dim, kernel_size=1, stride=1, padding=0) + + def fixed_style_features(self, x): + features = [] + for i in range(len(self.vgg19)): + x = self.vgg19[i](x) + if i in self.vgg19_layers: + features.append(x) + return features + + def forward(self, x): + fsf = self.fixed_style_features(x) + x = self.conv0(x) + for i, l in enumerate(self.conv): + x = l(torch.cat([x, fsf[i]], dim=1)) + x = self.pool(torch.cat([x, fsf[-1]], dim=1)) + x = self.conv1x1(x) + return x.view(x.size(0), -1) + + +class ContentEncoder(nn.Module): + def __init__(self, in_channels, base_channels=64, num_blocks=8, padding_mode='reflect', norm_type="IN"): + super().__init__() + norm_layer = select_norm_layer(norm_type) + + self.start_conv = nn.Sequential( + nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3, + bias=True), + norm_layer(num_features=base_channels), + nn.ReLU(inplace=True) + ) + + # down sampling + submodules = [] + num_down_sampling = 2 + for i in range(num_down_sampling): + multiple = 2 ** i + submodules += [ + nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2, + kernel_size=4, stride=2, padding=1, bias=True), + norm_layer(num_features=base_channels * multiple * 2), + nn.ReLU(inplace=True) + ] + self.encoder = nn.Sequential(*submodules) + res_block_channels = num_down_sampling ** 2 * base_channels + self.resnet = nn.Sequential( + *[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)]) + + def forward(self, x): + x = self.start_conv(x) + x = self.encoder(x) + x = self.resnet(x) + return x + + +class Decoder(nn.Module): + def __init__(self, out_channels, base_channels=64, num_down_sampling=2, padding_mode='reflect', norm_type="LN"): + super(Decoder, self).__init__() + norm_layer = select_norm_layer(norm_type) + use_bias = norm_type == "IN" + # up sampling + submodules = [] + for i in range(num_down_sampling): + multiple = 2 ** (num_down_sampling - i) + submodules += [ + nn.Upsample(scale_factor=2), + nn.Conv2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=5, stride=1, + padding=2, padding_mode=padding_mode, bias=use_bias), + norm_layer(num_features=base_channels * multiple // 2), + nn.ReLU(inplace=True), + ] + self.decoder = nn.Sequential(*submodules) + self.end_conv = nn.Sequential( + nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode), + nn.Tanh() + ) + + def forward(self, x): + x = self.decoder(x) + x = self.end_conv(x) + return x + + +class Fusion(nn.Module): + def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"): + super().__init__() + norm_layer = select_norm_layer(norm_type) + self.start_fc = nn.Sequential( + nn.Linear(in_features, base_features), + norm_layer(base_features), + nn.ReLU(True), + ) + self.fcs = nn.Sequential(*[ + nn.Sequential( + nn.Linear(base_features, base_features), + norm_layer(base_features), + nn.ReLU(True), + ) for _ in range(n_blocks - 2) + ]) + self.end_fc = nn.Sequential( + nn.Linear(base_features, out_features), + ) + + def forward(self, x): + x = self.start_fc(x) + x = self.fcs(x) + return self.end_fc(x) + + +@MODEL.register_module("TAHG-Generator") +class Generator(nn.Module): + def __init__(self, style_in_channels, content_in_channels, out_channels, style_dim=512, num_blocks=8, + base_channels=64, padding_mode="reflect"): + super(Generator, self).__init__() + self.num_blocks = num_blocks + self.style_encoder = VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, + padding_mode=padding_mode, norm_type="NONE") + self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks, + padding_mode=padding_mode, norm_type="IN") + res_block_channels = 2 ** 2 * base_channels + self.adain_res = nn.ModuleList([ + ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks) + ]) + self.decoders = nn.ModuleDict({ + "a": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode), + "b": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode) + }) + + self.fc = nn.Sequential( + nn.Linear(style_dim, style_dim), + nn.ReLU(True), + ) + self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3, + norm_type="NONE") + + def forward(self, content_img, style_img, which_decoder: str = "a"): + x = self.content_encoder(content_img) + styles = self.fusion(self.fc(self.style_encoder(style_img))) + styles = torch.chunk(styles, self.num_blocks * 2, dim=1) + for i, ar in enumerate(self.adain_res): + ar.norm1.set_style(styles[2 * i]) + ar.norm2.set_style(styles[2 * i + 1]) + x = ar(x) + return self.decoders[which_decoder](x)