import torch.nn as nn from model import MODEL from model.base.module import Conv2dBlock, ResidualBlock class Encoder(nn.Module): def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", down_conv_norm_type="IN", down_conv_kernel_size=3, res_norm_type="IN", pre_activation=False): super().__init__() sequence = [Conv2dBlock( in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, activation_type=activation_type, norm_type=down_conv_norm_type )] multiple_now = 1 for i in range(1, num_conv + 1): multiple_prev = multiple_now multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple) sequence.append(Conv2dBlock( multiple_prev * base_channels, multiple_now * base_channels, kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode="zeros", activation_type=activation_type, norm_type=down_conv_norm_type )) self.out_channels = multiple_now * base_channels sequence += [ ResidualBlock( self.out_channels, padding_mode=padding_mode, activation_type=activation_type, norm_type=res_norm_type, pre_activation=pre_activation ) for _ in range(num_res) ] self.sequence = nn.Sequential(*sequence) def forward(self, x): return self.sequence(x) class Decoder(nn.Module): def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks, activation_type="ReLU", padding_mode='reflect', up_conv_kernel_size=5, up_conv_norm_type="LN", res_norm_type="AdaIN", pre_activation=False, use_transpose_conv=False): super().__init__() self.residual_blocks = nn.ModuleList([ ResidualBlock( in_channels, padding_mode=padding_mode, activation_type=activation_type, norm_type=res_norm_type, pre_activation=pre_activation ) for _ in range(num_residual_blocks) ]) sequence = list() channels = in_channels padding = (up_conv_kernel_size - 1) // 2 for i in range(num_up_sampling): if use_transpose_conv: sequence.append(Conv2dBlock( channels, channels // 2, kernel_size=up_conv_kernel_size, stride=2, padding=padding, output_padding=padding, padding_mode=padding_mode, activation_type=activation_type, norm_type=up_conv_norm_type, use_transpose_conv=True )) else: sequence.append(nn.Sequential( nn.Upsample(scale_factor=2), Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1, padding=padding, padding_mode=padding_mode, activation_type=activation_type, norm_type=up_conv_norm_type), )) channels = channels // 2 sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")) self.up_sequence = nn.Sequential(*sequence) def forward(self, x): for i, blk in enumerate(self.residual_blocks): x = blk(x) return self.up_sequence(x) @MODEL.register_module("CycleGAN-Generator") class Generator(nn.Module): def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, activation_type="ReLU", padding_mode='reflect', norm_type="IN", pre_activation=False, use_transpose_conv=True): super().__init__() self.encoder = Encoder(in_channels, base_channels, num_conv=2, num_res=num_blocks, padding_mode=padding_mode, activation_type=activation_type, down_conv_norm_type=norm_type, res_norm_type=norm_type, pre_activation=pre_activation) self.decoder = Decoder(self.encoder.out_channels, out_channels, num_up_sampling=2, num_residual_blocks=0, padding_mode=padding_mode, activation_type=activation_type, up_conv_kernel_size=3, up_conv_norm_type=norm_type, pre_activation=pre_activation, use_transpose_conv=use_transpose_conv) def forward(self, x): return self.decoder(self.encoder(x)) @MODEL.register_module("PatchDiscriminator") class PatchDiscriminator(nn.Module): def __init__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False, norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"): super().__init__() self.need_intermediate_feature = need_intermediate_feature kernel_size = 4 padding = (kernel_size - 1) // 2 sequence = [Conv2dBlock( in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding, padding_mode=padding_mode, activation_type=activation_type, norm_type=norm_type )] multiple_now = 1 for i in range(1, num_conv): multiple_prev = multiple_now multiple_now = min(2 ** i, 2 ** 3) stride = 1 if i == num_conv - 1 else 2 sequence.append(Conv2dBlock( multiple_prev * base_channels, multiple_now * base_channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode, activation_type=activation_type, norm_type=norm_type )) sequence.append(nn.Conv2d( base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding, padding_mode=padding_mode)) if self.need_intermediate_feature: self.sequence = nn.ModuleList(sequence) else: self.sequence = nn.Sequential(*sequence) def forward(self, x): if self.need_intermediate_feature: intermediate_feature = [] for layer in self.sequence: x = layer(x) intermediate_feature.append(x) return tuple(intermediate_feature) else: return self.sequence(x) if __name__ == '__main__': g = Generator(**dict(in_channels=3, out_channels=3)) print(g) pd = PatchDiscriminator(**dict(in_channels=3, base_channels=64, num_conv=4)) print(pd)