184 lines
8.2 KiB
Python
184 lines
8.2 KiB
Python
from collections import OrderedDict
|
|
from functools import partial
|
|
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
|
|
from model import MODEL
|
|
|
|
class StyleEncoder(nn.Module):
|
|
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
|
|
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
|
|
super().__init__()
|
|
sequence = [Conv2dBlock(
|
|
in_channels, base_channels, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode,
|
|
activation_type=activation_type, norm_type=norm_type
|
|
)]
|
|
multiple_now = 0
|
|
max_multiple = 3
|
|
for i in range(1, num_conv + 1):
|
|
multiple_prev = multiple_now
|
|
multiple_now = min(2 ** i, 2 ** max_multiple)
|
|
sequence.append(Conv2dBlock(
|
|
multiple_prev * base_channels, multiple_now * base_channels,
|
|
kernel_size=3, stride=2, padding=1, padding_mode=padding_mode,
|
|
activation_type=activation_type, norm_type=norm_type
|
|
))
|
|
self.sequence = nn.Sequential(*sequence)
|
|
self.fc_avg = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
|
|
self.fc_var = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
|
|
|
|
def forward(self, x):
|
|
x = self.sequence(x)
|
|
x = x.view(x.size(0), -1)
|
|
return self.fc_avg(x), self.fc_var(x)
|
|
|
|
|
|
class ImprovedSPADEGenerator(nn.Module):
|
|
def __init__(self, in_channels, out_channels, output_size, have_style_input, style_dim, start_size=(4, 4),
|
|
base_channels=64, padding_mode='reflect', activation_type="LeakyReLU", pre_activation=False):
|
|
super().__init__()
|
|
|
|
assert output_size in (128, 256, 512, 1024)
|
|
self.output_size = output_size
|
|
|
|
kernel_size = 3
|
|
|
|
if have_style_input:
|
|
self.style_converter = nn.Sequential(
|
|
LinearBlock(style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
|
|
LinearBlock(2 * style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
|
|
)
|
|
|
|
base_conv = partial(
|
|
Conv2dBlock,
|
|
pre_activation=pre_activation, activation_type=activation_type,
|
|
norm_type="AdaIN" if have_style_input else "NONE",
|
|
kernel_size=kernel_size, padding=(kernel_size - 1) // 2, padding_mode=padding_mode
|
|
)
|
|
|
|
base_residual_block = partial(
|
|
ResidualBlock,
|
|
padding_mode=padding_mode,
|
|
activation_type=activation_type,
|
|
norm_type="SPADE",
|
|
pre_activation=True,
|
|
additional_norm_kwargs=dict(
|
|
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
|
|
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
|
|
)
|
|
)
|
|
|
|
sequence = OrderedDict()
|
|
channels = (2 ** 4) * base_channels
|
|
sequence["block_head"] = nn.Sequential(OrderedDict([
|
|
("conv_input", base_conv(in_channels=in_channels, out_channels=channels)),
|
|
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
|
|
("res_a", base_residual_block(in_channels=channels, out_channels=channels)),
|
|
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
|
|
("up", nn.Upsample(scale_factor=2, mode='nearest'))
|
|
]))
|
|
|
|
for i in range(4, 9 - min(int(math.log(self.output_size, 2)), 8), -1):
|
|
channels = (2 ** (i - 1)) * base_channels
|
|
sequence[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
|
|
("res_a", base_residual_block(in_channels=channels * 2, out_channels=channels)),
|
|
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
|
|
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
|
|
("up", nn.Upsample(scale_factor=2, mode='nearest'))
|
|
]))
|
|
self.sequence = nn.Sequential(sequence)
|
|
# channels = 2*base_channels when output size is 256, 512, 1024
|
|
# channels = 5*base_channels when output size is 128
|
|
out_modules = OrderedDict()
|
|
out_modules["out_1"] = nn.Sequential(
|
|
Conv2dBlock(
|
|
channels, out_channels, kernel_size=5, stride=1, padding=2,
|
|
pre_activation=pre_activation,
|
|
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
|
|
),
|
|
nn.Tanh()
|
|
)
|
|
for i in range(int(math.log(self.output_size, 2)) - 8):
|
|
channels = channels // 2
|
|
out_modules[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
|
|
("res_a", base_residual_block(in_channels=2 * channels, out_channels=channels)),
|
|
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
|
|
("up", nn.Upsample(scale_factor=2, mode='nearest'))
|
|
]))
|
|
out_modules[f"out_{i + 2}"] = nn.Sequential(
|
|
Conv2dBlock(
|
|
channels, out_channels, kernel_size=5, stride=1, padding=2,
|
|
pre_activation=pre_activation,
|
|
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
|
|
),
|
|
nn.Tanh()
|
|
)
|
|
self.out_modules = nn.ModuleDict(out_modules)
|
|
|
|
def forward(self, seg, style=None):
|
|
pass
|
|
|
|
@MODEL.register_module()
|
|
class SPADEGenerator(nn.Module):
|
|
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
|
|
padding_mode='reflect', activation_type="LeakyReLU"):
|
|
super().__init__()
|
|
self.sx, self.sy = start_size
|
|
self.use_vae = use_vae
|
|
self.num_z_dim = num_z_dim
|
|
if use_vae:
|
|
self.input_converter = nn.Linear(num_z_dim, 16 * base_channels * self.sx * self.sy)
|
|
else:
|
|
self.input_converter = nn.Conv2d(in_channels, 16 * base_channels, kernel_size=3, padding=1)
|
|
|
|
sequence = []
|
|
|
|
multiple_now = 16
|
|
for i in range(num_blocks - 1, -1, -1):
|
|
multiple_prev = multiple_now
|
|
multiple_now = min(2 ** i, 2 ** 4)
|
|
if i != num_blocks - 1:
|
|
sequence.append(nn.Upsample(scale_factor=2))
|
|
sequence.append(ResidualBlock(
|
|
base_channels * multiple_prev,
|
|
out_channels=base_channels * multiple_now,
|
|
padding_mode=padding_mode,
|
|
activation_type=activation_type,
|
|
norm_type="SPADE",
|
|
pre_activation=True,
|
|
additional_norm_kwargs=dict(
|
|
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
|
|
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
|
|
)
|
|
))
|
|
self.sequence = nn.Sequential(*sequence)
|
|
self.output_converter = Conv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
|
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")
|
|
|
|
def forward(self, seg, z=None):
|
|
if self.use_vae:
|
|
if z is None:
|
|
z = torch.randn(seg.size(0), self.num_z_dim, device=seg.device)
|
|
x = self.input_converter(z).view(seg.size(0), -1, self.sx, self.sy)
|
|
else:
|
|
x = self.input_converter(F.interpolate(seg, size=(self.sx, self.sy)))
|
|
for blk in self.sequence:
|
|
if isinstance(blk, ResidualBlock):
|
|
downsampling_seg = F.interpolate(seg, size=x.size()[2:], mode='nearest')
|
|
blk.conv1.normalization.set_condition_image(downsampling_seg)
|
|
blk.conv2.normalization.set_condition_image(downsampling_seg)
|
|
if blk.learn_skip_connection:
|
|
blk.res_conv.normalization.set_condition_image(downsampling_seg)
|
|
x = blk(x)
|
|
return self.output_converter(x)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
g = SPADEGenerator(3, 3, 7, False, 256)
|
|
print(g)
|
|
print(g(torch.randn(2, 3, 256, 256)).size())
|