This commit is contained in:
Ray Wong 2020-10-11 10:02:33 +08:00
parent 6ea13df465
commit 04c6366c07
24 changed files with 483 additions and 968 deletions

View File

@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="14d" remoteFilesAllowedToDisappearOnAutoupload="false">
<component name="PublishConfigData" autoUpload="Always" serverName="21d" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="14d">
<serverdata>

View File

@ -14,11 +14,15 @@ handler:
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: UGATIT-Generator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
base_channels: 64
@ -27,11 +31,13 @@ model:
light: True
local_discriminator:
_type: UGATIT-Discriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_blocks: 5
global_discriminator:
_type: UGATIT-Discriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_blocks: 7
@ -50,6 +56,8 @@ loss:
weight: 10.0
cam:
weight: 1000
mgc:
weight: 0
optimizers:
generator:
@ -70,7 +78,7 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 24
batch_size: 4
shuffle: True
num_workers: 2
pin_memory: True

View File

@ -1,16 +1,15 @@
from omegaconf import OmegaConf
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
import ignite.distributed as idist
from omegaconf import OmegaConf
from loss.gan import GANLoss
from model.GAN.UGATIT import RhoClipper
from model.GAN.base import GANImageBuffer
from util.image import attention_colored_map
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
from loss.gan import GANLoss
from model.image_translation.UGATIT import RhoClipper
from util.image import attention_colored_map
def mse_loss(x, target_flag):
@ -30,9 +29,8 @@ class UGATITEngineKernel(EngineKernel):
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.mgc_loss = MyLoss()
self.rho_clipper = RhoClipper(0, 1)
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
self.train_generator_first = False
def build_models(self) -> (dict, dict):
@ -82,6 +80,9 @@ class UGATITEngineKernel(EngineKernel):
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
generated["images"][f"{phase}2{phase}"])
if self.config.loss.mgc.weight > 0:
loss[f"mgc_{phase}"] = self.config.loss.mgc.weight * self.mgc_loss(
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
for dk in "lg":
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)

View File

