import torch.nn as nn import torch.nn.functional as F import torch from model import MODEL from model.base.module import ResidualBlock, Conv2dBlock class Interpolation(nn.Module): def __init__(self, scale_factor=None, mode='nearest', align_corners=None): super(Interpolation, self).__init__() self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners def forward(self, x): return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, recompute_scale_factor=False) def __repr__(self): return f"Interpolation(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})" @MODEL.register_module("TSIT-Generator") class Generator(nn.Module): def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=7, padding_mode='reflect', activation_type="LeakyReLU"): super().__init__() self.input_layer = Conv2dBlock( in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, activation_type=activation_type, norm_type="IN", ) multiple_now = 1 stream_sequence = [] for i in range(1, num_blocks + 1): multiple_prev = multiple_now multiple_now = min(2 ** i, 2 ** 4) stream_sequence.append(nn.Sequential( Interpolation(scale_factor=0.5, mode="nearest"), ResidualBlock( multiple_prev * base_channels, out_channels=multiple_now * base_channels, padding_mode=padding_mode, activation_type=activation_type, norm_type="IN") )) self.down_sequence = nn.ModuleList(stream_sequence) sequence = [] multiple_now = 16 for i in range(num_blocks - 1, -1, -1): multiple_prev = multiple_now multiple_now = min(2 ** i, 2 ** 4) sequence.append(nn.Sequential( ResidualBlock( base_channels * multiple_prev, out_channels=base_channels * multiple_now, padding_mode=padding_mode, activation_type=activation_type, norm_type="FADE", pre_activation=True, additional_norm_kwargs=dict( condition_in_channels=base_channels * multiple_prev, base_norm_type="BN", padding_mode="zeros", gamma_bias=0.0 ) ), Interpolation(scale_factor=2, mode="nearest") )) self.up_sequence = nn.Sequential(*sequence) self.output_layer = Conv2dBlock( base_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE" ) def forward(self, x, z=None): c = self.input_layer(x) contents = [] for blk in self.down_sequence: c = blk(c) contents.append(c) if z is None: # for image 256x256, z size: [batch_size, 1024, 2, 2] z = torch.randn(size=contents[-1].size(), device=contents[-1].device) for blk in self.up_sequence: res = blk[0] c = contents.pop() res.conv1.normalization.set_feature(c) res.conv2.normalization.set_feature(c) if res.learn_skip_connection: res.res_conv.normalization.set_feature(c) return self.output_layer(self.up_sequence(z)) if __name__ == '__main__': g = Generator(3, 3).cuda() img = torch.randn(2, 3, 256, 256).cuda() print(g(img).size())