63 lines
2.5 KiB
Python
63 lines
2.5 KiB
Python
import torch.nn as nn
|
|
|
|
from model.normalization import select_norm_layer
|
|
from model.registry import MODEL
|
|
from .base import ResidualBlock
|
|
|
|
|
|
@MODEL.register_module("CyCle-Generator")
|
|
class Generator(nn.Module):
|
|
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
|
|
norm_type="IN"):
|
|
super(Generator, self).__init__()
|
|
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
|
|
norm_layer = select_norm_layer(norm_type)
|
|
use_bias = norm_type == "IN"
|
|
|
|
self.start_conv = nn.Sequential(
|
|
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
|
|
bias=use_bias),
|
|
norm_layer(num_features=base_channels),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
# down sampling
|
|
submodules = []
|
|
num_down_sampling = 2
|
|
for i in range(num_down_sampling):
|
|
multiple = 2 ** i
|
|
submodules += [
|
|
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
|
|
kernel_size=3, stride=2, padding=1, bias=use_bias),
|
|
norm_layer(num_features=base_channels * multiple * 2),
|
|
nn.ReLU(inplace=True)
|
|
]
|
|
self.encoder = nn.Sequential(*submodules)
|
|
|
|
res_block_channels = num_down_sampling ** 2 * base_channels
|
|
self.resnet_middle = nn.Sequential(
|
|
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
|
|
range(num_blocks)])
|
|
|
|
# up sampling
|
|
submodules = []
|
|
for i in range(num_down_sampling):
|
|
multiple = 2 ** (num_down_sampling - i)
|
|
submodules += [
|
|
nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2,
|
|
padding=1, output_padding=1, bias=use_bias),
|
|
norm_layer(num_features=base_channels * multiple // 2),
|
|
nn.ReLU(inplace=True),
|
|
]
|
|
self.decoder = nn.Sequential(*submodules)
|
|
|
|
self.end_conv = nn.Sequential(
|
|
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
|
|
nn.Tanh()
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.encoder(self.start_conv(x))
|
|
x = self.resnet_middle(x)
|
|
return self.end_conv(self.decoder(x))
|