import torch import torch.nn as nn import torch.nn.functional as F from model import MODEL from model.normalization import AdaptiveInstanceNorm2d from model.normalization import select_norm_layer class ResBlock(nn.Module): def __init__(self, in_channels, out_channels, padding_mode='zeros', norm_type="IN", use_bias=None, use_spectral=True): super().__init__() self.padding_mode = padding_mode self.use_bias = use_bias self.use_spectral = use_spectral if use_bias is None: # Only for IN, use bias since it does not have affine parameters. self.use_bias = norm_type == "IN" norm_layer = select_norm_layer(norm_type) self.main = nn.Sequential( self.conv_block(in_channels, in_channels), norm_layer(in_channels), nn.LeakyReLU(0.2, inplace=True), self.conv_block(in_channels, out_channels), norm_layer(out_channels), nn.LeakyReLU(0.2, inplace=True), ) self.skip = nn.Sequential( self.conv_block(in_channels, out_channels, padding=0, kernel_size=1), norm_layer(out_channels), nn.LeakyReLU(0.2, inplace=True), ) def conv_block(self, in_channels, out_channels, kernel_size=3, padding=1): conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, padding_mode=self.padding_mode, bias=self.use_bias) if self.use_spectral: return nn.utils.spectral_norm(conv) else: return conv def forward(self, x): return self.main(x) + self.skip(x) class Interpolation(nn.Module): def __init__(self, scale_factor=None, mode='nearest', align_corners=None): super(Interpolation, self).__init__() self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners def forward(self, x): return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, recompute_scale_factor=False) def __repr__(self): return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})" class FADE(nn.Module): def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True): super().__init__() self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats) self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1, padding_mode="zeros") self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1, padding_mode="zeros") def forward(self, x, feature): alpha = self.alpha_conv(feature) beta = self.beta_conv(feature) x = self.bn(x) return alpha * x + beta class FADEResBlock(nn.Module): def __init__(self, use_spectral, features_channels, in_channels, out_channels): super().__init__() self.main = nn.Sequential( FADE(use_spectral, features_channels, in_channels), nn.LeakyReLU(0.2, inplace=True), conv_block(use_spectral, in_channels, in_channels, kernel_size=3, padding=1), FADE(use_spectral, features_channels, in_channels), nn.LeakyReLU(0.2, inplace=True), conv_block(use_spectral, in_channels, out_channels, kernel_size=3, padding=1), ) self.skip = nn.Sequential( FADE(use_spectral, features_channels, in_channels), nn.LeakyReLU(0.2, inplace=True), conv_block(use_spectral, in_channels, out_channels, kernel_size=1, padding=0), ) self.up_sample = Interpolation(2, mode="nearest") @staticmethod def forward_with_fade(module, x, feature): for layer in module: if layer.__class__.__name__ == "FADE": x = layer(x, feature) else: x = layer(x) return x def forward(self, x, feature): out = self.forward_with_fade(self.main, x, feature) + self.forward_with_fade(self.main, x, feature) return self.up_sample(out) def conv_block(use_spectral, in_channels, out_channels, **kwargs): conv = nn.Conv2d(in_channels, out_channels, **kwargs) return nn.utils.spectral_norm(conv) if use_spectral else conv @MODEL.register_module("TSIT-Generator") class TSITGenerator(nn.Module): def __init__(self, num_blocks=7, base_channels=64, content_in_channels=3, style_in_channels=3, out_channels=3, use_spectral=True, input_layer_type="conv1x1"): super().__init__() self.num_blocks = num_blocks self.base_channels = base_channels self.use_spectral = use_spectral self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type) self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type) self.content_stream = self.build_stream() self.style_stream = self.build_stream() self.generator = self.build_generator() self.end_conv = nn.Sequential( conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"), nn.Tanh() ) def build_generator(self): stream_sequence = [] multiple_now = min(2 ** self.num_blocks, 2 ** 4) for i in range(1, self.num_blocks + 1): m = self.num_blocks - i multiple_prev = multiple_now multiple_now = min(2 ** m, 2 ** 4) stream_sequence.append(nn.Sequential( AdaptiveInstanceNorm2d(multiple_prev * self.base_channels), FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels, multiple_now * self.base_channels) )) return nn.ModuleList(stream_sequence) def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"): if input_layer_type == "conv7x7": return nn.Sequential( conv_block(self.use_spectral, in_channels, out_channels, kernel_size=7, stride=1, padding_mode="zeros", padding=3, bias=True), select_norm_layer("IN")(out_channels), nn.ReLU(inplace=True) ) elif input_layer_type == "conv1x1": return conv_block(self.use_spectral, in_channels, out_channels, kernel_size=1, stride=1, padding=0) else: raise NotImplemented def build_stream(self): multiple_now = 1 stream_sequence = [] for i in range(1, self.num_blocks + 1): multiple_prev = multiple_now multiple_now = min(2 ** i, 2 ** 4) stream_sequence.append(nn.Sequential( Interpolation(scale_factor=0.5, mode="nearest"), ResBlock(multiple_prev * self.base_channels, multiple_now * self.base_channels, use_spectral=self.use_spectral) )) return nn.ModuleList(stream_sequence) def forward(self, content_img, style_img): c = self.content_input_layer(content_img) s = self.style_input_layer(style_img) content_features = [] style_features = [] for i in range(self.num_blocks): s = self.style_stream[i](s) c = self.content_stream[i](c) content_features.append(c) style_features.append(s) z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device) for i in range(self.num_blocks): m = - i - 1 layer = self.generator[i] layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1)) z = layer[0](z) z = layer[1](z, content_features[m]) return self.end_conv(z)