import torch import torch.nn as nn import torch.nn.functional as F from model import MODEL from model.base.module import Conv2dBlock, ResidualBlock, ReverseResidualBlock class Interpolation(nn.Module): def __init__(self, scale_factor=None, mode='nearest', align_corners=None): super(Interpolation, self).__init__() self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners def forward(self, x): return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, recompute_scale_factor=False) def __repr__(self): return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})" @MODEL.register_module("TSIT-Generator") class Generator(nn.Module): def __init__(self, content_in_channels=3, out_channels=3, base_channels=64, num_blocks=7, padding_mode="reflect", activation_type="ReLU"): super().__init__() self.num_blocks = num_blocks self.base_channels = base_channels self.content_stream = self.build_stream(padding_mode, activation_type) self.start_conv = Conv2dBlock(content_in_channels, base_channels, activation_type=activation_type, norm_type="IN", kernel_size=7, padding_mode=padding_mode, padding=3) sequence = [] multiple_now = min(2 ** self.num_blocks, 2 ** 4) for i in range(1, self.num_blocks + 1): m = self.num_blocks - i multiple_prev = multiple_now multiple_now = min(2 ** m, 2 ** 4) sequence.append(nn.Sequential( ReverseResidualBlock( multiple_prev * base_channels, multiple_now * base_channels, padding_mode=padding_mode, norm_type="FADE", additional_norm_kwargs=dict( condition_in_channels=multiple_prev * base_channels, base_norm_type="BN", padding_mode=padding_mode ) ), Interpolation(2, mode="nearest") )) self.generator = nn.Sequential(*sequence) self.end_conv = Conv2dBlock(base_channels, out_channels, activation_type="Tanh", kernel_size=7, padding_mode=padding_mode, padding=3) def build_stream(self, padding_mode, activation_type): multiple_now = 1 stream_sequence = [] for i in range(1, self.num_blocks + 1): multiple_prev = multiple_now multiple_now = min(2 ** i, 2 ** 4) stream_sequence.append(nn.Sequential( Interpolation(scale_factor=0.5, mode="nearest"), ResidualBlock( multiple_prev * self.base_channels, multiple_now * self.base_channels, padding_mode=padding_mode, activation_type=activation_type, norm_type="IN") )) return nn.ModuleList(stream_sequence) def forward(self, content, z=None): c = self.start_conv(content) content_features = [] for i in range(self.num_blocks): c = self.content_stream[i](c) content_features.append(c) if z is None: z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device) for i in range(self.num_blocks): m = - i - 1 res_block = self.generator[i][0] res_block.conv1.normalization.set_feature(content_features[m]) res_block.conv2.normalization.set_feature(content_features[m]) if res_block.learn_skip_connection: res_block.res_conv.normalization.set_feature(content_features[m]) return self.end_conv(self.generator(z))