89 lines
4.1 KiB
Python
89 lines
4.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from model import MODEL
|
|
from model.base.module import LinearBlock
|
|
from model.image_translation.CycleGAN import Encoder, Decoder
|
|
|
|
|
|
class StyleEncoder(nn.Module):
|
|
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
|
|
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
|
|
super().__init__()
|
|
self.down_encoder = Encoder(
|
|
in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple,
|
|
padding_mode=padding_mode, activation_type=activation_type,
|
|
down_conv_norm_type=norm_type, down_conv_kernel_size=4,
|
|
)
|
|
sequence = list()
|
|
sequence.append(nn.AdaptiveAvgPool2d(1))
|
|
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
|
|
sequence.append(nn.Conv2d(self.down_encoder.out_channels, out_dim, kernel_size=1, stride=1, padding=0))
|
|
self.sequence = nn.Sequential(*sequence)
|
|
|
|
def forward(self, image):
|
|
return self.sequence(image).view(image.size(0), -1)
|
|
|
|
|
|
class MLPFusion(nn.Module):
|
|
def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"):
|
|
super().__init__()
|
|
|
|
sequence = [LinearBlock(in_features, base_features, activation_type=activation_type, norm_type=norm_type)]
|
|
sequence += [
|
|
LinearBlock(base_features, base_features, activation_type=activation_type, norm_type=norm_type)
|
|
for _ in range(n_blocks - 2)
|
|
]
|
|
sequence.append(LinearBlock(base_features, out_features, activation_type=activation_type, norm_type=norm_type))
|
|
self.sequence = nn.Sequential(*sequence)
|
|
|
|
def forward(self, x):
|
|
return self.sequence(x)
|
|
|
|
|
|
@MODEL.register_module("MUNIT-Generator")
|
|
class Generator(nn.Module):
|
|
def __init__(self, in_channels, out_channels, base_channels=64, style_dim=8,
|
|
num_mlp_base_feature=256, num_mlp_blocks=3,
|
|
max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2,
|
|
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
|
|
padding_mode='reflect', activation_type="ReLU"):
|
|
super().__init__()
|
|
self.content_encoder = Encoder(
|
|
in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks,
|
|
max_down_sampling_multiple=num_content_down_sampling,
|
|
padding_mode=padding_mode, activation_type=activation_type,
|
|
down_conv_norm_type="IN", down_conv_kernel_size=4,
|
|
res_norm_type="IN"
|
|
)
|
|
|
|
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
|
|
max_down_sampling_multiple, padding_mode, activation_type,
|
|
norm_type="NONE")
|
|
|
|
content_channels = base_channels * (2 ** max_down_sampling_multiple)
|
|
|
|
self.fusion = MLPFusion(style_dim, decoder_num_residual_blocks * 2 * content_channels * 2,
|
|
num_mlp_base_feature, num_mlp_blocks, activation_type,
|
|
norm_type="NONE")
|
|
|
|
self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
|
|
activation_type=activation_type, padding_mode=padding_mode,
|
|
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
|
res_norm_type="AdaIN")
|
|
|
|
def encode(self, x):
|
|
return self.content_encoder(x), self.style_encoder(x)
|
|
|
|
def decode(self, content, style):
|
|
as_param_style = torch.chunk(self.fusion(style), 2 * len(self.decoder.residual_blocks), dim=1)
|
|
# set style for decoder
|
|
for i, blk in enumerate(self.decoder.residual_blocks):
|
|
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
|
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
|
self.decoder(content)
|
|
|
|
def forward(self, x):
|
|
content, style = self.encode(x)
|
|
return self.decode(content, style)
|