raycv/model/image_translation/CycleGAN.py

77 lines
3.1 KiB
Python

import torch.nn as nn
from model.base.module import Conv2dBlock, ResidualBlock
class Encoder(nn.Module):
def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2,
padding_mode='reflect', activation_type="ReLU",
down_conv_norm_type="IN", down_conv_kernel_size=3,
res_norm_type="IN", pre_activation=False):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type=down_conv_norm_type
)]
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=down_conv_norm_type
))
self.out_channels = multiple_now * base_channels
sequence += [
ResidualBlock(
self.out_channels,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type=res_norm_type,
pre_activation=pre_activation
) for _ in range(num_res)
]
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
return self.sequence(x)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
activation_type="ReLU", padding_mode='reflect',
up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN", pre_activation=False):
super().__init__()
self.residual_blocks = nn.ModuleList([
ResidualBlock(
in_channels,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type=res_norm_type,
pre_activation=pre_activation
) for _ in range(num_residual_blocks)
])
sequence = list()
channels = in_channels
for i in range(num_up_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type),
))
channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
self.up_sequence = nn.Sequential(*sequence)
def forward(self, x):
for i, blk in enumerate(self.residual_blocks):
x = blk(x)
return self.up_sequence(x)