add flag to switch to norm-activ-conv
This commit is contained in:
parent
9c08b4cd09
commit
06b2abd19a
@ -52,7 +52,8 @@ class LinearBlock(nn.Module):
|
|||||||
|
|
||||||
class Conv2dBlock(nn.Module):
|
class Conv2dBlock(nn.Module):
|
||||||
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
||||||
activation_type="ReLU", norm_type="NONE", **conv_kwargs):
|
activation_type="ReLU", norm_type="NONE",
|
||||||
|
additional_norm_kwargs=None, **conv_kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_type = norm_type
|
self.norm_type = norm_type
|
||||||
self.activation_type = activation_type
|
self.activation_type = activation_type
|
||||||
@ -61,40 +62,13 @@ class Conv2dBlock(nn.Module):
|
|||||||
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
||||||
|
|
||||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||||
self.normalization = _normalization(norm_type, out_channels)
|
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
|
||||||
self.activation = _activation(activation_type)
|
self.activation = _activation(activation_type)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.activation(self.normalization(self.convolution(x)))
|
return self.activation(self.normalization(self.convolution(x)))
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
|
||||||
def __init__(self, num_channels, out_channels=None, padding_mode='reflect',
|
|
||||||
activation_type="ReLU", norm_type="IN", out_activation_type=None):
|
|
||||||
super().__init__()
|
|
||||||
self.norm_type = norm_type
|
|
||||||
|
|
||||||
if out_channels is None:
|
|
||||||
out_channels = num_channels
|
|
||||||
if out_activation_type is None:
|
|
||||||
out_activation_type = "NONE"
|
|
||||||
|
|
||||||
self.learn_skip_connection = num_channels != out_channels
|
|
||||||
|
|
||||||
self.conv1 = Conv2dBlock(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
|
||||||
norm_type=norm_type, activation_type=activation_type)
|
|
||||||
self.conv2 = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
|
||||||
norm_type=norm_type, activation_type=out_activation_type)
|
|
||||||
|
|
||||||
if self.learn_skip_connection:
|
|
||||||
self.res_conv = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
|
||||||
norm_type=norm_type, activation_type=out_activation_type)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
res = x if not self.learn_skip_connection else self.res_conv(x)
|
|
||||||
return self.conv2(self.conv1(x)) + res
|
|
||||||
|
|
||||||
|
|
||||||
class ReverseConv2dBlock(nn.Module):
|
class ReverseConv2dBlock(nn.Module):
|
||||||
def __init__(self, in_channels: int, out_channels: int,
|
def __init__(self, in_channels: int, out_channels: int,
|
||||||
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
|
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
|
||||||
@ -107,19 +81,44 @@ class ReverseConv2dBlock(nn.Module):
|
|||||||
return self.convolution(self.activation(self.normalization(x)))
|
return self.convolution(self.activation(self.normalization(x)))
|
||||||
|
|
||||||
|
|
||||||
class ReverseResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, padding_mode="reflect",
|
def __init__(self, in_channels,
|
||||||
norm_type="IN", additional_norm_kwargs=None, activation_type="ReLU"):
|
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
|
||||||
|
out_channels=None, out_activation_type=None):
|
||||||
|
"""
|
||||||
|
Residual Conv Block
|
||||||
|
:param in_channels:
|
||||||
|
:param out_channels:
|
||||||
|
:param padding_mode:
|
||||||
|
:param activation_type:
|
||||||
|
:param norm_type:
|
||||||
|
:param out_activation_type:
|
||||||
|
:param pre_activation: full pre-activation mode from https://arxiv.org/pdf/1603.05027v3.pdf, figure 4
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.norm_type = norm_type
|
||||||
|
|
||||||
|
if out_channels is None:
|
||||||
|
out_channels = in_channels
|
||||||
|
if out_activation_type is None:
|
||||||
|
# if not specify `out_activation_type`, using default `out_activation_type`
|
||||||
|
# `out_activation_type` default mode:
|
||||||
|
# "NONE" for not full pre-activation
|
||||||
|
# `norm_type` for full pre-activation
|
||||||
|
out_activation_type = "NONE" if not pre_activation else norm_type
|
||||||
|
|
||||||
self.learn_skip_connection = in_channels != out_channels
|
self.learn_skip_connection = in_channels != out_channels
|
||||||
self.conv1 = ReverseConv2dBlock(in_channels, in_channels, activation_type, norm_type, additional_norm_kwargs,
|
|
||||||
kernel_size=3, padding=1, padding_mode=padding_mode)
|
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
|
||||||
self.conv2 = ReverseConv2dBlock(in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
|
|
||||||
kernel_size=3, padding=1, padding_mode=padding_mode)
|
self.conv1 = conv_block(in_channels, in_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||||
|
norm_type=norm_type, activation_type=activation_type)
|
||||||
|
self.conv2 = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||||
|
norm_type=norm_type, activation_type=out_activation_type)
|
||||||
|
|
||||||
if self.learn_skip_connection:
|
if self.learn_skip_connection:
|
||||||
self.res_conv = ReverseConv2dBlock(
|
self.res_conv = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||||
in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
|
norm_type=norm_type, activation_type=out_activation_type)
|
||||||
kernel_size=3, padding=1, padding_mode=padding_mode)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
res = x if not self.learn_skip_connection else self.res_conv(x)
|
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||||
|
|||||||
@ -7,7 +7,7 @@ class Encoder(nn.Module):
|
|||||||
def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2,
|
def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2,
|
||||||
padding_mode='reflect', activation_type="ReLU",
|
padding_mode='reflect', activation_type="ReLU",
|
||||||
down_conv_norm_type="IN", down_conv_kernel_size=3,
|
down_conv_norm_type="IN", down_conv_kernel_size=3,
|
||||||
res_norm_type="IN"):
|
res_norm_type="IN", pre_activation=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
sequence = [Conv2dBlock(
|
sequence = [Conv2dBlock(
|
||||||
@ -25,8 +25,13 @@ class Encoder(nn.Module):
|
|||||||
))
|
))
|
||||||
self.out_channels = multiple_now * base_channels
|
self.out_channels = multiple_now * base_channels
|
||||||
sequence += [
|
sequence += [
|
||||||
ResidualBlock(self.out_channels, self.out_channels, padding_mode, activation_type, norm_type=res_norm_type)
|
ResidualBlock(
|
||||||
for _ in range(num_res)
|
self.out_channels,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
activation_type=activation_type,
|
||||||
|
norm_type=res_norm_type,
|
||||||
|
pre_activation=pre_activation
|
||||||
|
) for _ in range(num_res)
|
||||||
]
|
]
|
||||||
self.sequence = nn.Sequential(*sequence)
|
self.sequence = nn.Sequential(*sequence)
|
||||||
|
|
||||||
@ -38,11 +43,16 @@ class Decoder(nn.Module):
|
|||||||
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
|
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
|
||||||
activation_type="ReLU", padding_mode='reflect',
|
activation_type="ReLU", padding_mode='reflect',
|
||||||
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
||||||
res_norm_type="AdaIN"):
|
res_norm_type="AdaIN", pre_activation=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.residual_blocks = nn.ModuleList([
|
self.residual_blocks = nn.ModuleList([
|
||||||
ResidualBlock(in_channels, in_channels, padding_mode, activation_type, norm_type=res_norm_type)
|
ResidualBlock(
|
||||||
for _ in range(num_residual_blocks)
|
in_channels,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
activation_type=activation_type,
|
||||||
|
norm_type=res_norm_type,
|
||||||
|
pre_activation=pre_activation
|
||||||
|
) for _ in range(num_residual_blocks)
|
||||||
])
|
])
|
||||||
|
|
||||||
sequence = list()
|
sequence = list()
|
||||||
@ -50,15 +60,13 @@ class Decoder(nn.Module):
|
|||||||
for i in range(num_up_sampling):
|
for i in range(num_up_sampling):
|
||||||
sequence.append(nn.Sequential(
|
sequence.append(nn.Sequential(
|
||||||
nn.Upsample(scale_factor=2),
|
nn.Upsample(scale_factor=2),
|
||||||
Conv2dBlock(channels, channels // 2,
|
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
|
||||||
kernel_size=up_conv_kernel_size, stride=1,
|
|
||||||
padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode,
|
padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode,
|
||||||
activation_type=activation_type, norm_type=up_conv_norm_type),
|
activation_type=activation_type, norm_type=up_conv_norm_type),
|
||||||
))
|
))
|
||||||
channels = channels // 2
|
channels = channels // 2
|
||||||
sequence.append(Conv2dBlock(channels, out_channels,
|
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
|
||||||
kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
|
||||||
activation_type="Tanh", norm_type="NONE"))
|
|
||||||
|
|
||||||
self.up_sequence = nn.Sequential(*sequence)
|
self.up_sequence = nn.Sequential(*sequence)
|
||||||
|
|
||||||
|
|||||||
@ -8,12 +8,13 @@ from model.image_translation.CycleGAN import Encoder, Decoder
|
|||||||
|
|
||||||
class StyleEncoder(nn.Module):
|
class StyleEncoder(nn.Module):
|
||||||
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
|
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
|
||||||
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
|
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE",
|
||||||
|
pre_activation=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.down_encoder = Encoder(
|
self.down_encoder = Encoder(
|
||||||
in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple,
|
in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple,
|
||||||
padding_mode=padding_mode, activation_type=activation_type,
|
padding_mode=padding_mode, activation_type=activation_type,
|
||||||
down_conv_norm_type=norm_type, down_conv_kernel_size=4,
|
down_conv_norm_type=norm_type, down_conv_kernel_size=4, pre_activation=pre_activation,
|
||||||
)
|
)
|
||||||
sequence = list()
|
sequence = list()
|
||||||
sequence.append(nn.AdaptiveAvgPool2d(1))
|
sequence.append(nn.AdaptiveAvgPool2d(1))
|
||||||
@ -47,19 +48,19 @@ class Generator(nn.Module):
|
|||||||
num_mlp_base_feature=256, num_mlp_blocks=3,
|
num_mlp_base_feature=256, num_mlp_blocks=3,
|
||||||
max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2,
|
max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2,
|
||||||
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
|
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
|
||||||
padding_mode='reflect', activation_type="ReLU"):
|
padding_mode='reflect', activation_type="ReLU", pre_activation=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.content_encoder = Encoder(
|
self.content_encoder = Encoder(
|
||||||
in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks,
|
in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks,
|
||||||
max_down_sampling_multiple=num_content_down_sampling,
|
max_down_sampling_multiple=num_content_down_sampling,
|
||||||
padding_mode=padding_mode, activation_type=activation_type,
|
padding_mode=padding_mode, activation_type=activation_type,
|
||||||
down_conv_norm_type="IN", down_conv_kernel_size=4,
|
down_conv_norm_type="IN", down_conv_kernel_size=4,
|
||||||
res_norm_type="IN"
|
res_norm_type="IN", pre_activation=pre_activation
|
||||||
)
|
)
|
||||||
|
|
||||||
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
|
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
|
||||||
max_down_sampling_multiple, padding_mode, activation_type,
|
max_down_sampling_multiple, padding_mode, activation_type,
|
||||||
norm_type="NONE")
|
norm_type="NONE", pre_activation=pre_activation)
|
||||||
|
|
||||||
content_channels = base_channels * (2 ** max_down_sampling_multiple)
|
content_channels = base_channels * (2 ** max_down_sampling_multiple)
|
||||||
|
|
||||||
@ -70,7 +71,7 @@ class Generator(nn.Module):
|
|||||||
self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
|
self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
|
||||||
activation_type=activation_type, padding_mode=padding_mode,
|
activation_type=activation_type, padding_mode=padding_mode,
|
||||||
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
||||||
res_norm_type="AdaIN")
|
res_norm_type="AdaIN", pre_activation=pre_activation)
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
return self.content_encoder(x), self.style_encoder(x)
|
return self.content_encoder(x), self.style_encoder(x)
|
||||||
|
|||||||
@ -26,8 +26,8 @@ class CAMClassifier(nn.Module):
|
|||||||
self.avg_fc = nn.Linear(in_channels, 1, bias=False)
|
self.avg_fc = nn.Linear(in_channels, 1, bias=False)
|
||||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||||
self.max_fc = nn.Linear(in_channels, 1, bias=False)
|
self.max_fc = nn.Linear(in_channels, 1, bias=False)
|
||||||
self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, activation_type=activation_type,
|
self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, kernel_size=1, stride=1, bias=True,
|
||||||
norm_type="NONE", kernel_size=1, stride=1, bias=True)
|
activation_type=activation_type, norm_type="NONE")
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1))
|
avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1))
|
||||||
@ -42,7 +42,7 @@ class CAMClassifier(nn.Module):
|
|||||||
@MODEL.register_module("UGATIT-Generator")
|
@MODEL.register_module("UGATIT-Generator")
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False,
|
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False,
|
||||||
activation_type="ReLU", norm_type="IN", padding_mode='reflect'):
|
activation_type="ReLU", norm_type="IN", padding_mode='reflect', pre_activation=False):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
|
|
||||||
self.light = light
|
self.light = light
|
||||||
@ -50,7 +50,8 @@ class Generator(nn.Module):
|
|||||||
n_down_sampling = 2
|
n_down_sampling = 2
|
||||||
self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
|
self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
|
||||||
padding_mode=padding_mode, activation_type=activation_type,
|
padding_mode=padding_mode, activation_type=activation_type,
|
||||||
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type)
|
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type,
|
||||||
|
pre_activation=pre_activation)
|
||||||
mult = 2 ** n_down_sampling
|
mult = 2 ** n_down_sampling
|
||||||
self.cam = CAMClassifier(base_channels * mult, activation_type)
|
self.cam = CAMClassifier(base_channels * mult, activation_type)
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ class Generator(nn.Module):
|
|||||||
base_channels * mult, out_channels, n_down_sampling, num_blocks,
|
base_channels * mult, out_channels, n_down_sampling, num_blocks,
|
||||||
activation_type=activation_type, padding_mode=padding_mode,
|
activation_type=activation_type, padding_mode=padding_mode,
|
||||||
up_conv_kernel_size=3, up_conv_norm_type="ILN",
|
up_conv_kernel_size=3, up_conv_norm_type="ILN",
|
||||||
res_norm_type="AdaILN"
|
res_norm_type="AdaILN", pre_activation=pre_activation
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user