@ -64,7 +64,7 @@ class EngineKernel(object):
self.engine = engine
def build_models(self) -> (dict, dict):
raise NotImplemented
raise NotImplementedError
def to_save(self):
to_save = {}
@ -73,19 +73,19 @@ class EngineKernel(object):
return to_save
def setup_after_g(self):
raise NotImplemented
raise NotImplementedError
def setup_before_g(self):
raise NotImplemented
raise NotImplementedError
def forward(self, batch, inference=False) -> dict:
raise NotImplemented
raise NotImplementedError
def criterion_generators(self, batch, generated) -> dict:
raise NotImplemented
raise NotImplementedError
def criterion_discriminators(self, batch, generated) -> dict:
raise NotImplemented
raise NotImplementedError
def intermediate_images(self, batch, generated) -> dict:
"""
@ -94,7 +94,7 @@ class EngineKernel(object):
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
raise NotImplemented
raise NotImplementedError
def change_engine(self, config, engine: Engine):
pass

View File

@ -1,18 +1,21 @@
import torch
import ignite.distributed as idist
import torch
import torch.optim as optim
from omegaconf import OmegaConf
from model import MODEL
import torch.optim as optim
from util.misc import add_spectral_norm
def build_model(cfg):
cfg = OmegaConf.to_container(cfg)
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False)
model = MODEL.build_with(cfg)
if bn_to_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if add_spectral_norm_flag:
model.apply(add_spectral_norm)
return idist.auto_model(model)

View File

@ -0,0 +1,111 @@
import torch
import torch.nn as nn
def gaussian_radial_basis_function(x, mu, sigma):
# (kernel_size) -> (batch_size, kernel_size, c*h*w)
mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3))
mu = mu.to(x.device)
# (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w)
x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1)
return torch.exp((x - mu).pow(2) / (2 * sigma ** 2))
class MyLoss(torch.nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, fakeI, realI):
def batch_ERSMI(I1, I2):
batch_size = I1.shape[0]
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
if I2.shape[1] == 1 and I1.shape[1] != 1:
I2 = I2.repeat(1, 3, 1, 1)
def kernel_F(y, mu_list, sigma):
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1).cuda() # [81, 784]
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
tmp_y = tmp_mu - tmp_y
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
return mat_L
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda()
x_mu_list = mu.repeat(9).view(-1, 81)
y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
mat_K = kernel_F(I1, x_mu_list, 1)
mat_L = kernel_F(I2, y_mu_list, 1)
H1 = ((mat_K.matmul(mat_K.transpose(1, 2))).mul(mat_L.matmul(mat_L.transpose(1, 2))) / (
img_size ** 2)).cuda()
H2 = ((mat_K.mul(mat_L)).matmul((mat_K.mul(mat_L)).transpose(1, 2)) / img_size).cuda()
h2 = ((mat_K.sum(2).view(batch_size, -1, 1)).mul(mat_L.sum(2).view(batch_size, -1, 1)) / (
img_size ** 2)).cuda()
H2 = 0.5 * H1 + 0.5 * H2
tmp = H2 + 0.05 * torch.eye(len(H2[0])).cuda()
alpha = (tmp.inverse())
alpha = alpha.matmul(h2)
ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
alpha) - 1).squeeze()
ersmi = -ersmi.mean()
return ersmi
batch_loss = batch_ERSMI(fakeI, realI)
return batch_loss
class MGCLoss(nn.Module):
"""
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
"""
def __init__(self, beta=0.5, lambda_=0.05):
super().__init__()
self.beta = beta
self.lambda_ = lambda_
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0])
self.mu_x = mu.repeat(9)
self.mu_y = mu.unsqueeze(0).t().repeat(1, 9).view(-1)
@staticmethod
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_):
assert img1.size() == img2.size()
num_pixel = img1.size(1) * img1.size(2) * img2.size(3)
mat_k = gaussian_radial_basis_function(img1, mu_x, sigma=1)
mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1)
mat_k_mul_mat_l = mat_k * mat_l
h_hat = (1 - beta) * (mat_k_mul_mat_l.matmul(mat_k_mul_mat_l.transpose(1, 2))) / num_pixel
h_hat += beta * (mat_k.matmul(mat_k.transpose(1, 2)) * mat_l.matmul(mat_l.transpose(1, 2))) / (num_pixel ** 2)
small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2)
R = torch.eye(h_hat.size(1)).to(img1.device)
alpha = (h_hat + lambda_ * R).inverse().matmul(small_h_hat)
rSMI = (2 * alpha.transpose(1, 2).matmul(small_h_hat)) - alpha.transpose(1, 2).matmul(h_hat).matmul(alpha) - 1
return rSMI
def forward(self, fake, real):
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_)
return -rSMI.squeeze().mean()
if __name__ == '__main__':
mg = MGCLoss().to("cuda")
def norm(x):
x -= x.min()
x /= x.max()
return (x - 0.5) * 2
x1 = norm(torch.randn(5, 3, 256, 256))
x2 = norm(x1 * 2 + 1)
x3 = norm(torch.randn(5, 3, 256, 256))
x4 = norm(torch.exp(x3))
print(mg(x1, x1), mg(x1, x2), mg(x1, x3), mg(x1, x4))

View File

@ -1,62 +0,0 @@
import torch.nn as nn
from model.normalization import select_norm_layer
from model.registry import MODEL
from .base import ResidualBlock
@MODEL.register_module("CyCle-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
norm_type="IN"):
super(Generator, self).__init__()
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
self.start_conv = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=use_bias),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)
# down sampling
submodules = []
num_down_sampling = 2
for i in range(num_down_sampling):
multiple = 2 ** i
submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True)
]
self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels
self.resnet_middle = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
multiple = 2 ** (num_down_sampling - i)
submodules += [
nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2,
padding=1, output_padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple // 2),
nn.ReLU(inplace=True),
]
self.decoder = nn.Sequential(*submodules)
self.end_conv = nn.Sequential(
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(self.start_conv(x))
x = self.resnet_middle(x)
return self.end_conv(self.decoder(x))

View File

@ -1,150 +0,0 @@
import torch
import torch.nn as nn
from model import MODEL
from model.GAN.base import Conv2dBlock, ResBlock
from model.normalization import select_norm_layer
class StyleEncoder(nn.Module):
def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False,
max_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
super(StyleEncoder, self).__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
))
sequence.append(nn.AdaptiveAvgPool2d(1))
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
sequence.append(nn.Conv2d(multiple_now * base_channels, out_dim, kernel_size=1, stride=1, padding=0))
self.model = nn.Sequential(*sequence)
def forward(self, x):
return self.model(x).view(x.size(0), -1)
class ContentEncoder(nn.Module):
def __init__(self, in_channels, num_down_sampling, num_res_blocks, base_channels=64, use_spectral_norm=False,
padding_mode='reflect', activation_type="ReLU", norm_type="IN"):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
)]
for i in range(num_down_sampling):
sequence.append(Conv2dBlock(
base_channels * (2 ** i), base_channels * (2 ** (i + 1)),
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
))
sequence += [ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type,
activation_type) for _ in range(num_res_blocks)]
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
return self.sequence(x)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_res_blocks,
use_spectral_norm=False, res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU",
padding_mode='reflect'):
super(Decoder, self).__init__()
self.res_norm_type = res_norm_type
self.res_blocks = nn.ModuleList([
ResBlock(in_channels, use_spectral_norm, padding_mode, res_norm_type, activation_type=activation_type)
for _ in range(num_res_blocks)
])
sequence = list()
channels = in_channels
for i in range(num_up_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2,
kernel_size=5, stride=1, padding=2, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
),
))
channels = channels // 2
sequence.append(
Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect",
use_spectral_norm=use_spectral_norm, activation_type="Tanh", norm_type="NONE"))
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
for blk in self.res_blocks:
x = blk(x)
return self.sequence(x)
class Fusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_fc = nn.Sequential(
nn.Linear(in_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
)
self.fcs = nn.Sequential(*[
nn.Sequential(
nn.Linear(base_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
) for _ in range(n_blocks - 2)
])
self.end_fc = nn.Sequential(
nn.Linear(base_features, out_features),
)
def forward(self, x):
x = self.start_fc(x)
x = self.fcs(x)
return self.end_fc(x)
@MODEL.register_module("MUNIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels, num_sampling, num_style_dim, num_style_conv,
num_content_res_blocks, num_decoder_res_blocks, num_fusion_dim, num_fusion_blocks,
use_spectral_norm=False, activation_type="ReLU", padding_mode='reflect'):
super().__init__()
self.num_decoder_res_blocks = num_decoder_res_blocks
self.content_encoder = ContentEncoder(in_channels, num_sampling, num_content_res_blocks, base_channels,
use_spectral_norm, padding_mode, activation_type, norm_type="IN")
self.style_encoder = StyleEncoder(in_channels, num_style_dim, num_style_conv, base_channels, use_spectral_norm,
padding_mode, activation_type, norm_type="NONE")
content_channels = base_channels * (2 ** 2)
self.decoder = Decoder(content_channels, out_channels, num_sampling,
num_decoder_res_blocks, use_spectral_norm, "AdaIN", norm_type="LN",
activation_type=activation_type, padding_mode=padding_mode)
self.fusion = Fusion(num_style_dim, num_decoder_res_blocks * 2 * content_channels * 2,
base_features=num_fusion_dim, n_blocks=num_fusion_blocks, norm_type="NONE")
def encode(self, x):
return self.content_encoder(x), self.style_encoder(x)
def decode(self, content, style):
as_param_style = torch.chunk(self.fusion(style), self.num_decoder_res_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return self.decoder(content)
def forward(self, x):
content, style = self.encode(x)
return self.decode(content, style)

View File

@ -1,171 +0,0 @@
import torch
import torch.nn as nn
from torchvision.models import vgg19
from model.normalization import select_norm_layer
from model.registry import MODEL
from .MUNIT import ContentEncoder, Fusion, Decoder, StyleEncoder
from .base import ResBlock
class VGG19StyleEncoder(nn.Module):
def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE",
vgg19_layers=(0, 5, 10, 19), fix_vgg19=True):
super().__init__()
self.vgg19_layers = vgg19_layers
self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1]
self.vgg19.requires_grad_(not fix_vgg19)
norm_layer = select_norm_layer(norm_type)
self.conv0 = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
bias=True),
norm_layer(base_channels),
nn.ReLU(True),
)
self.conv = nn.ModuleList([
nn.Sequential(
nn.Conv2d(base_channels * (2 ** i), base_channels * (2 ** i), kernel_size=4, stride=2, padding=1,
padding_mode=padding_mode, bias=True),
norm_layer(base_channels),
nn.ReLU(True),
) for i in range(1, 4)
])
self.pool = nn.AdaptiveAvgPool2d(1)
self.conv1x1 = nn.Conv2d(base_channels * (2 ** 4), style_dim, kernel_size=1, stride=1, padding=0)
def fixed_style_features(self, x):
features = []
for i in range(len(self.vgg19)):
x = self.vgg19[i](x)
if i in self.vgg19_layers:
features.append(x)
return features
def forward(self, x):
fsf = self.fixed_style_features(x)
x = self.conv0(x)
for i, l in enumerate(self.conv):
x = l(torch.cat([x, fsf[i]], dim=1))
x = self.pool(torch.cat([x, fsf[-1]], dim=1))
x = self.conv1x1(x)
return x.view(x.size(0), -1)
@MODEL.register_module("TAFG-ResGenerator")
class ResGenerator(nn.Module):
def __init__(self, in_channels, out_channels=3, use_spectral_norm=False, num_res_blocks=8, base_channels=64):
super().__init__()
self.content_encoder = ContentEncoder(in_channels, 2, num_res_blocks=num_res_blocks,
use_spectral_norm=use_spectral_norm)
resnet_channels = 2 ** 2 * base_channels
self.decoder = Decoder(resnet_channels, out_channels, 2,
0, use_spectral_norm, "IN", norm_type="LN", padding_mode="reflect")
def forward(self, x):
return self.decoder(self.content_encoder(x))
@MODEL.register_module("TAFG-SingleGenerator")
class SingleGenerator(nn.Module):
def __init__(self, style_in_channels, content_in_channels, out_channels=3, use_spectral_norm=False,
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
super().__init__()
self.num_adain_blocks = num_adain_blocks
if style_encoder_type == "StyleEncoder":
self.style_encoder = StyleEncoder(
style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"
)
elif style_encoder_type == "VGG19StyleEncoder":
self.style_encoder = VGG19StyleEncoder(
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE"
)
else:
raise NotImplemented(f"do not support {style_encoder_type}")
resnet_channels = 2 ** 2 * base_channels
self.style_converter = Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256,
n_blocks=3, norm_type="NONE")
self.content_encoder = ContentEncoder(content_in_channels, 2, num_res_blocks=num_res_blocks,
use_spectral_norm=use_spectral_norm)
self.decoder = Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode)
def forward(self, content_img, style_img):
content = self.content_encoder(content_img)
style = self.style_encoder(style_img)
as_param_style = torch.chunk(self.style_converter(style), self.num_adain_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return self.decoder(content)
@MODEL.register_module("TAFG-Generator")
class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__()
self.num_adain_blocks = num_adain_blocks
if style_encoder_type == "StyleEncoder":
self.style_encoders = nn.ModuleDict(dict(
a=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
b=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
))
elif style_encoder_type == "VGG19StyleEncoder":
self.style_encoders = nn.ModuleDict(dict(
a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
norm_type="NONE"),
b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
norm_type="NONE", fix_vgg19=False)
))
else:
raise NotImplemented(f"do not support {style_encoder_type}")
resnet_channels = 2 ** 2 * base_channels
self.style_converters = nn.ModuleDict(dict(
a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE"),
b=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE"),
))
self.content_encoders = nn.ModuleDict({
"a": ContentEncoder(content_in_channels, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm),
"b": ContentEncoder(1, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm)
})
self.content_resnet = nn.Sequential(*[
ResBlock(resnet_channels, use_spectral_norm, padding_mode, "IN")
for _ in range(num_res_blocks)
])
self.decoders = nn.ModuleDict(dict(
a=Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
b=Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
))
def encode(self, content_img, style_img, which_content, which_style):
content = self.content_resnet(self.content_encoders[which_content](content_img))
style = self.style_encoders[which_style](style_img)
return content, style
def decode(self, content, style, which):
decoder = self.decoders[which]
as_param_style = torch.chunk(self.style_converters[which](style), self.num_adain_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return decoder(content)
def forward(self, content_img, style_img, which_content, which_style):
content, style = self.encode(content_img, style_img, which_content, which_style)
return self.decode(content, style, which_style)

View File

@ -1,88 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
from model.base.module import Conv2dBlock, ResidualBlock, ReverseResidualBlock
class Interpolation(nn.Module):
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolation, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
recompute_scale_factor=False)
def __repr__(self):
return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
@MODEL.register_module("TSIT-Generator")
class Generator(nn.Module):
def __init__(self, content_in_channels=3, out_channels=3, base_channels=64, num_blocks=7,
padding_mode="reflect", activation_type="ReLU"):
super().__init__()
self.num_blocks = num_blocks
self.base_channels = base_channels
self.content_stream = self.build_stream(padding_mode, activation_type)
self.start_conv = Conv2dBlock(content_in_channels, base_channels, activation_type=activation_type,
norm_type="IN", kernel_size=7, padding_mode=padding_mode, padding=3)
sequence = []
multiple_now = min(2 ** self.num_blocks, 2 ** 4)
for i in range(1, self.num_blocks + 1):
m = self.num_blocks - i
multiple_prev = multiple_now
multiple_now = min(2 ** m, 2 ** 4)
sequence.append(nn.Sequential(
ReverseResidualBlock(
multiple_prev * base_channels, multiple_now * base_channels,
padding_mode=padding_mode, norm_type="FADE",
additional_norm_kwargs=dict(
condition_in_channels=multiple_prev * base_channels,
base_norm_type="BN",
padding_mode=padding_mode
)
),
Interpolation(2, mode="nearest")
))
self.generator = nn.Sequential(*sequence)
self.end_conv = Conv2dBlock(base_channels, out_channels, activation_type="Tanh",
kernel_size=7, padding_mode=padding_mode, padding=3)
def build_stream(self, padding_mode, activation_type):
multiple_now = 1
stream_sequence = []
for i in range(1, self.num_blocks + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
stream_sequence.append(nn.Sequential(
Interpolation(scale_factor=0.5, mode="nearest"),
ResidualBlock(
multiple_prev * self.base_channels, multiple_now * self.base_channels,
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
))
return nn.ModuleList(stream_sequence)
def forward(self, content, z=None):
c = self.start_conv(content)
content_features = []
for i in range(self.num_blocks):
c = self.content_stream[i](c)
content_features.append(c)
if z is None:
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
for i in range(self.num_blocks):
m = - i - 1
res_block = self.generator[i][0]
res_block.conv1.normalization.set_feature(content_features[m])
res_block.conv2.normalization.set_feature(content_features[m])
if res_block.learn_skip_connection:
res_block.res_conv.normalization.set_feature(content_features[m])
return self.end_conv(self.generator(z))

View File

@ -1,236 +0,0 @@
import torch
import torch.nn as nn
from .base import ResidualBlock
from model.registry import MODEL
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
@MODEL.register_module("UGATIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False):
assert (num_blocks >= 0)
super(Generator, self).__init__()
self.input_channels = in_channels
self.output_channels = out_channels
self.base_channels = base_channels
self.num_blocks = num_blocks
self.img_size = img_size
self.light = light
down_encoder = [nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3,
padding_mode="reflect", bias=False),
nn.InstanceNorm2d(base_channels),
nn.ReLU(True)]
n_down_sampling = 2
for i in range(n_down_sampling):
mult = 2 ** i
down_encoder += [nn.Conv2d(base_channels * mult, base_channels * mult * 2, kernel_size=3, stride=2,
padding=1, bias=False, padding_mode="reflect"),
nn.InstanceNorm2d(base_channels * mult * 2),
nn.ReLU(True)]
# Down-Sampling Bottleneck
mult = 2 ** n_down_sampling
for i in range(num_blocks):
down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)]
self.down_encoder = nn.Sequential(*down_encoder)
# Class Activation Map
self.gap_fc = nn.Linear(base_channels * mult, 1, bias=False)
self.gmp_fc = nn.Linear(base_channels * mult, 1, bias=False)
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
self.relu = nn.ReLU(True)
# Gamma, Beta block
if self.light:
fc = [nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True),
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True)]
else:
fc = [
nn.Linear(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True),
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True)]
self.fc = nn.Sequential(*fc)
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
# Up-Sampling Bottleneck
self.up_bottleneck = nn.ModuleList(
[ResnetAdaILNBlock(base_channels * mult, use_bias=False) for _ in range(num_blocks)])
# Up-Sampling
up_decoder = []
for i in range(n_down_sampling):
mult = 2 ** (n_down_sampling - i)
up_decoder += [nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(base_channels * mult, base_channels * mult // 2, kernel_size=3, stride=1,
padding=1, padding_mode="reflect", bias=False),
ILN(base_channels * mult // 2),
nn.ReLU(True)]
up_decoder += [nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
padding_mode="reflect", bias=False),
nn.Tanh()]
self.up_decoder = nn.Sequential(*up_decoder)
def forward(self, x):
x = self.down_encoder(x)
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.relu(self.conv1x1(x))
heatmap = torch.sum(x, dim=1, keepdim=True)
if self.light:
x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
x_ = self.fc(x_.view(x_.shape[0], -1))
else:
x_ = self.fc(x.view(x.shape[0], -1))
gamma, beta = self.gamma(x_), self.beta(x_)
for ub in self.up_bottleneck:
x = ub(x, gamma, beta)
x = self.up_decoder(x)
return x, cam_logit, heatmap
class ResnetAdaILNBlock(nn.Module):
def __init__(self, dim, use_bias):
super(ResnetAdaILNBlock, self).__init__()
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
self.norm1 = AdaILN(dim)
self.relu1 = nn.ReLU(True)
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
self.norm2 = AdaILN(dim)
def forward(self, x, gamma, beta):
out = self.conv1(x)
out = self.norm1(out, gamma, beta)
out = self.relu1(out)
out = self.conv2(out)
out = self.norm2(out, gamma, beta)
return out + x
def instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True)
out_in = (x - in_mean) / torch.sqrt(in_var + eps)
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
out_ln = (x - ln_mean) / torch.sqrt(ln_var + eps)
out = rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - rho.expand(x.shape[0], -1, -1, -1)) * out_ln
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
return out
class AdaILN(nn.Module):
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
super(AdaILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.rho.data.fill_(default_rho)
def forward(self, x, gamma, beta):
return instance_layer_normalization(x, gamma, beta, self.rho, self.eps)
class ILN(nn.Module):
def __init__(self, num_features, eps=1e-5):
super(ILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.gamma = nn.Parameter(torch.Tensor(1, num_features))
self.beta = nn.Parameter(torch.Tensor(1, num_features))
self.rho.data.fill_(0.0)
self.gamma.data.fill_(1.0)
self.beta.data.fill_(0.0)
def forward(self, x):
return instance_layer_normalization(
x, self.gamma.expand(x.shape[0], -1), self.beta.expand(x.shape[0], -1), self.rho, self.eps)
@MODEL.register_module("UGATIT-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=5):
super(Discriminator, self).__init__()
encoder = [self.build_conv_block(in_channels, base_channels)]
for i in range(1, num_blocks - 2):
mult = 2 ** (i - 1)
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2))
mult = 2 ** (num_blocks - 2 - 1)
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2, stride=1))
self.encoder = nn.Sequential(*encoder)
# Class Activation Map
mult = 2 ** (num_blocks - 2)
self.gap_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
self.gmp_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
self.leaky_relu = nn.LeakyReLU(0.2, True)
self.conv = nn.utils.spectral_norm(
nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect"))
@staticmethod
def build_conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"):
return nn.Sequential(*[
nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
bias=True, padding=padding, padding_mode=padding_mode)),
nn.LeakyReLU(0.2, True),
])
def forward(self, x, return_heatmap=False):
x = self.encoder(x)
batch_size = x.size(0)
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) # B x C x 1 x 1, avg of per channel
gap_logit = self.gap_fc(gap.view(batch_size, -1))
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(batch_size, -1))
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.leaky_relu(self.conv1x1(x))
if return_heatmap:
heatmap = torch.sum(x, dim=1, keepdim=True)
return self.conv(x), cam_logit, heatmap
else:
return self.conv(x), cam_logit

View File

@ -1,3 +0,0 @@
from util.misc import import_submodules
__all__ = import_submodules(__name__).keys()

View File

@ -1,203 +0,0 @@
from functools import partial
import math
import torch
import torch.nn as nn
from model import MODEL
from model.normalization import select_norm_layer
class GANImageBuffer(object):
"""This class implements an image buffer that stores previously
generated images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = torch.rand(1) < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = torch.randint(0, self.buffer_size, (1,)).item()
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images
# based SPADE or pix2pixHD Discriminator
@MODEL.register_module("PatchDiscriminator")
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
need_intermediate_feature=False):
super().__init__()
self.need_intermediate_feature = need_intermediate_feature
kernel_size = 4
padding = math.ceil((kernel_size - 1.0) / 2)
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
padding_mode = "zeros"
sequence = [nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size, stride=2, padding=padding),
nn.LeakyReLU(0.2, False)
)]
multiple_now = 1
for i in range(1, num_conv):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 3)
stride = 1 if i == num_conv - 1 else 2
sequence.append(nn.Sequential(
self.build_conv2d(use_spectral, base_channels * multiple_prev, base_channels * multiple_now,
kernel_size, stride, padding, bias=use_bias, padding_mode=padding_mode),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=False),
))
multiple_now = min(2 ** num_conv, 8)
sequence.append(nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding,
padding_mode=padding_mode))
self.conv_blocks = nn.ModuleList(sequence)
@staticmethod
def build_conv2d(use_spectral, in_channels: int, out_channels: int, kernel_size, stride, padding,
bias=True, padding_mode: str = 'zeros'):
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, padding_mode=padding_mode)
if not use_spectral:
return conv
return nn.utils.spectral_norm(conv)
def forward(self, x):
if self.need_intermediate_feature:
intermediate_feature = []
for layer in self.conv_blocks:
x = layer(x)
intermediate_feature.append(x)
return tuple(intermediate_feature)
else:
for layer in self.conv_blocks:
x = layer(x)
return x
@MODEL.register_module()
class ResidualBlock(nn.Module):
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
super(ResidualBlock, self).__init__()
if use_bias is None:
# Only for IN, use bias since it does not have affine parameters.
use_bias = norm_type == "IN"
norm_layer = select_norm_layer(norm_type)
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm1 = norm_layer(num_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm2 = norm_layer(num_channels)
def forward(self, x):
res = x
x = self.relu1(self.norm1(self.conv1(x)))
x = self.norm2(self.conv2(x))
return x + res
_DO_NO_THING_FUNC = lambda x: x
def select_activation(t):
if t == "ReLU":
return partial(nn.ReLU, inplace=True)
elif t == "LeakyReLU":
return partial(nn.LeakyReLU, negative_slope=0.2, inplace=True)
elif t == "Tanh":
return partial(nn.Tanh)
elif t == "NONE":
return _DO_NO_THING_FUNC
else:
raise NotImplemented
def _use_bias_checker(norm_type):
return norm_type not in ["IN", "BN", "AdaIN"]
class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, use_spectral_norm=False, activation_type="ReLU",
bias=None, norm_type="NONE", **conv_kwargs):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
conv = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
self.convolution = nn.utils.spectral_norm(conv) if use_spectral_norm else conv
if norm_type != "NONE":
self.normalization = select_norm_layer(norm_type)(out_channels)
if activation_type != "NONE":
self.activation = select_activation(activation_type)()
def forward(self, x):
x = self.convolution(x)
if self.norm_type != "NONE":
x = self.normalization(x)
if self.activation_type != "NONE":
x = self.activation(x)
return x
class ResBlock(nn.Module):
def __init__(self, num_channels, use_spectral_norm=False, padding_mode='reflect',
norm_type="IN", activation_type="ReLU", use_bias=None):
super().__init__()
self.norm_type = norm_type
if use_bias is None:
# bias will be canceled after channel wise normalization
use_bias = _use_bias_checker(norm_type)
self.conv1 = Conv2dBlock(num_channels, num_channels, use_spectral_norm,
kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias,
norm_type=norm_type, activation_type=activation_type)
self.conv2 = Conv2dBlock(num_channels, num_channels, use_spectral_norm,
kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias,
norm_type=norm_type, activation_type="NONE")
def forward(self, x):
return self.conv2(self.conv1(x)) + x

