import torch import torch.nn as nn from model import MODEL from model.base.module import LinearBlock from model.image_translation.CycleGAN import Encoder, Decoder class StyleEncoder(nn.Module): def __init__(self, in_channels, out_dim, num_conv, base_channels=64, max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE", pre_activation=False): super().__init__() self.down_encoder = Encoder( in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple, padding_mode=padding_mode, activation_type=activation_type, down_conv_norm_type=norm_type, down_conv_kernel_size=4, pre_activation=pre_activation, ) sequence = list() sequence.append(nn.AdaptiveAvgPool2d(1)) # conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code sequence.append(nn.Conv2d(self.down_encoder.out_channels, out_dim, kernel_size=1, stride=1, padding=0)) self.sequence = nn.Sequential(*sequence) def forward(self, image): return self.sequence(image).view(image.size(0), -1) class MLPFusion(nn.Module): def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"): super().__init__() sequence = [LinearBlock(in_features, base_features, activation_type=activation_type, norm_type=norm_type)] sequence += [ LinearBlock(base_features, base_features, activation_type=activation_type, norm_type=norm_type) for _ in range(n_blocks - 2) ] sequence.append(LinearBlock(base_features, out_features, activation_type=activation_type, norm_type=norm_type)) self.sequence = nn.Sequential(*sequence) def forward(self, x): return self.sequence(x) @MODEL.register_module("MUNIT-Generator") class Generator(nn.Module): def __init__(self, in_channels, out_channels, base_channels=64, style_dim=8, num_mlp_base_feature=256, num_mlp_blocks=3, max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2, encoder_num_residual_blocks=4, decoder_num_residual_blocks=4, padding_mode='reflect', activation_type="ReLU", pre_activation=False): super().__init__() self.content_encoder = Encoder( in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks, max_down_sampling_multiple=num_content_down_sampling, padding_mode=padding_mode, activation_type=activation_type, down_conv_norm_type="IN", down_conv_kernel_size=4, res_norm_type="IN", pre_activation=pre_activation ) self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels, max_down_sampling_multiple, padding_mode, activation_type, norm_type="NONE", pre_activation=pre_activation) content_channels = base_channels * (2 ** max_down_sampling_multiple) self.fusion = MLPFusion(style_dim, decoder_num_residual_blocks * 2 * content_channels * 2, num_mlp_base_feature, num_mlp_blocks, activation_type, norm_type="NONE") self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks, activation_type=activation_type, padding_mode=padding_mode, up_conv_kernel_size=5, up_conv_norm_type="LN", res_norm_type="AdaIN", pre_activation=pre_activation) def encode(self, x): return self.content_encoder(x), self.style_encoder(x) def decode(self, content, style): as_param_style = torch.chunk(self.fusion(style), 2 * len(self.decoder.residual_blocks), dim=1) # set style for decoder for i, blk in enumerate(self.decoder.residual_blocks): blk.conv1.normalization.set_style(as_param_style[2 * i]) blk.conv2.normalization.set_style(as_param_style[2 * i + 1]) self.decoder(content) def forward(self, x): content, style = self.encode(x) return self.decode(content, style)