172 lines
8.4 KiB
Python
172 lines
8.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torchvision.models import vgg19
|
|
|
|
from model.normalization import select_norm_layer
|
|
from model.registry import MODEL
|
|
from .MUNIT import ContentEncoder, Fusion, Decoder, StyleEncoder
|
|
from .base import ResBlock
|
|
|
|
|
|
class VGG19StyleEncoder(nn.Module):
|
|
def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE",
|
|
vgg19_layers=(0, 5, 10, 19), fix_vgg19=True):
|
|
super().__init__()
|
|
self.vgg19_layers = vgg19_layers
|
|
self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1]
|
|
self.vgg19.requires_grad_(not fix_vgg19)
|
|
|
|
norm_layer = select_norm_layer(norm_type)
|
|
|
|
self.conv0 = nn.Sequential(
|
|
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
|
bias=True),
|
|
norm_layer(base_channels),
|
|
nn.ReLU(True),
|
|
)
|
|
self.conv = nn.ModuleList([
|
|
nn.Sequential(
|
|
nn.Conv2d(base_channels * (2 ** i), base_channels * (2 ** i), kernel_size=4, stride=2, padding=1,
|
|
padding_mode=padding_mode, bias=True),
|
|
norm_layer(base_channels),
|
|
nn.ReLU(True),
|
|
) for i in range(1, 4)
|
|
])
|
|
self.pool = nn.AdaptiveAvgPool2d(1)
|
|
self.conv1x1 = nn.Conv2d(base_channels * (2 ** 4), style_dim, kernel_size=1, stride=1, padding=0)
|
|
|
|
def fixed_style_features(self, x):
|
|
features = []
|
|
for i in range(len(self.vgg19)):
|
|
x = self.vgg19[i](x)
|
|
if i in self.vgg19_layers:
|
|
features.append(x)
|
|
return features
|
|
|
|
def forward(self, x):
|
|
fsf = self.fixed_style_features(x)
|
|
x = self.conv0(x)
|
|
for i, l in enumerate(self.conv):
|
|
x = l(torch.cat([x, fsf[i]], dim=1))
|
|
x = self.pool(torch.cat([x, fsf[-1]], dim=1))
|
|
x = self.conv1x1(x)
|
|
return x.view(x.size(0), -1)
|
|
|
|
|
|
@MODEL.register_module("TAFG-ResGenerator")
|
|
class ResGenerator(nn.Module):
|
|
def __init__(self, in_channels, out_channels=3, use_spectral_norm=False, num_res_blocks=8, base_channels=64):
|
|
super().__init__()
|
|
self.content_encoder = ContentEncoder(in_channels, 2, num_res_blocks=num_res_blocks,
|
|
use_spectral_norm=use_spectral_norm)
|
|
resnet_channels = 2 ** 2 * base_channels
|
|
self.decoder = Decoder(resnet_channels, out_channels, 2,
|
|
0, use_spectral_norm, "IN", norm_type="LN", padding_mode="reflect")
|
|
|
|
def forward(self, x):
|
|
return self.decoder(self.content_encoder(x))
|
|
|
|
|
|
@MODEL.register_module("TAFG-SingleGenerator")
|
|
class SingleGenerator(nn.Module):
|
|
def __init__(self, style_in_channels, content_in_channels, out_channels=3, use_spectral_norm=False,
|
|
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
|
|
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
|
|
super().__init__()
|
|
self.num_adain_blocks = num_adain_blocks
|
|
if style_encoder_type == "StyleEncoder":
|
|
self.style_encoder = StyleEncoder(
|
|
style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
|
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"
|
|
)
|
|
elif style_encoder_type == "VGG19StyleEncoder":
|
|
self.style_encoder = VGG19StyleEncoder(
|
|
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE"
|
|
)
|
|
else:
|
|
raise NotImplemented(f"do not support {style_encoder_type}")
|
|
|
|
resnet_channels = 2 ** 2 * base_channels
|
|
self.style_converter = Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256,
|
|
n_blocks=3, norm_type="NONE")
|
|
self.content_encoder = ContentEncoder(content_in_channels, 2, num_res_blocks=num_res_blocks,
|
|
use_spectral_norm=use_spectral_norm)
|
|
|
|
self.decoder = Decoder(resnet_channels, out_channels, 2,
|
|
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode)
|
|
|
|
def forward(self, content_img, style_img):
|
|
content = self.content_encoder(content_img)
|
|
style = self.style_encoder(style_img)
|
|
as_param_style = torch.chunk(self.style_converter(style), self.num_adain_blocks * 2, dim=1)
|
|
# set style for decoder
|
|
for i, blk in enumerate(self.decoder.res_blocks):
|
|
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
|
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
|
return self.decoder(content)
|
|
|
|
|
|
@MODEL.register_module("TAFG-Generator")
|
|
class Generator(nn.Module):
|
|
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
|
|
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
|
|
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
|
|
super(Generator, self).__init__()
|
|
self.num_adain_blocks = num_adain_blocks
|
|
if style_encoder_type == "StyleEncoder":
|
|
self.style_encoders = nn.ModuleDict(dict(
|
|
a=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
|
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
|
|
b=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
|
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
|
|
))
|
|
elif style_encoder_type == "VGG19StyleEncoder":
|
|
self.style_encoders = nn.ModuleDict(dict(
|
|
a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
|
norm_type="NONE"),
|
|
b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
|
norm_type="NONE", fix_vgg19=False)
|
|
))
|
|
else:
|
|
raise NotImplemented(f"do not support {style_encoder_type}")
|
|
resnet_channels = 2 ** 2 * base_channels
|
|
self.style_converters = nn.ModuleDict(dict(
|
|
a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
|
|
norm_type="NONE"),
|
|
b=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
|
|
norm_type="NONE"),
|
|
))
|
|
self.content_encoders = nn.ModuleDict({
|
|
"a": ContentEncoder(content_in_channels, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm),
|
|
"b": ContentEncoder(1, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm)
|
|
})
|
|
|
|
self.content_resnet = nn.Sequential(*[
|
|
ResBlock(resnet_channels, use_spectral_norm, padding_mode, "IN")
|
|
for _ in range(num_res_blocks)
|
|
])
|
|
self.decoders = nn.ModuleDict(dict(
|
|
a=Decoder(resnet_channels, out_channels, 2,
|
|
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
|
|
b=Decoder(resnet_channels, out_channels, 2,
|
|
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
|
|
))
|
|
|
|
def encode(self, content_img, style_img, which_content, which_style):
|
|
content = self.content_resnet(self.content_encoders[which_content](content_img))
|
|
style = self.style_encoders[which_style](style_img)
|
|
return content, style
|
|
|
|
def decode(self, content, style, which):
|
|
decoder = self.decoders[which]
|
|
as_param_style = torch.chunk(self.style_converters[which](style), self.num_adain_blocks * 2, dim=1)
|
|
# set style for decoder
|
|
for i, blk in enumerate(decoder.res_blocks):
|
|
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
|
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
|
return decoder(content)
|
|
|
|
def forward(self, content_img, style_img, which_content, which_style):
|
|
content, style = self.encode(content_img, style_img, which_content, which_style)
|
|
return self.decode(content, style, which_style)
|