use base module rewrite TSIT
This commit is contained in:
parent
16f18ab2e2
commit
f67bcdf161
@ -3,44 +3,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from model import MODEL
|
from model import MODEL
|
||||||
from model.normalization import select_norm_layer
|
from model.base.module import Conv2dBlock, ResidualBlock, ReverseResidualBlock
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, padding_mode='zeros', norm_type="IN", use_bias=None,
|
|
||||||
use_spectral=True):
|
|
||||||
super().__init__()
|
|
||||||
self.padding_mode = padding_mode
|
|
||||||
self.use_bias = use_bias
|
|
||||||
self.use_spectral = use_spectral
|
|
||||||
if use_bias is None:
|
|
||||||
# Only for IN, use bias since it does not have affine parameters.
|
|
||||||
self.use_bias = norm_type == "IN"
|
|
||||||
norm_layer = select_norm_layer(norm_type)
|
|
||||||
self.main = nn.Sequential(
|
|
||||||
self.conv_block(in_channels, in_channels),
|
|
||||||
norm_layer(in_channels),
|
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
self.conv_block(in_channels, out_channels),
|
|
||||||
norm_layer(out_channels),
|
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
)
|
|
||||||
self.skip = nn.Sequential(
|
|
||||||
self.conv_block(in_channels, out_channels, padding=0, kernel_size=1),
|
|
||||||
norm_layer(out_channels),
|
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
def conv_block(self, in_channels, out_channels, kernel_size=3, padding=1):
|
|
||||||
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding,
|
|
||||||
padding_mode=self.padding_mode, bias=self.use_bias)
|
|
||||||
if self.use_spectral:
|
|
||||||
return nn.utils.spectral_norm(conv)
|
|
||||||
else:
|
|
||||||
return conv
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.main(x) + self.skip(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Interpolation(nn.Module):
|
class Interpolation(nn.Module):
|
||||||
@ -58,104 +21,41 @@ class Interpolation(nn.Module):
|
|||||||
return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
|
return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
|
||||||
|
|
||||||
|
|
||||||
class FADE(nn.Module):
|
|
||||||
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
|
|
||||||
super().__init__()
|
|
||||||
# self.norm = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
|
|
||||||
self.norm = nn.InstanceNorm2d(num_features=in_channels)
|
|
||||||
|
|
||||||
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
|
||||||
padding_mode="zeros")
|
|
||||||
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
|
||||||
padding_mode="zeros")
|
|
||||||
|
|
||||||
def forward(self, x, feature):
|
|
||||||
alpha = self.alpha_conv(feature)
|
|
||||||
beta = self.beta_conv(feature)
|
|
||||||
x = self.norm(x)
|
|
||||||
return alpha * x + beta
|
|
||||||
|
|
||||||
|
|
||||||
class FADEResBlock(nn.Module):
|
|
||||||
def __init__(self, use_spectral, features_channels, in_channels, out_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.main = nn.Sequential(
|
|
||||||
FADE(use_spectral, features_channels, in_channels),
|
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
conv_block(use_spectral, in_channels, in_channels, kernel_size=3, padding=1),
|
|
||||||
FADE(use_spectral, features_channels, in_channels),
|
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
conv_block(use_spectral, in_channels, out_channels, kernel_size=3, padding=1),
|
|
||||||
)
|
|
||||||
self.skip = nn.Sequential(
|
|
||||||
FADE(use_spectral, features_channels, in_channels),
|
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
conv_block(use_spectral, in_channels, out_channels, kernel_size=1, padding=0),
|
|
||||||
)
|
|
||||||
self.up_sample = Interpolation(2, mode="nearest")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward_with_fade(module, x, feature):
|
|
||||||
for layer in module:
|
|
||||||
if layer.__class__.__name__ == "FADE":
|
|
||||||
x = layer(x, feature)
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x, feature):
|
|
||||||
out = self.forward_with_fade(self.main, x, feature) + self.forward_with_fade(self.main, x, feature)
|
|
||||||
return self.up_sample(out)
|
|
||||||
|
|
||||||
|
|
||||||
def conv_block(use_spectral, in_channels, out_channels, **kwargs):
|
|
||||||
conv = nn.Conv2d(in_channels, out_channels, **kwargs)
|
|
||||||
return nn.utils.spectral_norm(conv) if use_spectral else conv
|
|
||||||
|
|
||||||
|
|
||||||
@MODEL.register_module("TSIT-Generator")
|
@MODEL.register_module("TSIT-Generator")
|
||||||
class TSITGenerator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, num_blocks=7, base_channels=64, content_in_channels=3, style_in_channels=3,
|
def __init__(self, content_in_channels=3, out_channels=3, base_channels=64, num_blocks=7,
|
||||||
out_channels=3, use_spectral=True, input_layer_type="conv1x1"):
|
padding_mode="reflect", activation_type="ReLU"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.base_channels = base_channels
|
self.base_channels = base_channels
|
||||||
self.use_spectral = use_spectral
|
|
||||||
|
|
||||||
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
|
self.content_stream = self.build_stream(padding_mode, activation_type)
|
||||||
self.content_stream = self.build_stream()
|
self.start_conv = Conv2dBlock(content_in_channels, base_channels, activation_type=activation_type,
|
||||||
self.generator = self.build_generator()
|
norm_type="IN", kernel_size=7, padding_mode=padding_mode, padding=3)
|
||||||
self.end_conv = nn.Sequential(
|
|
||||||
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
|
|
||||||
nn.Tanh()
|
|
||||||
)
|
|
||||||
|
|
||||||
def build_generator(self):
|
sequence = []
|
||||||
stream_sequence = []
|
|
||||||
multiple_now = min(2 ** self.num_blocks, 2 ** 4)
|
multiple_now = min(2 ** self.num_blocks, 2 ** 4)
|
||||||
for i in range(1, self.num_blocks + 1):
|
for i in range(1, self.num_blocks + 1):
|
||||||
m = self.num_blocks - i
|
m = self.num_blocks - i
|
||||||
multiple_prev = multiple_now
|
multiple_prev = multiple_now
|
||||||
multiple_now = min(2 ** m, 2 ** 4)
|
multiple_now = min(2 ** m, 2 ** 4)
|
||||||
stream_sequence.append(
|
sequence.append(nn.Sequential(
|
||||||
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
|
ReverseResidualBlock(
|
||||||
multiple_now * self.base_channels))
|
multiple_prev * base_channels, multiple_now * base_channels,
|
||||||
return nn.ModuleList(stream_sequence)
|
padding_mode=padding_mode, norm_type="FADE",
|
||||||
|
additional_norm_kwargs=dict(
|
||||||
|
condition_in_channels=multiple_prev * base_channels,
|
||||||
|
base_norm_type="BN",
|
||||||
|
padding_mode=padding_mode
|
||||||
|
)
|
||||||
|
),
|
||||||
|
Interpolation(2, mode="nearest")
|
||||||
|
))
|
||||||
|
self.generator = nn.Sequential(*sequence)
|
||||||
|
self.end_conv = Conv2dBlock(base_channels, out_channels, activation_type="Tanh",
|
||||||
|
kernel_size=7, padding_mode=padding_mode, padding=3)
|
||||||
|
|
||||||
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
|
def build_stream(self, padding_mode, activation_type):
|
||||||
if input_layer_type == "conv7x7":
|
|
||||||
return nn.Sequential(
|
|
||||||
conv_block(self.use_spectral, in_channels, out_channels, kernel_size=7, stride=1,
|
|
||||||
padding_mode="zeros", padding=3, bias=True),
|
|
||||||
select_norm_layer("IN")(out_channels),
|
|
||||||
nn.ReLU(inplace=True)
|
|
||||||
)
|
|
||||||
elif input_layer_type == "conv1x1":
|
|
||||||
return conv_block(self.use_spectral, in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|
||||||
else:
|
|
||||||
raise NotImplemented
|
|
||||||
|
|
||||||
def build_stream(self):
|
|
||||||
multiple_now = 1
|
multiple_now = 1
|
||||||
stream_sequence = []
|
stream_sequence = []
|
||||||
for i in range(1, self.num_blocks + 1):
|
for i in range(1, self.num_blocks + 1):
|
||||||
@ -163,21 +63,26 @@ class TSITGenerator(nn.Module):
|
|||||||
multiple_now = min(2 ** i, 2 ** 4)
|
multiple_now = min(2 ** i, 2 ** 4)
|
||||||
stream_sequence.append(nn.Sequential(
|
stream_sequence.append(nn.Sequential(
|
||||||
Interpolation(scale_factor=0.5, mode="nearest"),
|
Interpolation(scale_factor=0.5, mode="nearest"),
|
||||||
ResBlock(multiple_prev * self.base_channels, multiple_now * self.base_channels,
|
ResidualBlock(
|
||||||
use_spectral=self.use_spectral)
|
multiple_prev * self.base_channels, multiple_now * self.base_channels,
|
||||||
|
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
|
||||||
))
|
))
|
||||||
return nn.ModuleList(stream_sequence)
|
return nn.ModuleList(stream_sequence)
|
||||||
|
|
||||||
def forward(self, content_img):
|
def forward(self, content, z=None):
|
||||||
c = self.content_input_layer(content_img)
|
c = self.start_conv(content)
|
||||||
content_features = []
|
content_features = []
|
||||||
for i in range(self.num_blocks):
|
for i in range(self.num_blocks):
|
||||||
c = self.content_stream[i](c)
|
c = self.content_stream[i](c)
|
||||||
content_features.append(c)
|
content_features.append(c)
|
||||||
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
|
if z is None:
|
||||||
|
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
|
||||||
|
|
||||||
for i in range(self.num_blocks):
|
for i in range(self.num_blocks):
|
||||||
m = - i - 1
|
m = - i - 1
|
||||||
layer = self.generator[i]
|
res_block = self.generator[i][0]
|
||||||
z = layer(z, content_features[m])
|
res_block.conv1.normalization.set_feature(content_features[m])
|
||||||
return self.end_conv(z)
|
res_block.conv2.normalization.set_feature(content_features[m])
|
||||||
|
if res_block.learn_skip_connection:
|
||||||
|
res_block.res_conv.normalization.set_feature(content_features[m])
|
||||||
|
return self.end_conv(self.generator(z))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user