View File

@ -1,25 +0,0 @@
import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
@MODEL.register_module()
class MultiScaleDiscriminator(nn.Module):
def __init__(self, num_scale, discriminator_cfg):
super().__init__()
self.discriminator_list = nn.ModuleList([
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
])
@staticmethod
def down_sample(x):
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
def forward(self, x):
results = []
for discriminator in self.discriminator_list:
results.append(discriminator(x))
x = self.down_sample(x)
return results

View File

@ -1,4 +1,3 @@
from model.registry import MODEL, NORMALIZATION
import model.GAN
import model.base.normalization
import model.image_translation

View File

@ -30,7 +30,24 @@ def _activation(activation):
elif activation == "Tanh":
return nn.Tanh()
else:
raise NotImplemented(activation)
raise NotImplementedError(f"{activation} not valid")
class LinearBlock(nn.Module):
def __init__(self, in_features: int, out_features: int, bias=None, activation_type="ReLU", norm_type="NONE"):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
bias = _use_bias_checker(norm_type) if bias is None else bias
self.linear = nn.Linear(in_features, out_features, bias)
self.normalization = _normalization(norm_type, out_features)
self.activation = _activation(activation_type)
def forward(self, x):
return self.activation(self.normalization(self.linear(x)))
class Conv2dBlock(nn.Module):

