move encoder, decoder to CycleGAN

This commit is contained in:
Ray Wong 2020-10-11 11:09:16 +08:00
parent 04c6366c07
commit 9c08b4cd09
4 changed files with 106 additions and 128 deletions

View File

@ -70,7 +70,7 @@ class Conv2dBlock(nn.Module):
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, num_channels, out_channels=None, padding_mode='reflect', def __init__(self, num_channels, out_channels=None, padding_mode='reflect',
activation_type="ReLU", out_activation_type=None, norm_type="IN"): activation_type="ReLU", norm_type="IN", out_activation_type=None):
super().__init__() super().__init__()
self.norm_type = norm_type self.norm_type = norm_type

View File

@ -0,0 +1,68 @@
import torch.nn as nn
from model.base.module import Conv2dBlock, ResidualBlock
class Encoder(nn.Module):
def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2,
padding_mode='reflect', activation_type="ReLU",
down_conv_norm_type="IN", down_conv_kernel_size=3,
res_norm_type="IN"):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type=down_conv_norm_type
)]
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=down_conv_norm_type
))
self.out_channels = multiple_now * base_channels
sequence += [
ResidualBlock(self.out_channels, self.out_channels, padding_mode, activation_type, norm_type=res_norm_type)
for _ in range(num_res)
]
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
return self.sequence(x)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
activation_type="ReLU", padding_mode='reflect',
up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN"):
super().__init__()
self.residual_blocks = nn.ModuleList([
ResidualBlock(in_channels, in_channels, padding_mode, activation_type, norm_type=res_norm_type)
for _ in range(num_residual_blocks)
])
sequence = list()
channels = in_channels
for i in range(num_up_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2,
kernel_size=up_conv_kernel_size, stride=1,
padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type),
))
channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels,
kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type="Tanh", norm_type="NONE"))
self.up_sequence = nn.Sequential(*sequence)
def forward(self, x):
for i, blk in enumerate(self.residual_blocks):
x = blk(x)
return self.up_sequence(x)

View File

