add patch d

This commit is contained in:
budui 2020-10-13 10:31:17 +08:00
parent 611901cbdf
commit 0927fa3de5

View File

@ -1,5 +1,6 @@
import torch.nn as nn import torch.nn as nn
from model import MODEL
from model.base.module import Conv2dBlock, ResidualBlock from model.base.module import Conv2dBlock, ResidualBlock
@ -43,7 +44,7 @@ class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks, def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
activation_type="ReLU", padding_mode='reflect', activation_type="ReLU", padding_mode='reflect',
up_conv_kernel_size=5, up_conv_norm_type="LN", up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN", pre_activation=False): res_norm_type="AdaIN", pre_activation=False, use_transpose_conv=False):
super().__init__() super().__init__()
self.residual_blocks = nn.ModuleList([ self.residual_blocks = nn.ModuleList([
ResidualBlock( ResidualBlock(
@ -57,13 +58,23 @@ class Decoder(nn.Module):
sequence = list() sequence = list()
channels = in_channels channels = in_channels
padding = (up_conv_kernel_size - 1) // 2
for i in range(num_up_sampling): for i in range(num_up_sampling):
sequence.append(nn.Sequential( if use_transpose_conv:
nn.Upsample(scale_factor=2), sequence.append(Conv2dBlock(
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1, channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode, padding=padding, output_padding=padding,
activation_type=activation_type, norm_type=up_conv_norm_type), padding_mode=padding_mode,
)) activation_type=activation_type, norm_type=up_conv_norm_type,
use_transpose_conv=True
))
else:
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type),
))
channels = channels // 2 channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")) padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
@ -74,3 +85,61 @@ class Decoder(nn.Module):
for i, blk in enumerate(self.residual_blocks): for i, blk in enumerate(self.residual_blocks):
x = blk(x) x = blk(x)
return self.up_sequence(x) return self.up_sequence(x)
@MODEL.register_module("CycleGAN-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, activation_type="ReLU",
padding_mode='reflect', norm_type="IN", pre_activation=True, use_transpose_conv=True):
super().__init__()
self.encoder = Encoder(in_channels, base_channels, num_conv=2, num_res=num_blocks,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, res_norm_type=norm_type, pre_activation=pre_activation)
self.decoder = Decoder(self.encoder.out_channels, out_channels, num_up_sampling=2, num_residual_blocks=0,
padding_mode=padding_mode, activation_type=activation_type,
up_conv_kernel_size=3, up_conv_norm_type=norm_type,
pre_activation=pre_activation, use_transpose_conv=use_transpose_conv)
def forward(self, x):
return self.decoder(self.encoder(x))
@MODEL.register_module("PatchDiscriminator")
class PatchDiscriminator(nn.Module):
def __int__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False,
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.need_intermediate_feature = need_intermediate_feature
kernel_size = 4
padding = (kernel_size - 1) // 2
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 3)
stride = 1 if i == num_conv - 1 else 2
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
sequence.append(nn.Conv2d(
base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding, padding_mode=padding_mode))
if self.need_intermediate_feature:
self.sequence = nn.ModuleList(sequence)
else:
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
if self.need_intermediate_feature:
intermediate_feature = []
for layer in self.sequence:
x = layer(x)
intermediate_feature.append(x)
return tuple(intermediate_feature)
else:
return self.sequence(x)