raycv/model/image_translation/TSIT.py
2020-10-25 20:46:34 +08:00

99 lines
3.7 KiB
Python

import torch.nn as nn
import torch.nn.functional as F
import torch
from model import MODEL
from model.base.module import ResidualBlock, Conv2dBlock
class Interpolation(nn.Module):
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolation, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
recompute_scale_factor=False)
def __repr__(self):
return f"Interpolation(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
@MODEL.register_module("TSIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=7,
padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.input_layer = Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type="IN",
)
multiple_now = 1
stream_sequence = []
for i in range(1, num_blocks + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
stream_sequence.append(nn.Sequential(
Interpolation(scale_factor=0.5, mode="nearest"),
ResidualBlock(
multiple_prev * base_channels, out_channels=multiple_now * base_channels,
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
))
self.down_sequence = nn.ModuleList(stream_sequence)
sequence = []
multiple_now = 16
for i in range(num_blocks - 1, -1, -1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
sequence.append(nn.Sequential(
ResidualBlock(
base_channels * multiple_prev,
out_channels=base_channels * multiple_now,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type="FADE",
pre_activation=True,
additional_norm_kwargs=dict(
condition_in_channels=base_channels * multiple_prev, base_norm_type="BN",
padding_mode="zeros", gamma_bias=0.0
)
),
Interpolation(scale_factor=2, mode="nearest")
))
self.up_sequence = nn.Sequential(*sequence)
self.output_layer = Conv2dBlock(
base_channels, out_channels, kernel_size=3, stride=1, padding=1,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"
)
def forward(self, x, z=None):
c = self.input_layer(x)
contents = []
for blk in self.down_sequence:
c = blk(c)
contents.append(c)
if z is None:
# for image 256x256, z size: [batch_size, 1024, 2, 2]
z = torch.randn(size=contents[-1].size(), device=contents[-1].device)
for blk in self.up_sequence:
res = blk[0]
c = contents.pop()
res.conv1.normalization.set_feature(c)
res.conv2.normalization.set_feature(c)
if res.learn_skip_connection:
res.res_conv.normalization.set_feature(c)
return self.output_layer(self.up_sequence(z))
if __name__ == '__main__':
g = Generator(3, 3).cuda()
img = torch.randn(2, 3, 256, 256).cuda()
print(g(img).size())