import torch.nn as nn from model.normalization import select_norm_layer from model.registry import MODEL from .base import ResidualBlock @MODEL.register_module("CyCle-Generator") class Generator(nn.Module): def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect', norm_type="IN"): super(Generator, self).__init__() assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.' norm_layer = select_norm_layer(norm_type) use_bias = norm_type == "IN" self.start_conv = nn.Sequential( nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3, bias=use_bias), norm_layer(num_features=base_channels), nn.ReLU(inplace=True) ) # down sampling submodules = [] num_down_sampling = 2 for i in range(num_down_sampling): multiple = 2 ** i submodules += [ nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(num_features=base_channels * multiple * 2), nn.ReLU(inplace=True) ] self.encoder = nn.Sequential(*submodules) res_block_channels = num_down_sampling ** 2 * base_channels self.resnet_middle = nn.Sequential( *[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in range(num_blocks)]) # up sampling submodules = [] for i in range(num_down_sampling): multiple = 2 ** (num_down_sampling - i) submodules += [ nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(num_features=base_channels * multiple // 2), nn.ReLU(inplace=True), ] self.decoder = nn.Sequential(*submodules) self.end_conv = nn.Sequential( nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode), nn.Tanh() ) def forward(self, x): x = self.encoder(self.start_conv(x)) x = self.resnet_middle(x) return self.end_conv(self.decoder(x))