From 0927fa3de5909c224537c066a7145c2a5b75ebe9 Mon Sep 17 00:00:00 2001 From: budui Date: Tue, 13 Oct 2020 10:31:17 +0800 Subject: [PATCH] add patch d --- model/image_translation/CycleGAN.py | 83 ++++++++++++++++++++++++++--- 1 file changed, 76 insertions(+), 7 deletions(-) diff --git a/model/image_translation/CycleGAN.py b/model/image_translation/CycleGAN.py index dd487cb..bf0484d 100644 --- a/model/image_translation/CycleGAN.py +++ b/model/image_translation/CycleGAN.py @@ -1,5 +1,6 @@ import torch.nn as nn +from model import MODEL from model.base.module import Conv2dBlock, ResidualBlock @@ -43,7 +44,7 @@ class Decoder(nn.Module): def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks, activation_type="ReLU", padding_mode='reflect', up_conv_kernel_size=5, up_conv_norm_type="LN", - res_norm_type="AdaIN", pre_activation=False): + res_norm_type="AdaIN", pre_activation=False, use_transpose_conv=False): super().__init__() self.residual_blocks = nn.ModuleList([ ResidualBlock( @@ -57,13 +58,23 @@ class Decoder(nn.Module): sequence = list() channels = in_channels + padding = (up_conv_kernel_size - 1) // 2 for i in range(num_up_sampling): - sequence.append(nn.Sequential( - nn.Upsample(scale_factor=2), - Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1, - padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode, - activation_type=activation_type, norm_type=up_conv_norm_type), - )) + if use_transpose_conv: + sequence.append(Conv2dBlock( + channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1, + padding=padding, output_padding=padding, + padding_mode=padding_mode, + activation_type=activation_type, norm_type=up_conv_norm_type, + use_transpose_conv=True + )) + else: + sequence.append(nn.Sequential( + nn.Upsample(scale_factor=2), + Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1, + padding=padding, padding_mode=padding_mode, + activation_type=activation_type, norm_type=up_conv_norm_type), + )) channels = channels // 2 sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")) @@ -74,3 +85,61 @@ class Decoder(nn.Module): for i, blk in enumerate(self.residual_blocks): x = blk(x) return self.up_sequence(x) + + +@MODEL.register_module("CycleGAN-Generator") +class Generator(nn.Module): + def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, activation_type="ReLU", + padding_mode='reflect', norm_type="IN", pre_activation=True, use_transpose_conv=True): + super().__init__() + self.encoder = Encoder(in_channels, base_channels, num_conv=2, num_res=num_blocks, + padding_mode=padding_mode, activation_type=activation_type, + down_conv_norm_type=norm_type, res_norm_type=norm_type, pre_activation=pre_activation) + self.decoder = Decoder(self.encoder.out_channels, out_channels, num_up_sampling=2, num_residual_blocks=0, + padding_mode=padding_mode, activation_type=activation_type, + up_conv_kernel_size=3, up_conv_norm_type=norm_type, + pre_activation=pre_activation, use_transpose_conv=use_transpose_conv) + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +@MODEL.register_module("PatchDiscriminator") +class PatchDiscriminator(nn.Module): + def __int__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False, + norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"): + super().__init__() + self.need_intermediate_feature = need_intermediate_feature + kernel_size = 4 + padding = (kernel_size - 1) // 2 + sequence = [Conv2dBlock( + in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding, padding_mode=padding_mode, + activation_type=activation_type, norm_type=norm_type + )] + + multiple_now = 1 + for i in range(1, num_conv + 1): + multiple_prev = multiple_now + multiple_now = min(2 ** i, 2 ** 3) + stride = 1 if i == num_conv - 1 else 2 + sequence.append(Conv2dBlock( + multiple_prev * base_channels, multiple_now * base_channels, + kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode, + activation_type=activation_type, norm_type=norm_type + )) + sequence.append(nn.Conv2d( + base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding, padding_mode=padding_mode)) + if self.need_intermediate_feature: + self.sequence = nn.ModuleList(sequence) + else: + self.sequence = nn.Sequential(*sequence) + + def forward(self, x): + if self.need_intermediate_feature: + intermediate_feature = [] + for layer in self.sequence: + x = layer(x) + intermediate_feature.append(x) + return tuple(intermediate_feature) + else: + return self.sequence(x)