89 lines
3.8 KiB
Python
89 lines
3.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from model import MODEL
|
|
from model.base.module import Conv2dBlock, ResidualBlock, ReverseResidualBlock
|
|
|
|
|
|
class Interpolation(nn.Module):
|
|
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
|
|
super(Interpolation, self).__init__()
|
|
self.scale_factor = scale_factor
|
|
self.mode = mode
|
|
self.align_corners = align_corners
|
|
|
|
def forward(self, x):
|
|
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
|
|
recompute_scale_factor=False)
|
|
|
|
def __repr__(self):
|
|
return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
|
|
|
|
|
|
@MODEL.register_module("TSIT-Generator")
|
|
class Generator(nn.Module):
|
|
def __init__(self, content_in_channels=3, out_channels=3, base_channels=64, num_blocks=7,
|
|
padding_mode="reflect", activation_type="ReLU"):
|
|
super().__init__()
|
|
self.num_blocks = num_blocks
|
|
self.base_channels = base_channels
|
|
|
|
self.content_stream = self.build_stream(padding_mode, activation_type)
|
|
self.start_conv = Conv2dBlock(content_in_channels, base_channels, activation_type=activation_type,
|
|
norm_type="IN", kernel_size=7, padding_mode=padding_mode, padding=3)
|
|
|
|
sequence = []
|
|
multiple_now = min(2 ** self.num_blocks, 2 ** 4)
|
|
for i in range(1, self.num_blocks + 1):
|
|
m = self.num_blocks - i
|
|
multiple_prev = multiple_now
|
|
multiple_now = min(2 ** m, 2 ** 4)
|
|
sequence.append(nn.Sequential(
|
|
ReverseResidualBlock(
|
|
multiple_prev * base_channels, multiple_now * base_channels,
|
|
padding_mode=padding_mode, norm_type="FADE",
|
|
additional_norm_kwargs=dict(
|
|
condition_in_channels=multiple_prev * base_channels,
|
|
base_norm_type="BN",
|
|
padding_mode=padding_mode
|
|
)
|
|
),
|
|
Interpolation(2, mode="nearest")
|
|
))
|
|
self.generator = nn.Sequential(*sequence)
|
|
self.end_conv = Conv2dBlock(base_channels, out_channels, activation_type="Tanh",
|
|
kernel_size=7, padding_mode=padding_mode, padding=3)
|
|
|
|
def build_stream(self, padding_mode, activation_type):
|
|
multiple_now = 1
|
|
stream_sequence = []
|
|
for i in range(1, self.num_blocks + 1):
|
|
multiple_prev = multiple_now
|
|
multiple_now = min(2 ** i, 2 ** 4)
|
|
stream_sequence.append(nn.Sequential(
|
|
Interpolation(scale_factor=0.5, mode="nearest"),
|
|
ResidualBlock(
|
|
multiple_prev * self.base_channels, multiple_now * self.base_channels,
|
|
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
|
|
))
|
|
return nn.ModuleList(stream_sequence)
|
|
|
|
def forward(self, content, z=None):
|
|
c = self.start_conv(content)
|
|
content_features = []
|
|
for i in range(self.num_blocks):
|
|
c = self.content_stream[i](c)
|
|
content_features.append(c)
|
|
if z is None:
|
|
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
|
|
|
|
for i in range(self.num_blocks):
|
|
m = - i - 1
|
|
res_block = self.generator[i][0]
|
|
res_block.conv1.normalization.set_feature(content_features[m])
|
|
res_block.conv2.normalization.set_feature(content_features[m])
|
|
if res_block.learn_skip_connection:
|
|
res_block.res_conv.normalization.set_feature(content_features[m])
|
|
return self.end_conv(self.generator(z))
|