View File

@ -93,7 +93,7 @@ class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps)
out = out * gamma + beta
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
return out
@ -115,7 +115,7 @@ class ILN(nn.Module):
def forward(self, x):
return _instance_layer_normalization(
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
x, self.gamma.view(1, -1), self.beta.view(1, -1), self.rho.view(1, -1, 1, 1), self.eps)
@NORMALIZATION.register_module("AdaILN")
@ -136,7 +136,6 @@ class AdaILN(nn.Module):
def forward(self, x):
assert self.have_set_condition
out = _instance_layer_normalization(
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
out = _instance_layer_normalization(x, self.gamma, self.beta, self.rho.view(1, -1, 1, 1), self.eps)
self.have_set_condition = False
return out

View File

View File

View File

@ -0,0 +1,149 @@
import torch
import torch.nn as nn
from model import MODEL
from model.base.module import Conv2dBlock, ResidualBlock, LinearBlock
def _get_down_sampling_sequence(in_channels, base_channels, num_conv, max_down_sampling_multiple=2,
padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
return sequence, multiple_now * base_channels
class StyleEncoder(nn.Module):
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
super().__init__()
sequence, last_channels = _get_down_sampling_sequence(
in_channels, base_channels, num_conv,
max_down_sampling_multiple, padding_mode, activation_type, norm_type
)
sequence.append(nn.AdaptiveAvgPool2d(1))
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
sequence.append(nn.Conv2d(last_channels, out_dim, kernel_size=1, stride=1, padding=0))
self.sequence = nn.Sequential(*sequence)
def forward(self, image):
return self.sequence(image).view(image.size(0), -1)
class ContentEncoder(nn.Module):
def __init__(self, in_channels, num_down_sampling, num_residual_blocks, base_channels=64,
max_down_sampling_multiple=2,
padding_mode='reflect', activation_type="ReLU", norm_type="IN"):
super().__init__()
sequence, last_channels = _get_down_sampling_sequence(
in_channels, base_channels, num_down_sampling,
max_down_sampling_multiple, padding_mode, activation_type, norm_type
)
sequence += [ResidualBlock(last_channels, last_channels, padding_mode, activation_type, norm_type) for _ in
range(num_residual_blocks)]
self.sequence = nn.Sequential(*sequence)
def forward(self, image):
return self.sequence(image)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU", padding_mode='reflect'):
super().__init__()
self.residual_blocks = nn.ModuleList([
ResidualBlock(in_channels, in_channels, padding_mode, activation_type, norm_type=res_norm_type)
for _ in range(num_residual_blocks)
])
sequence = list()
channels = in_channels
for i in range(num_up_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2,
kernel_size=5, stride=1, padding=2, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type),
))
channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels,
kernel_size=7, stride=1, padding=3, padding_mode="reflect",
activation_type="Tanh", norm_type="NONE"))
self.up_sequence = nn.Sequential(*sequence)
def forward(self, x, style):
as_param_style = torch.chunk(style, 2 * len(self.residual_blocks), dim=1)
# set style for decoder
for i, blk in enumerate(self.residual_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
x = blk(x)
return self.up_sequence(x)
class MLPFusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"):
super().__init__()
sequence = [LinearBlock(in_features, base_features, activation_type=activation_type, norm_type=norm_type)]
sequence += [
LinearBlock(base_features, base_features, activation_type=activation_type, norm_type=norm_type)
for _ in range(n_blocks - 2)
]
sequence.append(LinearBlock(base_features, out_features, activation_type=activation_type, norm_type=norm_type))
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
return self.sequence(x)
@MODEL.register_module("MUNIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, style_dim=8,
num_mlp_base_feature=256, num_mlp_blocks=3,
max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2,
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
padding_mode='reflect', activation_type="ReLU"):
super().__init__()
self.content_encoder = ContentEncoder(
in_channels, num_content_down_sampling, encoder_num_residual_blocks,
base_channels, max_down_sampling_multiple,
padding_mode, activation_type, norm_type="IN")
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
max_down_sampling_multiple, padding_mode, activation_type,
norm_type="NONE")
content_channels = base_channels * (2 ** max_down_sampling_multiple)
self.fusion = MLPFusion(style_dim, decoder_num_residual_blocks * 2 * content_channels * 2,
num_mlp_base_feature, num_mlp_blocks, activation_type,
norm_type="NONE")
self.decoder = Decoder(content_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
res_norm_type="AdaIN", norm_type="LN", activation_type=activation_type,
padding_mode=padding_mode)
def encode(self, x):
return self.content_encoder(x), self.style_encoder(x)
def decode(self, content, style):
self.decoder(content, self.fusion(style))
def forward(self, x):
content, style = self.encode(x)
return self.decode(content, style)

