import torch.nn as nn from model.base.module import Conv2dBlock, ResidualBlock class Encoder(nn.Module): def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", down_conv_norm_type="IN", down_conv_kernel_size=3, res_norm_type="IN", pre_activation=False): super().__init__() sequence = [Conv2dBlock( in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, activation_type=activation_type, norm_type=down_conv_norm_type )] multiple_now = 1 for i in range(1, num_conv + 1): multiple_prev = multiple_now multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple) sequence.append(Conv2dBlock( multiple_prev * base_channels, multiple_now * base_channels, kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode=padding_mode, activation_type=activation_type, norm_type=down_conv_norm_type )) self.out_channels = multiple_now * base_channels sequence += [ ResidualBlock( self.out_channels, padding_mode=padding_mode, activation_type=activation_type, norm_type=res_norm_type, pre_activation=pre_activation ) for _ in range(num_res) ] self.sequence = nn.Sequential(*sequence) def forward(self, x): return self.sequence(x) class Decoder(nn.Module): def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks, activation_type="ReLU", padding_mode='reflect', up_conv_kernel_size=5, up_conv_norm_type="LN", res_norm_type="AdaIN", pre_activation=False): super().__init__() self.residual_blocks = nn.ModuleList([ ResidualBlock( in_channels, padding_mode=padding_mode, activation_type=activation_type, norm_type=res_norm_type, pre_activation=pre_activation ) for _ in range(num_residual_blocks) ]) sequence = list() channels = in_channels for i in range(num_up_sampling): sequence.append(nn.Sequential( nn.Upsample(scale_factor=2), Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1, padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode, activation_type=activation_type, norm_type=up_conv_norm_type), )) channels = channels // 2 sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")) self.up_sequence = nn.Sequential(*sequence) def forward(self, x): for i, blk in enumerate(self.residual_blocks): x = blk(x) return self.up_sequence(x)