From 06b2abd19a47c0981ffa7e33983e7f8f1d41268c Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Sun, 11 Oct 2020 19:02:42 +0800 Subject: [PATCH] add flag to switch to norm-activ-conv --- model/base/module.py | 77 ++++++++++++++--------------- model/image_translation/CycleGAN.py | 30 ++++++----- model/image_translation/MUNIT.py | 13 ++--- model/image_translation/UGATIT.py | 11 +++-- 4 files changed, 70 insertions(+), 61 deletions(-) diff --git a/model/base/module.py b/model/base/module.py index c5502f5..9597a43 100644 --- a/model/base/module.py +++ b/model/base/module.py @@ -52,7 +52,8 @@ class LinearBlock(nn.Module): class Conv2dBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, bias=None, - activation_type="ReLU", norm_type="NONE", **conv_kwargs): + activation_type="ReLU", norm_type="NONE", + additional_norm_kwargs=None, **conv_kwargs): super().__init__() self.norm_type = norm_type self.activation_type = activation_type @@ -61,40 +62,13 @@ class Conv2dBlock(nn.Module): conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs) - self.normalization = _normalization(norm_type, out_channels) + self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs) self.activation = _activation(activation_type) def forward(self, x): return self.activation(self.normalization(self.convolution(x))) -class ResidualBlock(nn.Module): - def __init__(self, num_channels, out_channels=None, padding_mode='reflect', - activation_type="ReLU", norm_type="IN", out_activation_type=None): - super().__init__() - self.norm_type = norm_type - - if out_channels is None: - out_channels = num_channels - if out_activation_type is None: - out_activation_type = "NONE" - - self.learn_skip_connection = num_channels != out_channels - - self.conv1 = Conv2dBlock(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, - norm_type=norm_type, activation_type=activation_type) - self.conv2 = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode, - norm_type=norm_type, activation_type=out_activation_type) - - if self.learn_skip_connection: - self.res_conv = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode, - norm_type=norm_type, activation_type=out_activation_type) - - def forward(self, x): - res = x if not self.learn_skip_connection else self.res_conv(x) - return self.conv2(self.conv1(x)) + res - - class ReverseConv2dBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs): @@ -107,19 +81,44 @@ class ReverseConv2dBlock(nn.Module): return self.convolution(self.activation(self.normalization(x))) -class ReverseResidualBlock(nn.Module): - def __init__(self, in_channels, out_channels, padding_mode="reflect", - norm_type="IN", additional_norm_kwargs=None, activation_type="ReLU"): +class ResidualBlock(nn.Module): + def __init__(self, in_channels, + padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False, + out_channels=None, out_activation_type=None): + """ + Residual Conv Block + :param in_channels: + :param out_channels: + :param padding_mode: + :param activation_type: + :param norm_type: + :param out_activation_type: + :param pre_activation: full pre-activation mode from https://arxiv.org/pdf/1603.05027v3.pdf, figure 4 + """ super().__init__() + self.norm_type = norm_type + + if out_channels is None: + out_channels = in_channels + if out_activation_type is None: + # if not specify `out_activation_type`, using default `out_activation_type` + # `out_activation_type` default mode: + # "NONE" for not full pre-activation + # `norm_type` for full pre-activation + out_activation_type = "NONE" if not pre_activation else norm_type + self.learn_skip_connection = in_channels != out_channels - self.conv1 = ReverseConv2dBlock(in_channels, in_channels, activation_type, norm_type, additional_norm_kwargs, - kernel_size=3, padding=1, padding_mode=padding_mode) - self.conv2 = ReverseConv2dBlock(in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs, - kernel_size=3, padding=1, padding_mode=padding_mode) + + conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock + + self.conv1 = conv_block(in_channels, in_channels, kernel_size=3, padding=1, padding_mode=padding_mode, + norm_type=norm_type, activation_type=activation_type) + self.conv2 = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode, + norm_type=norm_type, activation_type=out_activation_type) + if self.learn_skip_connection: - self.res_conv = ReverseConv2dBlock( - in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs, - kernel_size=3, padding=1, padding_mode=padding_mode) + self.res_conv = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode, + norm_type=norm_type, activation_type=out_activation_type) def forward(self, x): res = x if not self.learn_skip_connection else self.res_conv(x) diff --git a/model/image_translation/CycleGAN.py b/model/image_translation/CycleGAN.py index 10cd4b6..dd487cb 100644 --- a/model/image_translation/CycleGAN.py +++ b/model/image_translation/CycleGAN.py @@ -7,7 +7,7 @@ class Encoder(nn.Module): def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", down_conv_norm_type="IN", down_conv_kernel_size=3, - res_norm_type="IN"): + res_norm_type="IN", pre_activation=False): super().__init__() sequence = [Conv2dBlock( @@ -25,8 +25,13 @@ class Encoder(nn.Module): )) self.out_channels = multiple_now * base_channels sequence += [ - ResidualBlock(self.out_channels, self.out_channels, padding_mode, activation_type, norm_type=res_norm_type) - for _ in range(num_res) + ResidualBlock( + self.out_channels, + padding_mode=padding_mode, + activation_type=activation_type, + norm_type=res_norm_type, + pre_activation=pre_activation + ) for _ in range(num_res) ] self.sequence = nn.Sequential(*sequence) @@ -38,11 +43,16 @@ class Decoder(nn.Module): def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks, activation_type="ReLU", padding_mode='reflect', up_conv_kernel_size=5, up_conv_norm_type="LN", - res_norm_type="AdaIN"): + res_norm_type="AdaIN", pre_activation=False): super().__init__() self.residual_blocks = nn.ModuleList([ - ResidualBlock(in_channels, in_channels, padding_mode, activation_type, norm_type=res_norm_type) - for _ in range(num_residual_blocks) + ResidualBlock( + in_channels, + padding_mode=padding_mode, + activation_type=activation_type, + norm_type=res_norm_type, + pre_activation=pre_activation + ) for _ in range(num_residual_blocks) ]) sequence = list() @@ -50,15 +60,13 @@ class Decoder(nn.Module): for i in range(num_up_sampling): sequence.append(nn.Sequential( nn.Upsample(scale_factor=2), - Conv2dBlock(channels, channels // 2, - kernel_size=up_conv_kernel_size, stride=1, + Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1, padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode, activation_type=activation_type, norm_type=up_conv_norm_type), )) channels = channels // 2 - sequence.append(Conv2dBlock(channels, out_channels, - kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, - activation_type="Tanh", norm_type="NONE")) + sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, + padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")) self.up_sequence = nn.Sequential(*sequence) diff --git a/model/image_translation/MUNIT.py b/model/image_translation/MUNIT.py index d655058..02a7fd6 100644 --- a/model/image_translation/MUNIT.py +++ b/model/image_translation/MUNIT.py @@ -8,12 +8,13 @@ from model.image_translation.CycleGAN import Encoder, Decoder class StyleEncoder(nn.Module): def __init__(self, in_channels, out_dim, num_conv, base_channels=64, - max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"): + max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE", + pre_activation=False): super().__init__() self.down_encoder = Encoder( in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple, padding_mode=padding_mode, activation_type=activation_type, - down_conv_norm_type=norm_type, down_conv_kernel_size=4, + down_conv_norm_type=norm_type, down_conv_kernel_size=4, pre_activation=pre_activation, ) sequence = list() sequence.append(nn.AdaptiveAvgPool2d(1)) @@ -47,19 +48,19 @@ class Generator(nn.Module): num_mlp_base_feature=256, num_mlp_blocks=3, max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2, encoder_num_residual_blocks=4, decoder_num_residual_blocks=4, - padding_mode='reflect', activation_type="ReLU"): + padding_mode='reflect', activation_type="ReLU", pre_activation=False): super().__init__() self.content_encoder = Encoder( in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks, max_down_sampling_multiple=num_content_down_sampling, padding_mode=padding_mode, activation_type=activation_type, down_conv_norm_type="IN", down_conv_kernel_size=4, - res_norm_type="IN" + res_norm_type="IN", pre_activation=pre_activation ) self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels, max_down_sampling_multiple, padding_mode, activation_type, - norm_type="NONE") + norm_type="NONE", pre_activation=pre_activation) content_channels = base_channels * (2 ** max_down_sampling_multiple) @@ -70,7 +71,7 @@ class Generator(nn.Module): self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks, activation_type=activation_type, padding_mode=padding_mode, up_conv_kernel_size=5, up_conv_norm_type="LN", - res_norm_type="AdaIN") + res_norm_type="AdaIN", pre_activation=pre_activation) def encode(self, x): return self.content_encoder(x), self.style_encoder(x) diff --git a/model/image_translation/UGATIT.py b/model/image_translation/UGATIT.py index 7dd3d43..9e4a7c1 100644 --- a/model/image_translation/UGATIT.py +++ b/model/image_translation/UGATIT.py @@ -26,8 +26,8 @@ class CAMClassifier(nn.Module): self.avg_fc = nn.Linear(in_channels, 1, bias=False) self.max_pool = nn.AdaptiveMaxPool2d(1) self.max_fc = nn.Linear(in_channels, 1, bias=False) - self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, activation_type=activation_type, - norm_type="NONE", kernel_size=1, stride=1, bias=True) + self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, kernel_size=1, stride=1, bias=True, + activation_type=activation_type, norm_type="NONE") def forward(self, x): avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1)) @@ -42,7 +42,7 @@ class CAMClassifier(nn.Module): @MODEL.register_module("UGATIT-Generator") class Generator(nn.Module): def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False, - activation_type="ReLU", norm_type="IN", padding_mode='reflect'): + activation_type="ReLU", norm_type="IN", padding_mode='reflect', pre_activation=False): super(Generator, self).__init__() self.light = light @@ -50,7 +50,8 @@ class Generator(nn.Module): n_down_sampling = 2 self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks, padding_mode=padding_mode, activation_type=activation_type, - down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type) + down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type, + pre_activation=pre_activation) mult = 2 ** n_down_sampling self.cam = CAMClassifier(base_channels * mult, activation_type) @@ -74,7 +75,7 @@ class Generator(nn.Module): base_channels * mult, out_channels, n_down_sampling, num_blocks, activation_type=activation_type, padding_mode=padding_mode, up_conv_kernel_size=3, up_conv_norm_type="ILN", - res_norm_type="AdaILN" + res_norm_type="AdaILN", pre_activation=pre_activation ) def forward(self, x):