import torch import torch.nn as nn from model import MODEL from model.GAN.base import Conv2dBlock, ResBlock from model.normalization import select_norm_layer class StyleEncoder(nn.Module): def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False, max_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"): super(StyleEncoder, self).__init__() sequence = [Conv2dBlock( in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type )] multiple_now = 1 for i in range(1, num_conv + 1): multiple_prev = multiple_now multiple_now = min(2 ** i, 2 ** max_multiple) sequence.append(Conv2dBlock( multiple_prev * base_channels, multiple_now * base_channels, kernel_size=4, stride=2, padding=1, padding_mode=padding_mode, use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type )) sequence.append(nn.AdaptiveAvgPool2d(1)) # conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code sequence.append(nn.Conv2d(multiple_now * base_channels, out_dim, kernel_size=1, stride=1, padding=0)) self.model = nn.Sequential(*sequence) def forward(self, x): return self.model(x).view(x.size(0), -1) class ContentEncoder(nn.Module): def __init__(self, in_channels, num_down_sampling, num_res_blocks, base_channels=64, use_spectral_norm=False, padding_mode='reflect', activation_type="ReLU", norm_type="IN"): super().__init__() sequence = [Conv2dBlock( in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type )] for i in range(num_down_sampling): sequence.append(Conv2dBlock( base_channels * (2 ** i), base_channels * (2 ** (i + 1)), kernel_size=4, stride=2, padding=1, padding_mode=padding_mode, use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type )) sequence += [ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type, activation_type) for _ in range(num_res_blocks)] self.sequence = nn.Sequential(*sequence) def forward(self, x): return self.sequence(x) class Decoder(nn.Module): def __init__(self, in_channels, out_channels, num_up_sampling, num_res_blocks, use_spectral_norm=False, res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU", padding_mode='reflect'): super(Decoder, self).__init__() self.res_norm_type = res_norm_type self.res_blocks = nn.ModuleList([ ResBlock(in_channels, use_spectral_norm, padding_mode, res_norm_type, activation_type=activation_type) for _ in range(num_res_blocks) ]) sequence = list() channels = in_channels for i in range(num_up_sampling): sequence.append(nn.Sequential( nn.Upsample(scale_factor=2), Conv2dBlock(channels, channels // 2, kernel_size=5, stride=1, padding=2, padding_mode=padding_mode, use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type ), )) channels = channels // 2 sequence.append( Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect", use_spectral_norm=use_spectral_norm, activation_type="Tanh", norm_type="NONE")) self.sequence = nn.Sequential(*sequence) def forward(self, x): for blk in self.res_blocks: x = blk(x) return self.sequence(x) class Fusion(nn.Module): def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"): super().__init__() norm_layer = select_norm_layer(norm_type) self.start_fc = nn.Sequential( nn.Linear(in_features, base_features), norm_layer(base_features), nn.ReLU(True), ) self.fcs = nn.Sequential(*[ nn.Sequential( nn.Linear(base_features, base_features), norm_layer(base_features), nn.ReLU(True), ) for _ in range(n_blocks - 2) ]) self.end_fc = nn.Sequential( nn.Linear(base_features, out_features), ) def forward(self, x): x = self.start_fc(x) x = self.fcs(x) return self.end_fc(x) @MODEL.register_module("MUNIT-Generator") class Generator(nn.Module): def __init__(self, in_channels, out_channels, base_channels, num_sampling, num_style_dim, num_style_conv, num_content_res_blocks, num_decoder_res_blocks, num_fusion_dim, num_fusion_blocks, use_spectral_norm=False, activation_type="ReLU", padding_mode='reflect'): super().__init__() self.num_decoder_res_blocks = num_decoder_res_blocks self.content_encoder = ContentEncoder(in_channels, num_sampling, num_content_res_blocks, base_channels, use_spectral_norm, padding_mode, activation_type, norm_type="IN") self.style_encoder = StyleEncoder(in_channels, num_style_dim, num_style_conv, base_channels, use_spectral_norm, padding_mode, activation_type, norm_type="NONE") content_channels = base_channels * (2 ** 2) self.decoder = Decoder(content_channels, out_channels, num_sampling, num_decoder_res_blocks, use_spectral_norm, "AdaIN", norm_type="LN", activation_type=activation_type, padding_mode=padding_mode) self.fusion = Fusion(num_style_dim, num_decoder_res_blocks * 2 * content_channels * 2, base_features=num_fusion_dim, n_blocks=num_fusion_blocks, norm_type="NONE") def encode(self, x): return self.content_encoder(x), self.style_encoder(x) def decode(self, content, style): as_param_style = torch.chunk(self.fusion(style), self.num_decoder_res_blocks * 2, dim=1) # set style for decoder for i, blk in enumerate(self.decoder.res_blocks): blk.conv1.normalization.set_style(as_param_style[2 * i]) blk.conv2.normalization.set_style(as_param_style[2 * i + 1]) return self.decoder(content) def forward(self, x): content, style = self.encode(x) return self.decode(content, style)