diff --git a/model/GAN/TSIT.py b/model/GAN/TSIT.py index de9d467..3422a51 100644 --- a/model/GAN/TSIT.py +++ b/model/GAN/TSIT.py @@ -3,44 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from model import MODEL -from model.normalization import select_norm_layer - - -class ResBlock(nn.Module): - def __init__(self, in_channels, out_channels, padding_mode='zeros', norm_type="IN", use_bias=None, - use_spectral=True): - super().__init__() - self.padding_mode = padding_mode - self.use_bias = use_bias - self.use_spectral = use_spectral - if use_bias is None: - # Only for IN, use bias since it does not have affine parameters. - self.use_bias = norm_type == "IN" - norm_layer = select_norm_layer(norm_type) - self.main = nn.Sequential( - self.conv_block(in_channels, in_channels), - norm_layer(in_channels), - nn.LeakyReLU(0.2, inplace=True), - self.conv_block(in_channels, out_channels), - norm_layer(out_channels), - nn.LeakyReLU(0.2, inplace=True), - ) - self.skip = nn.Sequential( - self.conv_block(in_channels, out_channels, padding=0, kernel_size=1), - norm_layer(out_channels), - nn.LeakyReLU(0.2, inplace=True), - ) - - def conv_block(self, in_channels, out_channels, kernel_size=3, padding=1): - conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, - padding_mode=self.padding_mode, bias=self.use_bias) - if self.use_spectral: - return nn.utils.spectral_norm(conv) - else: - return conv - - def forward(self, x): - return self.main(x) + self.skip(x) +from model.base.module import Conv2dBlock, ResidualBlock, ReverseResidualBlock class Interpolation(nn.Module): @@ -58,104 +21,41 @@ class Interpolation(nn.Module): return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})" -class FADE(nn.Module): - def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True): - super().__init__() - # self.norm = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats) - self.norm = nn.InstanceNorm2d(num_features=in_channels) - - self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1, - padding_mode="zeros") - self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1, - padding_mode="zeros") - - def forward(self, x, feature): - alpha = self.alpha_conv(feature) - beta = self.beta_conv(feature) - x = self.norm(x) - return alpha * x + beta - - -class FADEResBlock(nn.Module): - def __init__(self, use_spectral, features_channels, in_channels, out_channels): - super().__init__() - self.main = nn.Sequential( - FADE(use_spectral, features_channels, in_channels), - nn.LeakyReLU(0.2, inplace=True), - conv_block(use_spectral, in_channels, in_channels, kernel_size=3, padding=1), - FADE(use_spectral, features_channels, in_channels), - nn.LeakyReLU(0.2, inplace=True), - conv_block(use_spectral, in_channels, out_channels, kernel_size=3, padding=1), - ) - self.skip = nn.Sequential( - FADE(use_spectral, features_channels, in_channels), - nn.LeakyReLU(0.2, inplace=True), - conv_block(use_spectral, in_channels, out_channels, kernel_size=1, padding=0), - ) - self.up_sample = Interpolation(2, mode="nearest") - - @staticmethod - def forward_with_fade(module, x, feature): - for layer in module: - if layer.__class__.__name__ == "FADE": - x = layer(x, feature) - else: - x = layer(x) - return x - - def forward(self, x, feature): - out = self.forward_with_fade(self.main, x, feature) + self.forward_with_fade(self.main, x, feature) - return self.up_sample(out) - - -def conv_block(use_spectral, in_channels, out_channels, **kwargs): - conv = nn.Conv2d(in_channels, out_channels, **kwargs) - return nn.utils.spectral_norm(conv) if use_spectral else conv - - @MODEL.register_module("TSIT-Generator") -class TSITGenerator(nn.Module): - def __init__(self, num_blocks=7, base_channels=64, content_in_channels=3, style_in_channels=3, - out_channels=3, use_spectral=True, input_layer_type="conv1x1"): +class Generator(nn.Module): + def __init__(self, content_in_channels=3, out_channels=3, base_channels=64, num_blocks=7, + padding_mode="reflect", activation_type="ReLU"): super().__init__() self.num_blocks = num_blocks self.base_channels = base_channels - self.use_spectral = use_spectral - self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type) - self.content_stream = self.build_stream() - self.generator = self.build_generator() - self.end_conv = nn.Sequential( - conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"), - nn.Tanh() - ) + self.content_stream = self.build_stream(padding_mode, activation_type) + self.start_conv = Conv2dBlock(content_in_channels, base_channels, activation_type=activation_type, + norm_type="IN", kernel_size=7, padding_mode=padding_mode, padding=3) - def build_generator(self): - stream_sequence = [] + sequence = [] multiple_now = min(2 ** self.num_blocks, 2 ** 4) for i in range(1, self.num_blocks + 1): m = self.num_blocks - i multiple_prev = multiple_now multiple_now = min(2 ** m, 2 ** 4) - stream_sequence.append( - FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels, - multiple_now * self.base_channels)) - return nn.ModuleList(stream_sequence) + sequence.append(nn.Sequential( + ReverseResidualBlock( + multiple_prev * base_channels, multiple_now * base_channels, + padding_mode=padding_mode, norm_type="FADE", + additional_norm_kwargs=dict( + condition_in_channels=multiple_prev * base_channels, + base_norm_type="BN", + padding_mode=padding_mode + ) + ), + Interpolation(2, mode="nearest") + )) + self.generator = nn.Sequential(*sequence) + self.end_conv = Conv2dBlock(base_channels, out_channels, activation_type="Tanh", + kernel_size=7, padding_mode=padding_mode, padding=3) - def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"): - if input_layer_type == "conv7x7": - return nn.Sequential( - conv_block(self.use_spectral, in_channels, out_channels, kernel_size=7, stride=1, - padding_mode="zeros", padding=3, bias=True), - select_norm_layer("IN")(out_channels), - nn.ReLU(inplace=True) - ) - elif input_layer_type == "conv1x1": - return conv_block(self.use_spectral, in_channels, out_channels, kernel_size=1, stride=1, padding=0) - else: - raise NotImplemented - - def build_stream(self): + def build_stream(self, padding_mode, activation_type): multiple_now = 1 stream_sequence = [] for i in range(1, self.num_blocks + 1): @@ -163,21 +63,26 @@ class TSITGenerator(nn.Module): multiple_now = min(2 ** i, 2 ** 4) stream_sequence.append(nn.Sequential( Interpolation(scale_factor=0.5, mode="nearest"), - ResBlock(multiple_prev * self.base_channels, multiple_now * self.base_channels, - use_spectral=self.use_spectral) + ResidualBlock( + multiple_prev * self.base_channels, multiple_now * self.base_channels, + padding_mode=padding_mode, activation_type=activation_type, norm_type="IN") )) return nn.ModuleList(stream_sequence) - def forward(self, content_img): - c = self.content_input_layer(content_img) + def forward(self, content, z=None): + c = self.start_conv(content) content_features = [] for i in range(self.num_blocks): c = self.content_stream[i](c) content_features.append(c) - z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device) + if z is None: + z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device) for i in range(self.num_blocks): m = - i - 1 - layer = self.generator[i] - z = layer(z, content_features[m]) - return self.end_conv(z) + res_block = self.generator[i][0] + res_block.conv1.normalization.set_feature(content_features[m]) + res_block.conv2.normalization.set_feature(content_features[m]) + if res_block.learn_skip_connection: + res_block.res_conv.normalization.set_feature(content_features[m]) + return self.end_conv(self.generator(z))