move encoder, decoder to CycleGAN
This commit is contained in:
parent
04c6366c07
commit
9c08b4cd09
@ -70,7 +70,7 @@ class Conv2dBlock(nn.Module):
|
|||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, num_channels, out_channels=None, padding_mode='reflect',
|
def __init__(self, num_channels, out_channels=None, padding_mode='reflect',
|
||||||
activation_type="ReLU", out_activation_type=None, norm_type="IN"):
|
activation_type="ReLU", norm_type="IN", out_activation_type=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_type = norm_type
|
self.norm_type = norm_type
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,68 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from model.base.module import Conv2dBlock, ResidualBlock
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2,
|
||||||
|
padding_mode='reflect', activation_type="ReLU",
|
||||||
|
down_conv_norm_type="IN", down_conv_kernel_size=3,
|
||||||
|
res_norm_type="IN"):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
sequence = [Conv2dBlock(
|
||||||
|
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||||
|
activation_type=activation_type, norm_type=down_conv_norm_type
|
||||||
|
)]
|
||||||
|
multiple_now = 1
|
||||||
|
for i in range(1, num_conv + 1):
|
||||||
|
multiple_prev = multiple_now
|
||||||
|
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
|
||||||
|
sequence.append(Conv2dBlock(
|
||||||
|
multiple_prev * base_channels, multiple_now * base_channels,
|
||||||
|
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode=padding_mode,
|
||||||
|
activation_type=activation_type, norm_type=down_conv_norm_type
|
||||||
|
))
|
||||||
|
self.out_channels = multiple_now * base_channels
|
||||||
|
sequence += [
|
||||||
|
ResidualBlock(self.out_channels, self.out_channels, padding_mode, activation_type, norm_type=res_norm_type)
|
||||||
|
for _ in range(num_res)
|
||||||
|
]
|
||||||
|
self.sequence = nn.Sequential(*sequence)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.sequence(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
|
||||||
|
activation_type="ReLU", padding_mode='reflect',
|
||||||
|
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
||||||
|
res_norm_type="AdaIN"):
|
||||||
|
super().__init__()
|
||||||
|
self.residual_blocks = nn.ModuleList([
|
||||||
|
ResidualBlock(in_channels, in_channels, padding_mode, activation_type, norm_type=res_norm_type)
|
||||||
|
for _ in range(num_residual_blocks)
|
||||||
|
])
|
||||||
|
|
||||||
|
sequence = list()
|
||||||
|
channels = in_channels
|
||||||
|
for i in range(num_up_sampling):
|
||||||
|
sequence.append(nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=2),
|
||||||
|
Conv2dBlock(channels, channels // 2,
|
||||||
|
kernel_size=up_conv_kernel_size, stride=1,
|
||||||
|
padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode,
|
||||||
|
activation_type=activation_type, norm_type=up_conv_norm_type),
|
||||||
|
))
|
||||||
|
channels = channels // 2
|
||||||
|
sequence.append(Conv2dBlock(channels, out_channels,
|
||||||
|
kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||||
|
activation_type="Tanh", norm_type="NONE"))
|
||||||
|
|
||||||
|
self.up_sequence = nn.Sequential(*sequence)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for i, blk in enumerate(self.residual_blocks):
|
||||||
|
x = blk(x)
|
||||||
|
return self.up_sequence(x)
|
||||||
@ -2,99 +2,29 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from model import MODEL
|
from model import MODEL
|
||||||
from model.base.module import Conv2dBlock, ResidualBlock, LinearBlock
|
from model.base.module import LinearBlock
|
||||||
|
from model.image_translation.CycleGAN import Encoder, Decoder
|
||||||
|
|
||||||
def _get_down_sampling_sequence(in_channels, base_channels, num_conv, max_down_sampling_multiple=2,
|
|
||||||
padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
|
|
||||||
sequence = [Conv2dBlock(
|
|
||||||
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
|
||||||
activation_type=activation_type, norm_type=norm_type
|
|
||||||
)]
|
|
||||||
multiple_now = 1
|
|
||||||
for i in range(1, num_conv + 1):
|
|
||||||
multiple_prev = multiple_now
|
|
||||||
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
|
|
||||||
sequence.append(Conv2dBlock(
|
|
||||||
multiple_prev * base_channels, multiple_now * base_channels,
|
|
||||||
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
|
||||||
activation_type=activation_type, norm_type=norm_type
|
|
||||||
))
|
|
||||||
return sequence, multiple_now * base_channels
|
|
||||||
|
|
||||||
|
|
||||||
class StyleEncoder(nn.Module):
|
class StyleEncoder(nn.Module):
|
||||||
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
|
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
|
||||||
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
|
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.down_encoder = Encoder(
|
||||||
sequence, last_channels = _get_down_sampling_sequence(
|
in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple,
|
||||||
in_channels, base_channels, num_conv,
|
padding_mode=padding_mode, activation_type=activation_type,
|
||||||
max_down_sampling_multiple, padding_mode, activation_type, norm_type
|
down_conv_norm_type=norm_type, down_conv_kernel_size=4,
|
||||||
)
|
)
|
||||||
|
sequence = list()
|
||||||
sequence.append(nn.AdaptiveAvgPool2d(1))
|
sequence.append(nn.AdaptiveAvgPool2d(1))
|
||||||
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
|
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
|
||||||
sequence.append(nn.Conv2d(last_channels, out_dim, kernel_size=1, stride=1, padding=0))
|
sequence.append(nn.Conv2d(self.down_encoder.out_channels, out_dim, kernel_size=1, stride=1, padding=0))
|
||||||
self.sequence = nn.Sequential(*sequence)
|
self.sequence = nn.Sequential(*sequence)
|
||||||
|
|
||||||
def forward(self, image):
|
def forward(self, image):
|
||||||
return self.sequence(image).view(image.size(0), -1)
|
return self.sequence(image).view(image.size(0), -1)
|
||||||
|
|
||||||
|
|
||||||
class ContentEncoder(nn.Module):
|
|
||||||
def __init__(self, in_channels, num_down_sampling, num_residual_blocks, base_channels=64,
|
|
||||||
max_down_sampling_multiple=2,
|
|
||||||
padding_mode='reflect', activation_type="ReLU", norm_type="IN"):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
sequence, last_channels = _get_down_sampling_sequence(
|
|
||||||
in_channels, base_channels, num_down_sampling,
|
|
||||||
max_down_sampling_multiple, padding_mode, activation_type, norm_type
|
|
||||||
)
|
|
||||||
|
|
||||||
sequence += [ResidualBlock(last_channels, last_channels, padding_mode, activation_type, norm_type) for _ in
|
|
||||||
range(num_residual_blocks)]
|
|
||||||
self.sequence = nn.Sequential(*sequence)
|
|
||||||
|
|
||||||
def forward(self, image):
|
|
||||||
return self.sequence(image)
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
|
|
||||||
res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU", padding_mode='reflect'):
|
|
||||||
super().__init__()
|
|
||||||
self.residual_blocks = nn.ModuleList([
|
|
||||||
ResidualBlock(in_channels, in_channels, padding_mode, activation_type, norm_type=res_norm_type)
|
|
||||||
for _ in range(num_residual_blocks)
|
|
||||||
])
|
|
||||||
|
|
||||||
sequence = list()
|
|
||||||
channels = in_channels
|
|
||||||
for i in range(num_up_sampling):
|
|
||||||
sequence.append(nn.Sequential(
|
|
||||||
nn.Upsample(scale_factor=2),
|
|
||||||
Conv2dBlock(channels, channels // 2,
|
|
||||||
kernel_size=5, stride=1, padding=2, padding_mode=padding_mode,
|
|
||||||
activation_type=activation_type, norm_type=norm_type),
|
|
||||||
))
|
|
||||||
channels = channels // 2
|
|
||||||
sequence.append(Conv2dBlock(channels, out_channels,
|
|
||||||
kernel_size=7, stride=1, padding=3, padding_mode="reflect",
|
|
||||||
activation_type="Tanh", norm_type="NONE"))
|
|
||||||
|
|
||||||
self.up_sequence = nn.Sequential(*sequence)
|
|
||||||
|
|
||||||
def forward(self, x, style):
|
|
||||||
as_param_style = torch.chunk(style, 2 * len(self.residual_blocks), dim=1)
|
|
||||||
# set style for decoder
|
|
||||||
for i, blk in enumerate(self.residual_blocks):
|
|
||||||
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
|
||||||
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
|
||||||
x = blk(x)
|
|
||||||
return self.up_sequence(x)
|
|
||||||
|
|
||||||
|
|
||||||
class MLPFusion(nn.Module):
|
class MLPFusion(nn.Module):
|
||||||
def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"):
|
def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -119,10 +49,13 @@ class Generator(nn.Module):
|
|||||||
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
|
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
|
||||||
padding_mode='reflect', activation_type="ReLU"):
|
padding_mode='reflect', activation_type="ReLU"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.content_encoder = ContentEncoder(
|
self.content_encoder = Encoder(
|
||||||
in_channels, num_content_down_sampling, encoder_num_residual_blocks,
|
in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks,
|
||||||
base_channels, max_down_sampling_multiple,
|
max_down_sampling_multiple=num_content_down_sampling,
|
||||||
padding_mode, activation_type, norm_type="IN")
|
padding_mode=padding_mode, activation_type=activation_type,
|
||||||
|
down_conv_norm_type="IN", down_conv_kernel_size=4,
|
||||||
|
res_norm_type="IN"
|
||||||
|
)
|
||||||
|
|
||||||
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
|
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
|
||||||
max_down_sampling_multiple, padding_mode, activation_type,
|
max_down_sampling_multiple, padding_mode, activation_type,
|
||||||
@ -134,15 +67,21 @@ class Generator(nn.Module):
|
|||||||
num_mlp_base_feature, num_mlp_blocks, activation_type,
|
num_mlp_base_feature, num_mlp_blocks, activation_type,
|
||||||
norm_type="NONE")
|
norm_type="NONE")
|
||||||
|
|
||||||
self.decoder = Decoder(content_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
|
self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
|
||||||
res_norm_type="AdaIN", norm_type="LN", activation_type=activation_type,
|
activation_type=activation_type, padding_mode=padding_mode,
|
||||||
padding_mode=padding_mode)
|
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
||||||
|
res_norm_type="AdaIN")
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
return self.content_encoder(x), self.style_encoder(x)
|
return self.content_encoder(x), self.style_encoder(x)
|
||||||
|
|
||||||
def decode(self, content, style):
|
def decode(self, content, style):
|
||||||
self.decoder(content, self.fusion(style))
|
as_param_style = torch.chunk(self.fusion(style), 2 * len(self.decoder.residual_blocks), dim=1)
|
||||||
|
# set style for decoder
|
||||||
|
for i, blk in enumerate(self.decoder.residual_blocks):
|
||||||
|
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
||||||
|
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
||||||
|
self.decoder(content)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
content, style = self.encode(x)
|
content, style = self.encode(x)
|
||||||
|
|||||||
@ -2,7 +2,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from model import MODEL
|
from model import MODEL
|
||||||
from model.base.module import Conv2dBlock, ResidualBlock, LinearBlock
|
from model.base.module import Conv2dBlock, LinearBlock
|
||||||
|
from model.image_translation.CycleGAN import Encoder, Decoder
|
||||||
|
|
||||||
|
|
||||||
class RhoClipper(object):
|
class RhoClipper(object):
|
||||||
@ -46,27 +47,11 @@ class Generator(nn.Module):
|
|||||||
|
|
||||||
self.light = light
|
self.light = light
|
||||||
|
|
||||||
sequence = [Conv2dBlock(
|
|
||||||
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
|
||||||
activation_type=activation_type, norm_type=norm_type
|
|
||||||
)]
|
|
||||||
|
|
||||||
n_down_sampling = 2
|
n_down_sampling = 2
|
||||||
for i in range(n_down_sampling):
|
self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
|
||||||
mult = 2 ** i
|
padding_mode=padding_mode, activation_type=activation_type,
|
||||||
sequence.append(Conv2dBlock(
|
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type)
|
||||||
base_channels * mult, base_channels * mult * 2,
|
|
||||||
kernel_size=3, stride=2, padding=1, padding_mode=padding_mode,
|
|
||||||
activation_type=activation_type, norm_type=norm_type
|
|
||||||
))
|
|
||||||
|
|
||||||
mult = 2 ** n_down_sampling
|
mult = 2 ** n_down_sampling
|
||||||
sequence += [
|
|
||||||
ResidualBlock(base_channels * mult, base_channels * mult, padding_mode, activation_type=activation_type,
|
|
||||||
norm_type=norm_type)
|
|
||||||
for _ in range(num_blocks)]
|
|
||||||
self.encoder = nn.Sequential(*sequence)
|
|
||||||
|
|
||||||
self.cam = CAMClassifier(base_channels * mult, activation_type)
|
self.cam = CAMClassifier(base_channels * mult, activation_type)
|
||||||
|
|
||||||
# Gamma, Beta block
|
# Gamma, Beta block
|
||||||
@ -85,25 +70,12 @@ class Generator(nn.Module):
|
|||||||
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||||
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||||
|
|
||||||
# Up-Sampling Bottleneck
|
self.decoder = Decoder(
|
||||||
self.up_bottleneck = nn.ModuleList(
|
base_channels * mult, out_channels, n_down_sampling, num_blocks,
|
||||||
[ResidualBlock(base_channels * mult, base_channels * mult, padding_mode,
|
activation_type=activation_type, padding_mode=padding_mode,
|
||||||
activation_type, norm_type="AdaILN") for _ in range(num_blocks)])
|
up_conv_kernel_size=3, up_conv_norm_type="ILN",
|
||||||
|
res_norm_type="AdaILN"
|
||||||
sequence = list()
|
)
|
||||||
channels = base_channels * mult
|
|
||||||
for i in range(n_down_sampling):
|
|
||||||
sequence.append(nn.Sequential(
|
|
||||||
nn.Upsample(scale_factor=2),
|
|
||||||
Conv2dBlock(channels, channels // 2,
|
|
||||||
kernel_size=3, stride=1, padding=1, bias=False, padding_mode=padding_mode,
|
|
||||||
activation_type=activation_type, norm_type="ILN"),
|
|
||||||
))
|
|
||||||
channels = channels // 2
|
|
||||||
sequence.append(Conv2dBlock(channels, out_channels,
|
|
||||||
kernel_size=7, stride=1, padding=3, padding_mode="reflect",
|
|
||||||
activation_type="Tanh", norm_type="NONE"))
|
|
||||||
self.decoder = nn.Sequential(*sequence)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.encoder(x)
|
x = self.encoder(x)
|
||||||
@ -119,10 +91,9 @@ class Generator(nn.Module):
|
|||||||
x_ = self.fc(x.view(x.shape[0], -1))
|
x_ = self.fc(x.view(x.shape[0], -1))
|
||||||
gamma, beta = self.gamma(x_), self.beta(x_)
|
gamma, beta = self.gamma(x_), self.beta(x_)
|
||||||
|
|
||||||
for blk in self.up_bottleneck:
|
for blk in self.decoder.residual_blocks:
|
||||||
blk.conv1.normalization.set_condition(gamma, beta)
|
blk.conv1.normalization.set_condition(gamma, beta)
|
||||||
blk.conv2.normalization.set_condition(gamma, beta)
|
blk.conv2.normalization.set_condition(gamma, beta)
|
||||||
x = blk(x)
|
|
||||||
return self.decoder(x), cam_logit, heatmap
|
return self.decoder(x), cam_logit, heatmap
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user