raycv/model/GAN/TSIT.py
2020-09-25 18:31:12 +08:00

184 lines
7.5 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
from model.normalization import select_norm_layer
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, padding_mode='zeros', norm_type="IN", use_bias=None,
use_spectral=True):
super().__init__()
self.padding_mode = padding_mode
self.use_bias = use_bias
self.use_spectral = use_spectral
if use_bias is None:
# Only for IN, use bias since it does not have affine parameters.
self.use_bias = norm_type == "IN"
norm_layer = select_norm_layer(norm_type)
self.main = nn.Sequential(
self.conv_block(in_channels, in_channels),
norm_layer(in_channels),
nn.LeakyReLU(0.2, inplace=True),
self.conv_block(in_channels, out_channels),
norm_layer(out_channels),
nn.LeakyReLU(0.2, inplace=True),
)
self.skip = nn.Sequential(
self.conv_block(in_channels, out_channels, padding=0, kernel_size=1),
norm_layer(out_channels),
nn.LeakyReLU(0.2, inplace=True),
)
def conv_block(self, in_channels, out_channels, kernel_size=3, padding=1):
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding,
padding_mode=self.padding_mode, bias=self.use_bias)
if self.use_spectral:
return nn.utils.spectral_norm(conv)
else:
return conv
def forward(self, x):
return self.main(x) + self.skip(x)
class Interpolation(nn.Module):
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolation, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
recompute_scale_factor=False)
def __repr__(self):
return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
class FADE(nn.Module):
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
super().__init__()
# self.norm = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
self.norm = nn.InstanceNorm2d(num_features=in_channels)
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
padding_mode="zeros")
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
padding_mode="zeros")
def forward(self, x, feature):
alpha = self.alpha_conv(feature)
beta = self.beta_conv(feature)
x = self.norm(x)
return alpha * x + beta
class FADEResBlock(nn.Module):
def __init__(self, use_spectral, features_channels, in_channels, out_channels):
super().__init__()
self.main = nn.Sequential(
FADE(use_spectral, features_channels, in_channels),
nn.LeakyReLU(0.2, inplace=True),
conv_block(use_spectral, in_channels, in_channels, kernel_size=3, padding=1),
FADE(use_spectral, features_channels, in_channels),
nn.LeakyReLU(0.2, inplace=True),
conv_block(use_spectral, in_channels, out_channels, kernel_size=3, padding=1),
)
self.skip = nn.Sequential(
FADE(use_spectral, features_channels, in_channels),
nn.LeakyReLU(0.2, inplace=True),
conv_block(use_spectral, in_channels, out_channels, kernel_size=1, padding=0),
)
self.up_sample = Interpolation(2, mode="nearest")
@staticmethod
def forward_with_fade(module, x, feature):
for layer in module:
if layer.__class__.__name__ == "FADE":
x = layer(x, feature)
else:
x = layer(x)
return x
def forward(self, x, feature):
out = self.forward_with_fade(self.main, x, feature) + self.forward_with_fade(self.main, x, feature)
return self.up_sample(out)
def conv_block(use_spectral, in_channels, out_channels, **kwargs):
conv = nn.Conv2d(in_channels, out_channels, **kwargs)
return nn.utils.spectral_norm(conv) if use_spectral else conv
@MODEL.register_module("TSIT-Generator")
class TSITGenerator(nn.Module):
def __init__(self, num_blocks=7, base_channels=64, content_in_channels=3, style_in_channels=3,
out_channels=3, use_spectral=True, input_layer_type="conv1x1"):
super().__init__()
self.num_blocks = num_blocks
self.base_channels = base_channels
self.use_spectral = use_spectral
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
self.content_stream = self.build_stream()
self.generator = self.build_generator()
self.end_conv = nn.Sequential(
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
nn.Tanh()
)
def build_generator(self):
stream_sequence = []
multiple_now = min(2 ** self.num_blocks, 2 ** 4)
for i in range(1, self.num_blocks + 1):
m = self.num_blocks - i
multiple_prev = multiple_now
multiple_now = min(2 ** m, 2 ** 4)
stream_sequence.append(
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
multiple_now * self.base_channels))
return nn.ModuleList(stream_sequence)
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
if input_layer_type == "conv7x7":
return nn.Sequential(
conv_block(self.use_spectral, in_channels, out_channels, kernel_size=7, stride=1,
padding_mode="zeros", padding=3, bias=True),
select_norm_layer("IN")(out_channels),
nn.ReLU(inplace=True)
)
elif input_layer_type == "conv1x1":
return conv_block(self.use_spectral, in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
raise NotImplemented
def build_stream(self):
multiple_now = 1
stream_sequence = []
for i in range(1, self.num_blocks + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
stream_sequence.append(nn.Sequential(
Interpolation(scale_factor=0.5, mode="nearest"),
ResBlock(multiple_prev * self.base_channels, multiple_now * self.base_channels,
use_spectral=self.use_spectral)
))
return nn.ModuleList(stream_sequence)
def forward(self, content_img):
c = self.content_input_layer(content_img)
content_features = []
for i in range(self.num_blocks):
c = self.content_stream[i](c)
content_features.append(c)
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
for i in range(self.num_blocks):
m = - i - 1
layer = self.generator[i]
z = layer(z, content_features[m])
return self.end_conv(z)