add flag to switch to norm-activ-conv
This commit is contained in:
parent
9c08b4cd09
commit
06b2abd19a
@ -52,7 +52,8 @@ class LinearBlock(nn.Module):
|
||||
|
||||
class Conv2dBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
||||
activation_type="ReLU", norm_type="NONE", **conv_kwargs):
|
||||
activation_type="ReLU", norm_type="NONE",
|
||||
additional_norm_kwargs=None, **conv_kwargs):
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
self.activation_type = activation_type
|
||||
@ -61,40 +62,13 @@ class Conv2dBlock(nn.Module):
|
||||
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
||||
|
||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||
self.normalization = _normalization(norm_type, out_channels)
|
||||
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
|
||||
self.activation = _activation(activation_type)
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(self.normalization(self.convolution(x)))
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, out_channels=None, padding_mode='reflect',
|
||||
activation_type="ReLU", norm_type="IN", out_activation_type=None):
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
|
||||
if out_channels is None:
|
||||
out_channels = num_channels
|
||||
if out_activation_type is None:
|
||||
out_activation_type = "NONE"
|
||||
|
||||
self.learn_skip_connection = num_channels != out_channels
|
||||
|
||||
self.conv1 = Conv2dBlock(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=activation_type)
|
||||
self.conv2 = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=out_activation_type)
|
||||
|
||||
if self.learn_skip_connection:
|
||||
self.res_conv = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=out_activation_type)
|
||||
|
||||
def forward(self, x):
|
||||
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||
return self.conv2(self.conv1(x)) + res
|
||||
|
||||
|
||||
class ReverseConv2dBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int,
|
||||
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
|
||||
@ -107,19 +81,44 @@ class ReverseConv2dBlock(nn.Module):
|
||||
return self.convolution(self.activation(self.normalization(x)))
|
||||
|
||||
|
||||
class ReverseResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, padding_mode="reflect",
|
||||
norm_type="IN", additional_norm_kwargs=None, activation_type="ReLU"):
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels,
|
||||
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
|
||||
out_channels=None, out_activation_type=None):
|
||||
"""
|
||||
Residual Conv Block
|
||||
:param in_channels:
|
||||
:param out_channels:
|
||||
:param padding_mode:
|
||||
:param activation_type:
|
||||
:param norm_type:
|
||||
:param out_activation_type:
|
||||
:param pre_activation: full pre-activation mode from https://arxiv.org/pdf/1603.05027v3.pdf, figure 4
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
if out_activation_type is None:
|
||||
# if not specify `out_activation_type`, using default `out_activation_type`
|
||||
# `out_activation_type` default mode:
|
||||
# "NONE" for not full pre-activation
|
||||
# `norm_type` for full pre-activation
|
||||
out_activation_type = "NONE" if not pre_activation else norm_type
|
||||
|
||||
self.learn_skip_connection = in_channels != out_channels
|
||||
self.conv1 = ReverseConv2dBlock(in_channels, in_channels, activation_type, norm_type, additional_norm_kwargs,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
self.conv2 = ReverseConv2dBlock(in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
|
||||
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
|
||||
|
||||
self.conv1 = conv_block(in_channels, in_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=activation_type)
|
||||
self.conv2 = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=out_activation_type)
|
||||
|
||||
if self.learn_skip_connection:
|
||||
self.res_conv = ReverseConv2dBlock(
|
||||
in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
self.res_conv = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=out_activation_type)
|
||||
|
||||
def forward(self, x):
|
||||
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||
|
||||
@ -7,7 +7,7 @@ class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2,
|
||||
padding_mode='reflect', activation_type="ReLU",
|
||||
down_conv_norm_type="IN", down_conv_kernel_size=3,
|
||||
res_norm_type="IN"):
|
||||
res_norm_type="IN", pre_activation=False):
|
||||
super().__init__()
|
||||
|
||||
sequence = [Conv2dBlock(
|
||||
@ -25,8 +25,13 @@ class Encoder(nn.Module):
|
||||
))
|
||||
self.out_channels = multiple_now * base_channels
|
||||
sequence += [
|
||||
ResidualBlock(self.out_channels, self.out_channels, padding_mode, activation_type, norm_type=res_norm_type)
|
||||
for _ in range(num_res)
|
||||
ResidualBlock(
|
||||
self.out_channels,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type,
|
||||
norm_type=res_norm_type,
|
||||
pre_activation=pre_activation
|
||||
) for _ in range(num_res)
|
||||
]
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
|
||||
@ -38,11 +43,16 @@ class Decoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
|
||||
activation_type="ReLU", padding_mode='reflect',
|
||||
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
||||
res_norm_type="AdaIN"):
|
||||
res_norm_type="AdaIN", pre_activation=False):
|
||||
super().__init__()
|
||||
self.residual_blocks = nn.ModuleList([
|
||||
ResidualBlock(in_channels, in_channels, padding_mode, activation_type, norm_type=res_norm_type)
|
||||
for _ in range(num_residual_blocks)
|
||||
ResidualBlock(
|
||||
in_channels,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type,
|
||||
norm_type=res_norm_type,
|
||||
pre_activation=pre_activation
|
||||
) for _ in range(num_residual_blocks)
|
||||
])
|
||||
|
||||
sequence = list()
|
||||
@ -50,15 +60,13 @@ class Decoder(nn.Module):
|
||||
for i in range(num_up_sampling):
|
||||
sequence.append(nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
Conv2dBlock(channels, channels // 2,
|
||||
kernel_size=up_conv_kernel_size, stride=1,
|
||||
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
|
||||
padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=up_conv_norm_type),
|
||||
))
|
||||
channels = channels // 2
|
||||
sequence.append(Conv2dBlock(channels, out_channels,
|
||||
kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||
activation_type="Tanh", norm_type="NONE"))
|
||||
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
|
||||
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
|
||||
|
||||
self.up_sequence = nn.Sequential(*sequence)
|
||||
|
||||
|
||||
@ -8,12 +8,13 @@ from model.image_translation.CycleGAN import Encoder, Decoder
|
||||
|
||||
class StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
|
||||
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
|
||||
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE",
|
||||
pre_activation=False):
|
||||
super().__init__()
|
||||
self.down_encoder = Encoder(
|
||||
in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple,
|
||||
padding_mode=padding_mode, activation_type=activation_type,
|
||||
down_conv_norm_type=norm_type, down_conv_kernel_size=4,
|
||||
down_conv_norm_type=norm_type, down_conv_kernel_size=4, pre_activation=pre_activation,
|
||||
)
|
||||
sequence = list()
|
||||
sequence.append(nn.AdaptiveAvgPool2d(1))
|
||||
@ -47,19 +48,19 @@ class Generator(nn.Module):
|
||||
num_mlp_base_feature=256, num_mlp_blocks=3,
|
||||
max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2,
|
||||
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
|
||||
padding_mode='reflect', activation_type="ReLU"):
|
||||
padding_mode='reflect', activation_type="ReLU", pre_activation=False):
|
||||
super().__init__()
|
||||
self.content_encoder = Encoder(
|
||||
in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks,
|
||||
max_down_sampling_multiple=num_content_down_sampling,
|
||||
padding_mode=padding_mode, activation_type=activation_type,
|
||||
down_conv_norm_type="IN", down_conv_kernel_size=4,
|
||||
res_norm_type="IN"
|
||||
res_norm_type="IN", pre_activation=pre_activation
|
||||
)
|
||||
|
||||
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
|
||||
max_down_sampling_multiple, padding_mode, activation_type,
|
||||
norm_type="NONE")
|
||||
norm_type="NONE", pre_activation=pre_activation)
|
||||
|
||||
content_channels = base_channels * (2 ** max_down_sampling_multiple)
|
||||
|
||||
@ -70,7 +71,7 @@ class Generator(nn.Module):
|
||||
self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
|
||||
activation_type=activation_type, padding_mode=padding_mode,
|
||||
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
||||
res_norm_type="AdaIN")
|
||||
res_norm_type="AdaIN", pre_activation=pre_activation)
|
||||
|
||||
def encode(self, x):
|
||||
return self.content_encoder(x), self.style_encoder(x)
|
||||
|
||||
@ -26,8 +26,8 @@ class CAMClassifier(nn.Module):
|
||||
self.avg_fc = nn.Linear(in_channels, 1, bias=False)
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
self.max_fc = nn.Linear(in_channels, 1, bias=False)
|
||||
self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, activation_type=activation_type,
|
||||
norm_type="NONE", kernel_size=1, stride=1, bias=True)
|
||||
self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, kernel_size=1, stride=1, bias=True,
|
||||
activation_type=activation_type, norm_type="NONE")
|
||||
|
||||
def forward(self, x):
|
||||
avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1))
|
||||
@ -42,7 +42,7 @@ class CAMClassifier(nn.Module):
|
||||
@MODEL.register_module("UGATIT-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False,
|
||||
activation_type="ReLU", norm_type="IN", padding_mode='reflect'):
|
||||
activation_type="ReLU", norm_type="IN", padding_mode='reflect', pre_activation=False):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
self.light = light
|
||||
@ -50,7 +50,8 @@ class Generator(nn.Module):
|
||||
n_down_sampling = 2
|
||||
self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
|
||||
padding_mode=padding_mode, activation_type=activation_type,
|
||||
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type)
|
||||
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type,
|
||||
pre_activation=pre_activation)
|
||||
mult = 2 ** n_down_sampling
|
||||
self.cam = CAMClassifier(base_channels * mult, activation_type)
|
||||
|
||||
@ -74,7 +75,7 @@ class Generator(nn.Module):
|
||||
base_channels * mult, out_channels, n_down_sampling, num_blocks,
|
||||
activation_type=activation_type, padding_mode=padding_mode,
|
||||
up_conv_kernel_size=3, up_conv_norm_type="ILN",
|
||||
res_norm_type="AdaILN"
|
||||
res_norm_type="AdaILN", pre_activation=pre_activation
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user