151 lines
6.6 KiB
Python
151 lines
6.6 KiB
Python
import torch.nn as nn
|
|
|
|
from model import MODEL
|
|
from model.base.module import Conv2dBlock, ResidualBlock
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2,
|
|
padding_mode='reflect', activation_type="ReLU",
|
|
down_conv_norm_type="IN", down_conv_kernel_size=3,
|
|
res_norm_type="IN", pre_activation=False):
|
|
super().__init__()
|
|
|
|
sequence = [Conv2dBlock(
|
|
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
|
activation_type=activation_type, norm_type=down_conv_norm_type
|
|
)]
|
|
multiple_now = 1
|
|
for i in range(1, num_conv + 1):
|
|
multiple_prev = multiple_now
|
|
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
|
|
sequence.append(Conv2dBlock(
|
|
multiple_prev * base_channels, multiple_now * base_channels,
|
|
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode="zeros",
|
|
activation_type=activation_type, norm_type=down_conv_norm_type
|
|
))
|
|
self.out_channels = multiple_now * base_channels
|
|
sequence += [
|
|
ResidualBlock(
|
|
self.out_channels,
|
|
padding_mode=padding_mode,
|
|
activation_type=activation_type,
|
|
norm_type=res_norm_type,
|
|
pre_activation=pre_activation
|
|
) for _ in range(num_res)
|
|
]
|
|
self.sequence = nn.Sequential(*sequence)
|
|
|
|
def forward(self, x):
|
|
return self.sequence(x)
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
|
|
activation_type="ReLU", padding_mode='reflect',
|
|
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
|
res_norm_type="AdaIN", pre_activation=False, use_transpose_conv=False):
|
|
super().__init__()
|
|
self.residual_blocks = nn.ModuleList([
|
|
ResidualBlock(
|
|
in_channels,
|
|
padding_mode=padding_mode,
|
|
activation_type=activation_type,
|
|
norm_type=res_norm_type,
|
|
pre_activation=pre_activation
|
|
) for _ in range(num_residual_blocks)
|
|
])
|
|
|
|
sequence = list()
|
|
channels = in_channels
|
|
padding = (up_conv_kernel_size - 1) // 2
|
|
for i in range(num_up_sampling):
|
|
if use_transpose_conv:
|
|
sequence.append(Conv2dBlock(
|
|
channels, channels // 2, kernel_size=up_conv_kernel_size, stride=2,
|
|
padding=padding, output_padding=padding,
|
|
padding_mode=padding_mode,
|
|
activation_type=activation_type, norm_type=up_conv_norm_type,
|
|
use_transpose_conv=True
|
|
))
|
|
else:
|
|
sequence.append(nn.Sequential(
|
|
nn.Upsample(scale_factor=2),
|
|
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
|
|
padding=padding, padding_mode=padding_mode,
|
|
activation_type=activation_type, norm_type=up_conv_norm_type),
|
|
))
|
|
channels = channels // 2
|
|
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
|
|
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
|
|
|
|
self.up_sequence = nn.Sequential(*sequence)
|
|
|
|
def forward(self, x):
|
|
for i, blk in enumerate(self.residual_blocks):
|
|
x = blk(x)
|
|
return self.up_sequence(x)
|
|
|
|
|
|
@MODEL.register_module("CycleGAN-Generator")
|
|
class Generator(nn.Module):
|
|
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, activation_type="ReLU",
|
|
padding_mode='reflect', norm_type="IN", pre_activation=False, use_transpose_conv=True):
|
|
super().__init__()
|
|
self.encoder = Encoder(in_channels, base_channels, num_conv=2, num_res=num_blocks,
|
|
padding_mode=padding_mode, activation_type=activation_type,
|
|
down_conv_norm_type=norm_type, res_norm_type=norm_type, pre_activation=pre_activation)
|
|
self.decoder = Decoder(self.encoder.out_channels, out_channels, num_up_sampling=2, num_residual_blocks=0,
|
|
padding_mode=padding_mode, activation_type=activation_type,
|
|
up_conv_kernel_size=3, up_conv_norm_type=norm_type,
|
|
pre_activation=pre_activation, use_transpose_conv=use_transpose_conv)
|
|
|
|
def forward(self, x):
|
|
return self.decoder(self.encoder(x))
|
|
|
|
|
|
@MODEL.register_module("PatchDiscriminator")
|
|
class PatchDiscriminator(nn.Module):
|
|
def __init__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False,
|
|
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
|
|
super().__init__()
|
|
self.need_intermediate_feature = need_intermediate_feature
|
|
kernel_size = 4
|
|
padding = (kernel_size - 1) // 2
|
|
sequence = [Conv2dBlock(
|
|
in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding, padding_mode=padding_mode,
|
|
activation_type=activation_type, norm_type=norm_type
|
|
)]
|
|
|
|
multiple_now = 1
|
|
for i in range(1, num_conv):
|
|
multiple_prev = multiple_now
|
|
multiple_now = min(2 ** i, 2 ** 3)
|
|
stride = 1 if i == num_conv - 1 else 2
|
|
sequence.append(Conv2dBlock(
|
|
multiple_prev * base_channels, multiple_now * base_channels,
|
|
kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode,
|
|
activation_type=activation_type, norm_type=norm_type
|
|
))
|
|
sequence.append(nn.Conv2d(
|
|
base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding, padding_mode=padding_mode))
|
|
if self.need_intermediate_feature:
|
|
self.sequence = nn.ModuleList(sequence)
|
|
else:
|
|
self.sequence = nn.Sequential(*sequence)
|
|
|
|
def forward(self, x):
|
|
if self.need_intermediate_feature:
|
|
intermediate_feature = []
|
|
for layer in self.sequence:
|
|
x = layer(x)
|
|
intermediate_feature.append(x)
|
|
return tuple(intermediate_feature)
|
|
else:
|
|
return self.sequence(x)
|
|
|
|
if __name__ == '__main__':
|
|
g = Generator(**dict(in_channels=3, out_channels=3))
|
|
print(g)
|
|
pd = PatchDiscriminator(**dict(in_channels=3, base_channels=64, num_conv=4))
|
|
print(pd) |