99 lines
3.7 KiB
Python
99 lines
3.7 KiB
Python
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch
|
|
|
|
from model import MODEL
|
|
from model.base.module import ResidualBlock, Conv2dBlock
|
|
|
|
|
|
class Interpolation(nn.Module):
|
|
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
|
|
super(Interpolation, self).__init__()
|
|
self.scale_factor = scale_factor
|
|
self.mode = mode
|
|
self.align_corners = align_corners
|
|
|
|
def forward(self, x):
|
|
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
|
|
recompute_scale_factor=False)
|
|
|
|
def __repr__(self):
|
|
return f"Interpolation(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
|
|
|
|
|
|
@MODEL.register_module("TSIT-Generator")
|
|
class Generator(nn.Module):
|
|
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=7,
|
|
padding_mode='reflect', activation_type="LeakyReLU"):
|
|
super().__init__()
|
|
self.input_layer = Conv2dBlock(
|
|
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
|
activation_type=activation_type, norm_type="IN",
|
|
)
|
|
multiple_now = 1
|
|
stream_sequence = []
|
|
for i in range(1, num_blocks + 1):
|
|
multiple_prev = multiple_now
|
|
multiple_now = min(2 ** i, 2 ** 4)
|
|
stream_sequence.append(nn.Sequential(
|
|
Interpolation(scale_factor=0.5, mode="nearest"),
|
|
ResidualBlock(
|
|
multiple_prev * base_channels, out_channels=multiple_now * base_channels,
|
|
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
|
|
))
|
|
self.down_sequence = nn.ModuleList(stream_sequence)
|
|
|
|
|
|
sequence = []
|
|
multiple_now = 16
|
|
for i in range(num_blocks - 1, -1, -1):
|
|
multiple_prev = multiple_now
|
|
multiple_now = min(2 ** i, 2 ** 4)
|
|
sequence.append(nn.Sequential(
|
|
ResidualBlock(
|
|
base_channels * multiple_prev,
|
|
out_channels=base_channels * multiple_now,
|
|
padding_mode=padding_mode,
|
|
activation_type=activation_type,
|
|
norm_type="FADE",
|
|
pre_activation=True,
|
|
additional_norm_kwargs=dict(
|
|
condition_in_channels=base_channels * multiple_prev, base_norm_type="BN",
|
|
padding_mode="zeros", gamma_bias=0.0
|
|
)
|
|
),
|
|
Interpolation(scale_factor=2, mode="nearest")
|
|
))
|
|
self.up_sequence = nn.Sequential(*sequence)
|
|
|
|
self.output_layer = Conv2dBlock(
|
|
base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
|
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"
|
|
)
|
|
|
|
def forward(self, x, z=None):
|
|
c = self.input_layer(x)
|
|
contents = []
|
|
for blk in self.down_sequence:
|
|
c = blk(c)
|
|
contents.append(c)
|
|
if z is None:
|
|
# for image 256x256, z size: [batch_size, 1024, 2, 2]
|
|
z = torch.randn(size=contents[-1].size(), device=contents[-1].device)
|
|
|
|
for blk in self.up_sequence:
|
|
res = blk[0]
|
|
c = contents.pop()
|
|
res.conv1.normalization.set_feature(c)
|
|
res.conv2.normalization.set_feature(c)
|
|
if res.learn_skip_connection:
|
|
res.res_conv.normalization.set_feature(c)
|
|
return self.output_layer(self.up_sequence(z))
|
|
|
|
if __name__ == '__main__':
|
|
g = Generator(3, 3).cuda()
|
|
img = torch.randn(2, 3, 256, 256).cuda()
|
|
print(g(img).size())
|
|
|
|
|