View File

View File

@ -0,0 +1,166 @@
import torch
import torch.nn as nn
from model import MODEL
from model.base.module import Conv2dBlock, ResidualBlock, LinearBlock
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
class CAMClassifier(nn.Module):
def __init__(self, in_channels, activation_type="ReLU"):
super(CAMClassifier, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.avg_fc = nn.Linear(in_channels, 1, bias=False)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.max_fc = nn.Linear(in_channels, 1, bias=False)
self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, activation_type=activation_type,
norm_type="NONE", kernel_size=1, stride=1, bias=True)
def forward(self, x):
avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1))
max_logit = self.max_fc(self.max_pool(x).view(x.size(0), -1))
return self.fusion_conv(torch.cat(
[x * self.avg_fc.weight.unsqueeze(2).unsqueeze(3), x * self.max_fc.weight.unsqueeze(2).unsqueeze(3)],
dim=1
)), torch.cat([avg_logit, max_logit], 1)
@MODEL.register_module("UGATIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False,
activation_type="ReLU", norm_type="IN", padding_mode='reflect'):
super(Generator, self).__init__()
self.light = light
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
n_down_sampling = 2
for i in range(n_down_sampling):
mult = 2 ** i
sequence.append(Conv2dBlock(
base_channels * mult, base_channels * mult * 2,
kernel_size=3, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
mult = 2 ** n_down_sampling
sequence += [
ResidualBlock(base_channels * mult, base_channels * mult, padding_mode, activation_type=activation_type,
norm_type=norm_type)
for _ in range(num_blocks)]
self.encoder = nn.Sequential(*sequence)
self.cam = CAMClassifier(base_channels * mult, activation_type)
# Gamma, Beta block
if self.light:
self.fc = nn.Sequential(
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE"),
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
)
else:
self.fc = nn.Sequential(
LinearBlock(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, False,
"ReLU", "NONE"),
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
)
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
# Up-Sampling Bottleneck
self.up_bottleneck = nn.ModuleList(
[ResidualBlock(base_channels * mult, base_channels * mult, padding_mode,
activation_type, norm_type="AdaILN") for _ in range(num_blocks)])
sequence = list()
channels = base_channels * mult
for i in range(n_down_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2,
kernel_size=3, stride=1, padding=1, bias=False, padding_mode=padding_mode,
activation_type=activation_type, norm_type="ILN"),
))
channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels,
kernel_size=7, stride=1, padding=3, padding_mode="reflect",
activation_type="Tanh", norm_type="NONE"))
self.decoder = nn.Sequential(*sequence)
def forward(self, x):
x = self.encoder(x)
x, cam_logit = self.cam(x)
heatmap = torch.sum(x, dim=1, keepdim=True)
if self.light:
x_ = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
x_ = self.fc(x_.view(x_.shape[0], -1))
else:
x_ = self.fc(x.view(x.shape[0], -1))
gamma, beta = self.gamma(x_), self.beta(x_)
for blk in self.up_bottleneck:
blk.conv1.normalization.set_condition(gamma, beta)
blk.conv2.normalization.set_condition(gamma, beta)
x = blk(x)
return self.decoder(x), cam_logit, heatmap
@MODEL.register_module("UGATIT-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=5,
activation_type="LeakyReLU", norm_type="NONE", padding_mode='reflect'):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
sequence += [Conv2dBlock(
base_channels * (2 ** i), base_channels * (2 ** i) * 2,
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type) for i in range(num_blocks - 3)]
sequence.append(
Conv2dBlock(base_channels * (2 ** (num_blocks - 3)), base_channels * (2 ** (num_blocks - 2)),
kernel_size=4, stride=1, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type)
)
self.sequence = nn.Sequential(*sequence)
mult = 2 ** (num_blocks - 2)
self.cam = CAMClassifier(base_channels * mult, activation_type)
self.conv = nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False,
padding_mode="reflect")
def forward(self, x, return_heatmap=False):
x = self.sequence(x)
x, cam_logit = self.cam(x)
if return_heatmap:
heatmap = torch.sum(x, dim=1, keepdim=True)
return self.conv(x), cam_logit, heatmap
else:
return self.conv(x), cam_logit

View File

View File

@ -8,7 +8,7 @@ import torch.nn as nn
def add_spectral_norm(module):
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module