diff --git a/model/base/module.py b/model/base/module.py index 7a7a765..c5502f5 100644 --- a/model/base/module.py +++ b/model/base/module.py @@ -70,7 +70,7 @@ class Conv2dBlock(nn.Module): class ResidualBlock(nn.Module): def __init__(self, num_channels, out_channels=None, padding_mode='reflect', - activation_type="ReLU", out_activation_type=None, norm_type="IN"): + activation_type="ReLU", norm_type="IN", out_activation_type=None): super().__init__() self.norm_type = norm_type diff --git a/model/image_translation/CycleGAN.py b/model/image_translation/CycleGAN.py index e69de29..10cd4b6 100644 --- a/model/image_translation/CycleGAN.py +++ b/model/image_translation/CycleGAN.py @@ -0,0 +1,68 @@ +import torch.nn as nn + +from model.base.module import Conv2dBlock, ResidualBlock + + +class Encoder(nn.Module): + def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2, + padding_mode='reflect', activation_type="ReLU", + down_conv_norm_type="IN", down_conv_kernel_size=3, + res_norm_type="IN"): + super().__init__() + + sequence = [Conv2dBlock( + in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, + activation_type=activation_type, norm_type=down_conv_norm_type + )] + multiple_now = 1 + for i in range(1, num_conv + 1): + multiple_prev = multiple_now + multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple) + sequence.append(Conv2dBlock( + multiple_prev * base_channels, multiple_now * base_channels, + kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode=padding_mode, + activation_type=activation_type, norm_type=down_conv_norm_type + )) + self.out_channels = multiple_now * base_channels + sequence += [ + ResidualBlock(self.out_channels, self.out_channels, padding_mode, activation_type, norm_type=res_norm_type) + for _ in range(num_res) + ] + self.sequence = nn.Sequential(*sequence) + + def forward(self, x): + return self.sequence(x) + + +class Decoder(nn.Module): + def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks, + activation_type="ReLU", padding_mode='reflect', + up_conv_kernel_size=5, up_conv_norm_type="LN", + res_norm_type="AdaIN"): + super().__init__() + self.residual_blocks = nn.ModuleList([ + ResidualBlock(in_channels, in_channels, padding_mode, activation_type, norm_type=res_norm_type) + for _ in range(num_residual_blocks) + ]) + + sequence = list() + channels = in_channels + for i in range(num_up_sampling): + sequence.append(nn.Sequential( + nn.Upsample(scale_factor=2), + Conv2dBlock(channels, channels // 2, + kernel_size=up_conv_kernel_size, stride=1, + padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode, + activation_type=activation_type, norm_type=up_conv_norm_type), + )) + channels = channels // 2 + sequence.append(Conv2dBlock(channels, out_channels, + kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, + activation_type="Tanh", norm_type="NONE")) + + self.up_sequence = nn.Sequential(*sequence) + + def forward(self, x): + for i, blk in enumerate(self.residual_blocks): + x = blk(x) + return self.up_sequence(x) diff --git a/model/image_translation/MUNIT.py b/model/image_translation/MUNIT.py index 8dc3af0..d655058 100644 --- a/model/image_translation/MUNIT.py +++ b/model/image_translation/MUNIT.py @@ -2,99 +2,29 @@ import torch import torch.nn as nn from model import MODEL -from model.base.module import Conv2dBlock, ResidualBlock, LinearBlock - - -def _get_down_sampling_sequence(in_channels, base_channels, num_conv, max_down_sampling_multiple=2, - padding_mode='reflect', activation_type="ReLU", norm_type="NONE"): - sequence = [Conv2dBlock( - in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, - activation_type=activation_type, norm_type=norm_type - )] - multiple_now = 1 - for i in range(1, num_conv + 1): - multiple_prev = multiple_now - multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple) - sequence.append(Conv2dBlock( - multiple_prev * base_channels, multiple_now * base_channels, - kernel_size=4, stride=2, padding=1, padding_mode=padding_mode, - activation_type=activation_type, norm_type=norm_type - )) - return sequence, multiple_now * base_channels +from model.base.module import LinearBlock +from model.image_translation.CycleGAN import Encoder, Decoder class StyleEncoder(nn.Module): def __init__(self, in_channels, out_dim, num_conv, base_channels=64, max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"): super().__init__() - - sequence, last_channels = _get_down_sampling_sequence( - in_channels, base_channels, num_conv, - max_down_sampling_multiple, padding_mode, activation_type, norm_type + self.down_encoder = Encoder( + in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple, + padding_mode=padding_mode, activation_type=activation_type, + down_conv_norm_type=norm_type, down_conv_kernel_size=4, ) + sequence = list() sequence.append(nn.AdaptiveAvgPool2d(1)) # conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code - sequence.append(nn.Conv2d(last_channels, out_dim, kernel_size=1, stride=1, padding=0)) + sequence.append(nn.Conv2d(self.down_encoder.out_channels, out_dim, kernel_size=1, stride=1, padding=0)) self.sequence = nn.Sequential(*sequence) def forward(self, image): return self.sequence(image).view(image.size(0), -1) -class ContentEncoder(nn.Module): - def __init__(self, in_channels, num_down_sampling, num_residual_blocks, base_channels=64, - max_down_sampling_multiple=2, - padding_mode='reflect', activation_type="ReLU", norm_type="IN"): - super().__init__() - - sequence, last_channels = _get_down_sampling_sequence( - in_channels, base_channels, num_down_sampling, - max_down_sampling_multiple, padding_mode, activation_type, norm_type - ) - - sequence += [ResidualBlock(last_channels, last_channels, padding_mode, activation_type, norm_type) for _ in - range(num_residual_blocks)] - self.sequence = nn.Sequential(*sequence) - - def forward(self, image): - return self.sequence(image) - - -class Decoder(nn.Module): - def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks, - res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU", padding_mode='reflect'): - super().__init__() - self.residual_blocks = nn.ModuleList([ - ResidualBlock(in_channels, in_channels, padding_mode, activation_type, norm_type=res_norm_type) - for _ in range(num_residual_blocks) - ]) - - sequence = list() - channels = in_channels - for i in range(num_up_sampling): - sequence.append(nn.Sequential( - nn.Upsample(scale_factor=2), - Conv2dBlock(channels, channels // 2, - kernel_size=5, stride=1, padding=2, padding_mode=padding_mode, - activation_type=activation_type, norm_type=norm_type), - )) - channels = channels // 2 - sequence.append(Conv2dBlock(channels, out_channels, - kernel_size=7, stride=1, padding=3, padding_mode="reflect", - activation_type="Tanh", norm_type="NONE")) - - self.up_sequence = nn.Sequential(*sequence) - - def forward(self, x, style): - as_param_style = torch.chunk(style, 2 * len(self.residual_blocks), dim=1) - # set style for decoder - for i, blk in enumerate(self.residual_blocks): - blk.conv1.normalization.set_style(as_param_style[2 * i]) - blk.conv2.normalization.set_style(as_param_style[2 * i + 1]) - x = blk(x) - return self.up_sequence(x) - - class MLPFusion(nn.Module): def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"): super().__init__() @@ -119,10 +49,13 @@ class Generator(nn.Module): encoder_num_residual_blocks=4, decoder_num_residual_blocks=4, padding_mode='reflect', activation_type="ReLU"): super().__init__() - self.content_encoder = ContentEncoder( - in_channels, num_content_down_sampling, encoder_num_residual_blocks, - base_channels, max_down_sampling_multiple, - padding_mode, activation_type, norm_type="IN") + self.content_encoder = Encoder( + in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks, + max_down_sampling_multiple=num_content_down_sampling, + padding_mode=padding_mode, activation_type=activation_type, + down_conv_norm_type="IN", down_conv_kernel_size=4, + res_norm_type="IN" + ) self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels, max_down_sampling_multiple, padding_mode, activation_type, @@ -134,15 +67,21 @@ class Generator(nn.Module): num_mlp_base_feature, num_mlp_blocks, activation_type, norm_type="NONE") - self.decoder = Decoder(content_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks, - res_norm_type="AdaIN", norm_type="LN", activation_type=activation_type, - padding_mode=padding_mode) + self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks, + activation_type=activation_type, padding_mode=padding_mode, + up_conv_kernel_size=5, up_conv_norm_type="LN", + res_norm_type="AdaIN") def encode(self, x): return self.content_encoder(x), self.style_encoder(x) def decode(self, content, style): - self.decoder(content, self.fusion(style)) + as_param_style = torch.chunk(self.fusion(style), 2 * len(self.decoder.residual_blocks), dim=1) + # set style for decoder + for i, blk in enumerate(self.decoder.residual_blocks): + blk.conv1.normalization.set_style(as_param_style[2 * i]) + blk.conv2.normalization.set_style(as_param_style[2 * i + 1]) + self.decoder(content) def forward(self, x): content, style = self.encode(x) diff --git a/model/image_translation/UGATIT.py b/model/image_translation/UGATIT.py index d288af7..7dd3d43 100644 --- a/model/image_translation/UGATIT.py +++ b/model/image_translation/UGATIT.py @@ -2,7 +2,8 @@ import torch import torch.nn as nn from model import MODEL -from model.base.module import Conv2dBlock, ResidualBlock, LinearBlock +from model.base.module import Conv2dBlock, LinearBlock +from model.image_translation.CycleGAN import Encoder, Decoder class RhoClipper(object): @@ -46,27 +47,11 @@ class Generator(nn.Module): self.light = light - sequence = [Conv2dBlock( - in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, - activation_type=activation_type, norm_type=norm_type - )] - n_down_sampling = 2 - for i in range(n_down_sampling): - mult = 2 ** i - sequence.append(Conv2dBlock( - base_channels * mult, base_channels * mult * 2, - kernel_size=3, stride=2, padding=1, padding_mode=padding_mode, - activation_type=activation_type, norm_type=norm_type - )) - + self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks, + padding_mode=padding_mode, activation_type=activation_type, + down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type) mult = 2 ** n_down_sampling - sequence += [ - ResidualBlock(base_channels * mult, base_channels * mult, padding_mode, activation_type=activation_type, - norm_type=norm_type) - for _ in range(num_blocks)] - self.encoder = nn.Sequential(*sequence) - self.cam = CAMClassifier(base_channels * mult, activation_type) # Gamma, Beta block @@ -85,25 +70,12 @@ class Generator(nn.Module): self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False) self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False) - # Up-Sampling Bottleneck - self.up_bottleneck = nn.ModuleList( - [ResidualBlock(base_channels * mult, base_channels * mult, padding_mode, - activation_type, norm_type="AdaILN") for _ in range(num_blocks)]) - - sequence = list() - channels = base_channels * mult - for i in range(n_down_sampling): - sequence.append(nn.Sequential( - nn.Upsample(scale_factor=2), - Conv2dBlock(channels, channels // 2, - kernel_size=3, stride=1, padding=1, bias=False, padding_mode=padding_mode, - activation_type=activation_type, norm_type="ILN"), - )) - channels = channels // 2 - sequence.append(Conv2dBlock(channels, out_channels, - kernel_size=7, stride=1, padding=3, padding_mode="reflect", - activation_type="Tanh", norm_type="NONE")) - self.decoder = nn.Sequential(*sequence) + self.decoder = Decoder( + base_channels * mult, out_channels, n_down_sampling, num_blocks, + activation_type=activation_type, padding_mode=padding_mode, + up_conv_kernel_size=3, up_conv_norm_type="ILN", + res_norm_type="AdaILN" + ) def forward(self, x): x = self.encoder(x) @@ -119,10 +91,9 @@ class Generator(nn.Module): x_ = self.fc(x.view(x.shape[0], -1)) gamma, beta = self.gamma(x_), self.beta(x_) - for blk in self.up_bottleneck: + for blk in self.decoder.residual_blocks: blk.conv1.normalization.set_condition(gamma, beta) blk.conv2.normalization.set_condition(gamma, beta) - x = blk(x) return self.decoder(x), cam_logit, heatmap