@ -2,99 +2,29 @@ import torch
import torch.nn as nn import torch.nn as nn
from model import MODEL from model import MODEL
from model.base.module import Conv2dBlock, ResidualBlock, LinearBlock from model.base.module import LinearBlock
from model.image_translation.CycleGAN import Encoder, Decoder
def _get_down_sampling_sequence(in_channels, base_channels, num_conv, max_down_sampling_multiple=2,
padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
return sequence, multiple_now * base_channels
class StyleEncoder(nn.Module): class StyleEncoder(nn.Module):
def __init__(self, in_channels, out_dim, num_conv, base_channels=64, def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"): max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
super().__init__() super().__init__()
self.down_encoder = Encoder(
sequence, last_channels = _get_down_sampling_sequence( in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple,
in_channels, base_channels, num_conv, padding_mode=padding_mode, activation_type=activation_type,
max_down_sampling_multiple, padding_mode, activation_type, norm_type down_conv_norm_type=norm_type, down_conv_kernel_size=4,
) )
sequence = list()
sequence.append(nn.AdaptiveAvgPool2d(1)) sequence.append(nn.AdaptiveAvgPool2d(1))
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code # conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
sequence.append(nn.Conv2d(last_channels, out_dim, kernel_size=1, stride=1, padding=0)) sequence.append(nn.Conv2d(self.down_encoder.out_channels, out_dim, kernel_size=1, stride=1, padding=0))
self.sequence = nn.Sequential(*sequence) self.sequence = nn.Sequential(*sequence)
def forward(self, image): def forward(self, image):
return self.sequence(image).view(image.size(0), -1) return self.sequence(image).view(image.size(0), -1)
class ContentEncoder(nn.Module):
def __init__(self, in_channels, num_down_sampling, num_residual_blocks, base_channels=64,
max_down_sampling_multiple=2,
padding_mode='reflect', activation_type="ReLU", norm_type="IN"):
super().__init__()
sequence, last_channels = _get_down_sampling_sequence(
in_channels, base_channels, num_down_sampling,
max_down_sampling_multiple, padding_mode, activation_type, norm_type
)
sequence += [ResidualBlock(last_channels, last_channels, padding_mode, activation_type, norm_type) for _ in
range(num_residual_blocks)]
self.sequence = nn.Sequential(*sequence)
def forward(self, image):
return self.sequence(image)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU", padding_mode='reflect'):
super().__init__()
self.residual_blocks = nn.ModuleList([
ResidualBlock(in_channels, in_channels, padding_mode, activation_type, norm_type=res_norm_type)
for _ in range(num_residual_blocks)
])
sequence = list()
channels = in_channels
for i in range(num_up_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2,
kernel_size=5, stride=1, padding=2, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type),
))
channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels,
kernel_size=7, stride=1, padding=3, padding_mode="reflect",
activation_type="Tanh", norm_type="NONE"))
self.up_sequence = nn.Sequential(*sequence)
def forward(self, x, style):
as_param_style = torch.chunk(style, 2 * len(self.residual_blocks), dim=1)
# set style for decoder
for i, blk in enumerate(self.residual_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
x = blk(x)
return self.up_sequence(x)
class MLPFusion(nn.Module): class MLPFusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"): def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"):
super().__init__() super().__init__()
@ -119,10 +49,13 @@ class Generator(nn.Module):
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4, encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
padding_mode='reflect', activation_type="ReLU"): padding_mode='reflect', activation_type="ReLU"):
super().__init__() super().__init__()
self.content_encoder = ContentEncoder( self.content_encoder = Encoder(
in_channels, num_content_down_sampling, encoder_num_residual_blocks, in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks,
base_channels, max_down_sampling_multiple, max_down_sampling_multiple=num_content_down_sampling,
padding_mode, activation_type, norm_type="IN") padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type="IN", down_conv_kernel_size=4,
res_norm_type="IN"
)
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels, self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
max_down_sampling_multiple, padding_mode, activation_type, max_down_sampling_multiple, padding_mode, activation_type,
@ -134,15 +67,21 @@ class Generator(nn.Module):
num_mlp_base_feature, num_mlp_blocks, activation_type, num_mlp_base_feature, num_mlp_blocks, activation_type,
norm_type="NONE") norm_type="NONE")
self.decoder = Decoder(content_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks, self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
res_norm_type="AdaIN", norm_type="LN", activation_type=activation_type, activation_type=activation_type, padding_mode=padding_mode,
padding_mode=padding_mode) up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN")
def encode(self, x): def encode(self, x):
return self.content_encoder(x), self.style_encoder(x) return self.content_encoder(x), self.style_encoder(x)
def decode(self, content, style): def decode(self, content, style):
self.decoder(content, self.fusion(style)) as_param_style = torch.chunk(self.fusion(style), 2 * len(self.decoder.residual_blocks), dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.residual_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
self.decoder(content)
def forward(self, x): def forward(self, x):
content, style = self.encode(x) content, style = self.encode(x)

View File

@ -2,7 +2,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from model import MODEL from model import MODEL
from model.base.module import Conv2dBlock, ResidualBlock, LinearBlock from model.base.module import Conv2dBlock, LinearBlock
from model.image_translation.CycleGAN import Encoder, Decoder
class RhoClipper(object): class RhoClipper(object):
@ -46,27 +47,11 @@ class Generator(nn.Module):
self.light = light self.light = light
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
n_down_sampling = 2 n_down_sampling = 2
for i in range(n_down_sampling): self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
mult = 2 ** i padding_mode=padding_mode, activation_type=activation_type,
sequence.append(Conv2dBlock( down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type)
base_channels * mult, base_channels * mult * 2,
kernel_size=3, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
mult = 2 ** n_down_sampling mult = 2 ** n_down_sampling
sequence += [
ResidualBlock(base_channels * mult, base_channels * mult, padding_mode, activation_type=activation_type,
norm_type=norm_type)
for _ in range(num_blocks)]
self.encoder = nn.Sequential(*sequence)
self.cam = CAMClassifier(base_channels * mult, activation_type) self.cam = CAMClassifier(base_channels * mult, activation_type)
# Gamma, Beta block # Gamma, Beta block
@ -85,25 +70,12 @@ class Generator(nn.Module):
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False) self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False) self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
# Up-Sampling Bottleneck self.decoder = Decoder(
self.up_bottleneck = nn.ModuleList( base_channels * mult, out_channels, n_down_sampling, num_blocks,
[ResidualBlock(base_channels * mult, base_channels * mult, padding_mode, activation_type=activation_type, padding_mode=padding_mode,
activation_type, norm_type="AdaILN") for _ in range(num_blocks)]) up_conv_kernel_size=3, up_conv_norm_type="ILN",
res_norm_type="AdaILN"
sequence = list() )
channels = base_channels * mult
for i in range(n_down_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2,
kernel_size=3, stride=1, padding=1, bias=False, padding_mode=padding_mode,
activation_type=activation_type, norm_type="ILN"),
))
channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels,
kernel_size=7, stride=1, padding=3, padding_mode="reflect",
activation_type="Tanh", norm_type="NONE"))
self.decoder = nn.Sequential(*sequence)
def forward(self, x): def forward(self, x):
x = self.encoder(x) x = self.encoder(x)
@ -119,10 +91,9 @@ class Generator(nn.Module):
x_ = self.fc(x.view(x.shape[0], -1)) x_ = self.fc(x.view(x.shape[0], -1))
gamma, beta = self.gamma(x_), self.beta(x_) gamma, beta = self.gamma(x_), self.beta(x_)
for blk in self.up_bottleneck: for blk in self.decoder.residual_blocks:
blk.conv1.normalization.set_condition(gamma, beta) blk.conv1.normalization.set_condition(gamma, beta)
blk.conv2.normalization.set_condition(gamma, beta) blk.conv2.normalization.set_condition(gamma, beta)
x = blk(x)
return self.decoder(x), cam_logit, heatmap return self.decoder(x), cam_logit, heatmap