add flag to switch to norm-activ-conv

This commit is contained in:
Ray Wong 2020-10-11 19:02:42 +08:00
parent 9c08b4cd09
commit 06b2abd19a
4 changed files with 70 additions and 61 deletions

View File

@ -52,7 +52,8 @@ class LinearBlock(nn.Module):
class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, bias=None,
activation_type="ReLU", norm_type="NONE", **conv_kwargs):
activation_type="ReLU", norm_type="NONE",
additional_norm_kwargs=None, **conv_kwargs):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
@ -61,40 +62,13 @@ class Conv2dBlock(nn.Module):
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
self.normalization = _normalization(norm_type, out_channels)
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)
def forward(self, x):
return self.activation(self.normalization(self.convolution(x)))
class ResidualBlock(nn.Module):
def __init__(self, num_channels, out_channels=None, padding_mode='reflect',
activation_type="ReLU", norm_type="IN", out_activation_type=None):
super().__init__()
self.norm_type = norm_type
if out_channels is None:
out_channels = num_channels
if out_activation_type is None:
out_activation_type = "NONE"
self.learn_skip_connection = num_channels != out_channels
self.conv1 = Conv2dBlock(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=activation_type)
self.conv2 = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
if self.learn_skip_connection:
self.res_conv = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)
return self.conv2(self.conv1(x)) + res
class ReverseConv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int,
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
@ -107,19 +81,44 @@ class ReverseConv2dBlock(nn.Module):
return self.convolution(self.activation(self.normalization(x)))
class ReverseResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, padding_mode="reflect",
norm_type="IN", additional_norm_kwargs=None, activation_type="ReLU"):
class ResidualBlock(nn.Module):
def __init__(self, in_channels,
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
out_channels=None, out_activation_type=None):
"""
Residual Conv Block
:param in_channels:
:param out_channels:
:param padding_mode:
:param activation_type:
:param norm_type:
:param out_activation_type:
:param pre_activation: full pre-activation mode from https://arxiv.org/pdf/1603.05027v3.pdf, figure 4
"""
super().__init__()
self.norm_type = norm_type
if out_channels is None:
out_channels = in_channels
if out_activation_type is None:
# if not specify `out_activation_type`, using default `out_activation_type`
# `out_activation_type` default mode:
# "NONE" for not full pre-activation
# `norm_type` for full pre-activation
out_activation_type = "NONE" if not pre_activation else norm_type
self.learn_skip_connection = in_channels != out_channels
self.conv1 = ReverseConv2dBlock(in_channels, in_channels, activation_type, norm_type, additional_norm_kwargs,
kernel_size=3, padding=1, padding_mode=padding_mode)
self.conv2 = ReverseConv2dBlock(in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
kernel_size=3, padding=1, padding_mode=padding_mode)
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
self.conv1 = conv_block(in_channels, in_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=activation_type)
self.conv2 = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
if self.learn_skip_connection:
self.res_conv = ReverseConv2dBlock(
in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
kernel_size=3, padding=1, padding_mode=padding_mode)
self.res_conv = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)

View File

@ -7,7 +7,7 @@ class Encoder(nn.Module):
def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2,
padding_mode='reflect', activation_type="ReLU",
down_conv_norm_type="IN", down_conv_kernel_size=3,
res_norm_type="IN"):
res_norm_type="IN", pre_activation=False):
super().__init__()
sequence = [Conv2dBlock(
@ -25,8 +25,13 @@ class Encoder(nn.Module):
))
self.out_channels = multiple_now * base_channels
sequence += [
ResidualBlock(self.out_channels, self.out_channels, padding_mode, activation_type, norm_type=res_norm_type)
for _ in range(num_res)
ResidualBlock(
self.out_channels,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type=res_norm_type,
pre_activation=pre_activation
) for _ in range(num_res)
]
self.sequence = nn.Sequential(*sequence)
@ -38,11 +43,16 @@ class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
activation_type="ReLU", padding_mode='reflect',
up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN"):
res_norm_type="AdaIN", pre_activation=False):
super().__init__()
self.residual_blocks = nn.ModuleList([
ResidualBlock(in_channels, in_channels, padding_mode, activation_type, norm_type=res_norm_type)
for _ in range(num_residual_blocks)
ResidualBlock(
in_channels,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type=res_norm_type,
pre_activation=pre_activation
) for _ in range(num_residual_blocks)
])
sequence = list()
@ -50,15 +60,13 @@ class Decoder(nn.Module):
for i in range(num_up_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2,
kernel_size=up_conv_kernel_size, stride=1,
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type),
))
channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels,
kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type="Tanh", norm_type="NONE"))
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
self.up_sequence = nn.Sequential(*sequence)

View File

@ -8,12 +8,13 @@ from model.image_translation.CycleGAN import Encoder, Decoder
class StyleEncoder(nn.Module):
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE",
pre_activation=False):
super().__init__()
self.down_encoder = Encoder(
in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, down_conv_kernel_size=4,
down_conv_norm_type=norm_type, down_conv_kernel_size=4, pre_activation=pre_activation,
)
sequence = list()
sequence.append(nn.AdaptiveAvgPool2d(1))
@ -47,19 +48,19 @@ class Generator(nn.Module):
num_mlp_base_feature=256, num_mlp_blocks=3,
max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2,
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
padding_mode='reflect', activation_type="ReLU"):
padding_mode='reflect', activation_type="ReLU", pre_activation=False):
super().__init__()
self.content_encoder = Encoder(
in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks,
max_down_sampling_multiple=num_content_down_sampling,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type="IN", down_conv_kernel_size=4,
res_norm_type="IN"
res_norm_type="IN", pre_activation=pre_activation
)
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
max_down_sampling_multiple, padding_mode, activation_type,
norm_type="NONE")
norm_type="NONE", pre_activation=pre_activation)
content_channels = base_channels * (2 ** max_down_sampling_multiple)
@ -70,7 +71,7 @@ class Generator(nn.Module):
self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
activation_type=activation_type, padding_mode=padding_mode,
up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN")
res_norm_type="AdaIN", pre_activation=pre_activation)
def encode(self, x):
return self.content_encoder(x), self.style_encoder(x)

View File

@ -26,8 +26,8 @@ class CAMClassifier(nn.Module):
self.avg_fc = nn.Linear(in_channels, 1, bias=False)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.max_fc = nn.Linear(in_channels, 1, bias=False)
self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, activation_type=activation_type,
norm_type="NONE", kernel_size=1, stride=1, bias=True)
self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, kernel_size=1, stride=1, bias=True,
activation_type=activation_type, norm_type="NONE")
def forward(self, x):
avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1))
@ -42,7 +42,7 @@ class CAMClassifier(nn.Module):
@MODEL.register_module("UGATIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False,
activation_type="ReLU", norm_type="IN", padding_mode='reflect'):
activation_type="ReLU", norm_type="IN", padding_mode='reflect', pre_activation=False):
super(Generator, self).__init__()
self.light = light
@ -50,7 +50,8 @@ class Generator(nn.Module):
n_down_sampling = 2
self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type)
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type,
pre_activation=pre_activation)
mult = 2 ** n_down_sampling
self.cam = CAMClassifier(base_channels * mult, activation_type)
@ -74,7 +75,7 @@ class Generator(nn.Module):
base_channels * mult, out_channels, n_down_sampling, num_blocks,
activation_type=activation_type, padding_mode=padding_mode,
up_conv_kernel_size=3, up_conv_norm_type="ILN",
res_norm_type="AdaILN"
res_norm_type="AdaILN", pre_activation=pre_activation
)
def forward(self, x):