from collections import OrderedDict from functools import partial import math import torch import torch.nn as nn import torch.nn.functional as F from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock from model import MODEL class StyleEncoder(nn.Module): def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64, norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"): super().__init__() sequence = [Conv2dBlock( in_channels, base_channels, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode, activation_type=activation_type, norm_type=norm_type )] multiple_now = 0 max_multiple = 3 for i in range(1, num_conv + 1): multiple_prev = multiple_now multiple_now = min(2 ** i, 2 ** max_multiple) sequence.append(Conv2dBlock( multiple_prev * base_channels, multiple_now * base_channels, kernel_size=3, stride=2, padding=1, padding_mode=padding_mode, activation_type=activation_type, norm_type=norm_type )) self.sequence = nn.Sequential(*sequence) self.fc_avg = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim) self.fc_var = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim) def forward(self, x): x = self.sequence(x) x = x.view(x.size(0), -1) return self.fc_avg(x), self.fc_var(x) class ImprovedSPADEGenerator(nn.Module): def __init__(self, in_channels, out_channels, output_size, have_style_input, style_dim, start_size=(4, 4), base_channels=64, padding_mode='reflect', activation_type="LeakyReLU", pre_activation=False): super().__init__() assert output_size in (128, 256, 512, 1024) self.output_size = output_size kernel_size = 3 if have_style_input: self.style_converter = nn.Sequential( LinearBlock(style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"), LinearBlock(2 * style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"), ) base_conv = partial( Conv2dBlock, pre_activation=pre_activation, activation_type=activation_type, norm_type="AdaIN" if have_style_input else "NONE", kernel_size=kernel_size, padding=(kernel_size - 1) // 2, padding_mode=padding_mode ) base_residual_block = partial( ResidualBlock, padding_mode=padding_mode, activation_type=activation_type, norm_type="SPADE", pre_activation=True, additional_norm_kwargs=dict( condition_in_channels=in_channels, base_channels=128, base_norm_type="BN", activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0 ) ) sequence = OrderedDict() channels = (2 ** 4) * base_channels sequence["block_head"] = nn.Sequential(OrderedDict([ ("conv_input", base_conv(in_channels=in_channels, out_channels=channels)), ("conv_style", base_conv(in_channels=channels, out_channels=channels)), ("res_a", base_residual_block(in_channels=channels, out_channels=channels)), ("res_b", base_residual_block(in_channels=channels, out_channels=channels)), ("up", nn.Upsample(scale_factor=2, mode='nearest')) ])) for i in range(4, 9 - min(int(math.log(self.output_size, 2)), 8), -1): channels = (2 ** (i - 1)) * base_channels sequence[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([ ("res_a", base_residual_block(in_channels=channels * 2, out_channels=channels)), ("conv_style", base_conv(in_channels=channels, out_channels=channels)), ("res_b", base_residual_block(in_channels=channels, out_channels=channels)), ("up", nn.Upsample(scale_factor=2, mode='nearest')) ])) self.sequence = nn.Sequential(sequence) # channels = 2*base_channels when output size is 256, 512, 1024 # channels = 5*base_channels when output size is 128 out_modules = OrderedDict() out_modules["out_1"] = nn.Sequential( Conv2dBlock( channels, out_channels, kernel_size=5, stride=1, padding=2, pre_activation=pre_activation, padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE" ), nn.Tanh() ) for i in range(int(math.log(self.output_size, 2)) - 8): channels = channels // 2 out_modules[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([ ("res_a", base_residual_block(in_channels=2 * channels, out_channels=channels)), ("res_b", base_residual_block(in_channels=channels, out_channels=channels)), ("up", nn.Upsample(scale_factor=2, mode='nearest')) ])) out_modules[f"out_{i + 2}"] = nn.Sequential( Conv2dBlock( channels, out_channels, kernel_size=5, stride=1, padding=2, pre_activation=pre_activation, padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE" ), nn.Tanh() ) self.out_modules = nn.ModuleDict(out_modules) def forward(self, seg, style=None): pass @MODEL.register_module() class SPADEGenerator(nn.Module): def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64, padding_mode='reflect', activation_type="LeakyReLU"): super().__init__() self.sx, self.sy = start_size self.use_vae = use_vae self.num_z_dim = num_z_dim if use_vae: self.input_converter = nn.Linear(num_z_dim, 16 * base_channels * self.sx * self.sy) else: self.input_converter = nn.Conv2d(in_channels, 16 * base_channels, kernel_size=3, padding=1) sequence = [] multiple_now = 16 for i in range(num_blocks - 1, -1, -1): multiple_prev = multiple_now multiple_now = min(2 ** i, 2 ** 4) if i != num_blocks - 1: sequence.append(nn.Upsample(scale_factor=2)) sequence.append(ResidualBlock( base_channels * multiple_prev, out_channels=base_channels * multiple_now, padding_mode=padding_mode, activation_type=activation_type, norm_type="SPADE", pre_activation=True, additional_norm_kwargs=dict( condition_in_channels=in_channels, base_channels=128, base_norm_type="BN", activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0 ) )) self.sequence = nn.Sequential(*sequence) self.output_converter = Conv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE") def forward(self, seg, z=None): if self.use_vae: if z is None: z = torch.randn(seg.size(0), self.num_z_dim, device=seg.device) x = self.input_converter(z).view(seg.size(0), -1, self.sx, self.sy) else: x = self.input_converter(F.interpolate(seg, size=(self.sx, self.sy))) for blk in self.sequence: if isinstance(blk, ResidualBlock): downsampling_seg = F.interpolate(seg, size=x.size()[2:], mode='nearest') blk.conv1.normalization.set_condition_image(downsampling_seg) blk.conv2.normalization.set_condition_image(downsampling_seg) if blk.learn_skip_connection: blk.res_conv.normalization.set_condition_image(downsampling_seg) x = blk(x) return self.output_converter(x) if __name__ == '__main__': g = SPADEGenerator(3, 3, 7, False, 256) print(g) print(g(torch.randn(2, 3, 256, 256)).size())