Compare commits

...

5 Commits

Author SHA1 Message Date
0019d4034c change a lot 2020-10-14 18:55:51 +08:00
0927fa3de5 add patch d 2020-10-13 10:31:17 +08:00
611901cbdf add ConvTranspose2d in Conv2d 2020-10-13 10:31:03 +08:00
a6ffab1445 add image buffers for gan 2020-10-13 10:30:27 +08:00
7b05b45156 update SPADE 2020-10-12 19:01:07 +08:00
14 changed files with 503 additions and 214 deletions

View File

@ -1,34 +1,38 @@
name: horse2zebra-CyCleGAN name: selfie2anime-cycleGAN
engine: CyCleGAN engine: CycleGAN
result_dir: ./result result_dir: ./result
max_pairs: 266800 max_pairs: 1000000
misc: misc:
random_seed: 324 random_seed: 324
handler: handler:
clear_cuda_cache: False clear_cuda_cache: True
set_epoch_for_dist_sampler: True set_epoch_for_dist_sampler: True
checkpoint: checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2 n_saved: 2
tensorboard: tensorboard:
scalar: 100 # log scalar `scalar` times per epoch scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model: model:
generator: generator:
_type: CyCle-Generator _type: CycleGAN-Generator
_add_spectral_norm: True
in_channels: 3 in_channels: 3
out_channels: 3 out_channels: 3
base_channels: 64 base_channels: 64
num_blocks: 9 num_blocks: 9
padding_mode: reflect
norm_type: IN
discriminator: discriminator:
_type: PatchDiscriminator _type: PatchDiscriminator
_add_spectral_norm: True
in_channels: 3 in_channels: 3
base_channels: 64 base_channels: 64
num_conv: 4
loss: loss:
gan: gan:
@ -41,17 +45,21 @@ loss:
weight: 10.0 weight: 10.0
id: id:
level: 1 level: 1
weight: 0 weight: 10.0
mgc:
weight: 5
optimizers: optimizers:
generator: generator:
_type: Adam _type: Adam
lr: 2e-4 lr: 0.0001
betas: [ 0.5, 0.999 ] betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
discriminator: discriminator:
_type: Adam _type: Adam
lr: 2e-4 lr: 1e-4
betas: [ 0.5, 0.999 ] betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
data: data:
train: train:
@ -60,15 +68,15 @@ data:
target_lr: 0 target_lr: 0
buffer_size: 50 buffer_size: 50
dataloader: dataloader:
batch_size: 6 batch_size: 1
shuffle: True shuffle: True
num_workers: 2 num_workers: 2
pin_memory: True pin_memory: True
drop_last: True drop_last: True
dataset: dataset:
_type: GenerationUnpairedDataset _type: GenerationUnpairedDataset
root_a: "/data/i2i/horse2zebra/trainA" root_a: "/data/i2i/selfie2anime/trainA"
root_b: "/data/i2i/horse2zebra/trainB" root_b: "/data/i2i/selfie2anime/trainB"
random_pair: True random_pair: True
pipeline: pipeline:
- Load - Load
@ -82,16 +90,17 @@ data:
mean: [ 0.5, 0.5, 0.5 ] mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ] std: [ 0.5, 0.5, 0.5 ]
test: test:
which: video_dataset
dataloader: dataloader:
batch_size: 4 batch_size: 1
shuffle: False shuffle: False
num_workers: 1 num_workers: 1
pin_memory: False pin_memory: False
drop_last: False drop_last: False
dataset: dataset:
_type: GenerationUnpairedDataset _type: GenerationUnpairedDataset
root_a: "/data/i2i/horse2zebra/testA" root_a: "/data/i2i/selfie2anime/testA"
root_b: "/data/i2i/horse2zebra/testB" root_b: "/data/i2i/selfie2anime/testB"
random_pair: False random_pair: False
pipeline: pipeline:
- Load - Load
@ -101,3 +110,15 @@ data:
- Normalize: - Normalize:
mean: [ 0.5, 0.5, 0.5 ] mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ] std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@ -78,7 +78,7 @@ data:
target_lr: 0 target_lr: 0
buffer_size: 50 buffer_size: 50
dataloader: dataloader:
batch_size: 4 batch_size: 1
shuffle: True shuffle: True
num_workers: 2 num_workers: 2
pin_memory: True pin_memory: True
@ -102,7 +102,7 @@ data:
test: test:
which: video_dataset which: video_dataset
dataloader: dataloader:
batch_size: 8 batch_size: 1
shuffle: False shuffle: False
num_workers: 1 num_workers: 1
pin_memory: False pin_memory: False

View File

@ -1,26 +1,23 @@
from itertools import chain from itertools import chain
import ignite.distributed as idist
import torch import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model from engine.util.build import build_model
from loss.gan import GANLoss from engine.util.container import GANImageBuffer, LossContainer
from model.GAN.base import GANImageBuffer from engine.util.loss import pixel_loss, gan_loss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
from model.weight_init import generation_init_weights from model.weight_init import generation_init_weights
class TAFGEngineKernel(EngineKernel): class CycleGANEngineKernel(EngineKernel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan) self.gan_loss = gan_loss(config.loss.gan)
gan_loss_cfg.pop("weight") self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss())
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()} self.discriminators.keys()}
@ -56,21 +53,19 @@ class TAFGEngineKernel(EngineKernel):
images["b2a"] = self.generators["b2a"](batch["b"]) images["b2a"] = self.generators["b2a"](batch["b"])
images["a2b2a"] = self.generators["b2a"](images["a2b"]) images["a2b2a"] = self.generators["b2a"](images["a2b"])
images["b2a2b"] = self.generators["a2b"](images["b2a"]) images["b2a2b"] = self.generators["a2b"](images["b2a"])
if self.config.loss.id.weight > 0: if self.id_loss.weight > 0:
images["a2a"] = self.generators["b2a"](batch["a"]) images["a2a"] = self.generators["b2a"](batch["a"])
images["b2b"] = self.generators["a2b"](batch["b"]) images["b2b"] = self.generators["a2b"](batch["b"])
return images return images
def criterion_generators(self, batch, generated) -> dict: def criterion_generators(self, batch, generated) -> dict:
loss = dict() loss = dict()
for phase in ["a2b", "b2a"]: for ph in "ab":
loss[f"cycle_{phase[0]}"] = self.config.loss.cycle.weight * self.cycle_loss( loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph])
generated[f"{phase}2{phase[0]}"], batch[phase[0]]) loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph])
loss[f"gan_{phase}"] = self.config.loss.gan.weight * self.gan_loss( loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph])
self.discriminators[phase[-1]](generated[phase]), True) loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(
if self.config.loss.id.weight > 0: self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"]), True)
loss[f"id_{phase[0]}"] = self.config.loss.id.weight * self.id_loss(
generated[f"{phase[0]}2{phase[0]}"], batch[phase[0]])
return loss return loss
def criterion_discriminators(self, batch, generated) -> dict: def criterion_discriminators(self, batch, generated) -> dict:
@ -97,5 +92,5 @@ class TAFGEngineKernel(EngineKernel):
def run(task, config, _): def run(task, config, _):
kernel = TAFGEngineKernel(config) kernel = CycleGANEngineKernel(config)
run_kernel(task, config, kernel) run_kernel(task, config, kernel)

View File

@ -1,38 +1,31 @@
import ignite.distributed as idist
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model from engine.util.build import build_model
from engine.util.container import LossContainer from engine.util.container import LossContainer
from engine.util.loss import bce_loss, mse_loss, pixel_loss, gan_loss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
from loss.gan import GANLoss
from model.image_translation.UGATIT import RhoClipper
from util.image import attention_colored_map from util.image import attention_colored_map
def pixel_loss(level): class RhoClipper(object):
return nn.L1Loss() if level == 1 else nn.MSELoss() def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def __call__(self, module):
def mse_loss(x, target_flag): if hasattr(module, 'rho'):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
class UGATITEngineKernel(EngineKernel): class UGATITEngineKernel(EngineKernel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan) self.gan_loss = gan_loss(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level)) self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss()) self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level)) self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))

View File

@ -1,3 +1,6 @@
import torch
class LossContainer: class LossContainer:
def __init__(self, weight, loss): def __init__(self, weight, loss):
self.weight = weight self.weight = weight
@ -7,3 +10,57 @@ class LossContainer:
if self.weight > 0: if self.weight > 0:
return self.weight * self.loss(*args, **kwargs) return self.weight * self.loss(*args, **kwargs)
return 0.0 return 0.0
class GANImageBuffer:
"""This class implements an image buffer that stores previously
generated images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = torch.rand(1) < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = torch.randint(0, self.buffer_size, (1,)).item()
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images

25
engine/util/loss.py Normal file
View File

@ -0,0 +1,25 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from loss.gan import GANLoss
def gan_loss(config):
gan_loss_cfg = OmegaConf.to_container(config)
gan_loss_cfg.pop("weight")
return GANLoss(**gan_loss_cfg).to(idist.device())
def pixel_loss(level):
return nn.L1Loss() if level == 1 else nn.MSELoss()
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))

View File

@ -1,3 +1,4 @@
import ignite.distributed as idist
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -5,17 +6,59 @@ import torch.nn as nn
def gaussian_radial_basis_function(x, mu, sigma): def gaussian_radial_basis_function(x, mu, sigma):
# (kernel_size) -> (batch_size, kernel_size, c*h*w) # (kernel_size) -> (batch_size, kernel_size, c*h*w)
mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3)) mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3))
mu = mu.to(x.device)
# (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w) # (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w)
x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1) x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1)
return torch.exp((x - mu).pow(2) / (2 * sigma ** 2)) return torch.exp((x - mu).pow(2) / (2 * sigma ** 2))
class ImporveMyLoss(torch.nn.Module):
def __init__(self, device=idist.device()):
super().__init__()
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).to(device)
self.x_mu_list = mu.repeat(9).view(-1, 81)
self.y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
self.R = torch.eye(81).to(device)
def batch_ERSMI(self, I1, I2):
batch_size = I1.shape[0]
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
if I2.shape[1] == 1 and I1.shape[1] != 1:
I2 = I2.repeat(1, 3, 1, 1)
def kernel_F(y, mu_list, sigma):
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1) # [81, 784]
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
tmp_y = tmp_mu - tmp_y
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
return mat_L
mat_K = kernel_F(I1, self.x_mu_list, 1)
mat_L = kernel_F(I2, self.y_mu_list, 1)
mat_k_l = mat_K * mat_L
H1 = (mat_K @ mat_K.transpose(1, 2)) * (mat_L @ mat_L.transpose(1, 2)) / (img_size ** 2)
h_hat = mat_k_l @ mat_k_l.transpose(1, 2) / img_size
small_h_hat = mat_K.sum(2).view(batch_size, -1, 1) * mat_L.sum(2).view(batch_size, -1, 1) / (img_size ** 2)
h_hat = 0.5 * H1 + 0.5 * h_hat
alpha = (h_hat + 0.05 * self.R).inverse() @ small_h_hat
ersmi = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
ersmi = -ersmi.squeeze().mean()
return ersmi
def forward(self, fakeI, realI):
return self.batch_ERSMI(fakeI, realI)
class MyLoss(torch.nn.Module): class MyLoss(torch.nn.Module):
def __init__(self): def __init__(self):
super(MyLoss, self).__init__() super(MyLoss, self).__init__()
def forward(self, fakeI, realI): def forward(self, fakeI, realI):
fakeI = fakeI.cuda()
realI = realI.cuda()
def batch_ERSMI(I1, I2): def batch_ERSMI(I1, I2):
batch_size = I1.shape[0] batch_size = I1.shape[0]
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3] img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
@ -49,6 +92,7 @@ class MyLoss(torch.nn.Module):
alpha = alpha.matmul(h2) alpha = alpha.matmul(h2)
ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul( ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
alpha) - 1).squeeze() alpha) - 1).squeeze()
ersmi = -ersmi.mean() ersmi = -ersmi.mean()
return ersmi return ersmi
@ -61,16 +105,17 @@ class MGCLoss(nn.Module):
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
""" """
def __init__(self, beta=0.5, lambda_=0.05): def __init__(self, beta=0.5, lambda_=0.05, device=idist.device()):
super().__init__() super().__init__()
self.beta = beta self.beta = beta
self.lambda_ = lambda_ self.lambda_ = lambda_
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]) mu_y, mu_x = torch.meshgrid([torch.arange(-1, 1.25, 0.25), torch.arange(-1, 1.25, 0.25)])
self.mu_x = mu.repeat(9) self.mu_x = mu_x.flatten().to(device)
self.mu_y = mu.unsqueeze(0).t().repeat(1, 9).view(-1) self.mu_y = mu_y.flatten().to(device)
self.R = torch.eye(81).unsqueeze(0).to(device)
@staticmethod @staticmethod
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_): def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_, R):
assert img1.size() == img2.size() assert img1.size() == img2.size()
num_pixel = img1.size(1) * img1.size(2) * img2.size(3) num_pixel = img1.size(1) * img1.size(2) * img2.size(3)
@ -79,33 +124,102 @@ class MGCLoss(nn.Module):
mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1) mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1)
mat_k_mul_mat_l = mat_k * mat_l mat_k_mul_mat_l = mat_k * mat_l
h_hat = (1 - beta) * (mat_k_mul_mat_l.matmul(mat_k_mul_mat_l.transpose(1, 2))) / num_pixel h_hat = (1 - beta) * (mat_k_mul_mat_l @ mat_k_mul_mat_l.transpose(1, 2)) / num_pixel
h_hat += beta * (mat_k.matmul(mat_k.transpose(1, 2)) * mat_l.matmul(mat_l.transpose(1, 2))) / (num_pixel ** 2) h_hat += beta * ((mat_k @ mat_k.transpose(1, 2)) * (mat_l @ mat_l.transpose(1, 2))) / (num_pixel ** 2)
small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2) small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2)
R = torch.eye(h_hat.size(1)).to(img1.device) alpha = (h_hat + lambda_ * R).inverse() @ small_h_hat
alpha = (h_hat + lambda_ * R).inverse().matmul(small_h_hat) rSMI = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
return rSMI.squeeze()
rSMI = (2 * alpha.transpose(1, 2).matmul(small_h_hat)) - alpha.transpose(1, 2).matmul(h_hat).matmul(alpha) - 1
return rSMI
def forward(self, fake, real): def forward(self, fake, real):
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_) rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_, self.R)
return -rSMI.squeeze().mean() return -rSMI.mean()
if __name__ == '__main__': if __name__ == '__main__':
mg = MGCLoss().to("cuda") mg = MGCLoss(device=torch.device("cpu"))
my = MyLoss().to("cuda")
imy = ImporveMyLoss()
from data.transform import transform_pipeline
def norm(x): pipeline = transform_pipeline(
x -= x.min() ['Load', 'ToTensor', {'Normalize': {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5]}}])
x /= x.max()
return (x - 0.5) * 2
img_a1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_1.jpg")
img_a2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_2.jpg")
img_a3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_3.jpg")
img_b1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_1.jpg")
img_b2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_2.jpg")
img_b3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_3.jpg")
x1 = norm(torch.randn(5, 3, 256, 256)) img_a1.requires_grad_(True)
x2 = norm(x1 * 2 + 1) img_a2.requires_grad_(True)
x3 = norm(torch.randn(5, 3, 256, 256)) img_a3.requires_grad_(True)
x4 = norm(torch.exp(x3))
print(mg(x1, x1), mg(x1, x2), mg(x1, x3), mg(x1, x4)) # print("MyLoss")
# l1 = my(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
# l2 = my(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
# l3 = my(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
# l = (l1+l2+l3)/3
# l.backward()
# print(img_a1.grad[0][0][0:10])
# print(img_a2.grad[0][0][0:10])
# print(img_a3.grad[0][0][0:10])
#
# img_a1.grad = None
# img_a2.grad = None
# img_a3.grad = None
#
# print("---")
# l = my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# l.backward()
# print(img_a1.grad[0][0][0:10])
# print(img_a2.grad[0][0][0:10])
# print(img_a3.grad[0][0][0:10])
# img_a1.grad = None
# img_a2.grad = None
# img_a3.grad = None
print("MGCLoss")
l1 = mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
l2 = mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
l3 = mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
l = (l1 + l2 + l3) / 3
l.backward()
print(img_a1.grad[0][0][0:10])
print(img_a2.grad[0][0][0:10])
print(img_a3.grad[0][0][0:10])
img_a1.grad = None
img_a2.grad = None
img_a3.grad = None
print("---")
l = mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
l.backward()
print(img_a1.grad[0][0][0:10])
print(img_a2.grad[0][0][0:10])
print(img_a3.grad[0][0][0:10])
# print("\nMGCLoss")
# mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
# mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
# mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
#
# print("---")
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
#
# import pprofile
#
# profiler = pprofile.Profile()
# with profiler:
# iter_times = 1000
# for _ in range(iter_times):
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# for _ in range(iter_times):
# my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# for _ in range(iter_times):
# imy(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# profiler.print_stats()

View File

@ -1,3 +1,4 @@
from model.registry import MODEL, NORMALIZATION from model.registry import MODEL, NORMALIZATION
import model.base.normalization import model.base.normalization
import model.image_translation import model.image_translation.UGATIT
import model.image_translation.CycleGAN

View File

@ -52,35 +52,37 @@ class LinearBlock(nn.Module):
class Conv2dBlock(nn.Module): class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, bias=None, def __init__(self, in_channels: int, out_channels: int, bias=None,
activation_type="ReLU", norm_type="NONE", activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None,
additional_norm_kwargs=None, **conv_kwargs): pre_activation=False, use_transpose_conv=False, **conv_kwargs):
super().__init__() super().__init__()
self.norm_type = norm_type self.norm_type = norm_type
self.activation_type = activation_type self.activation_type = activation_type
self.pre_activation = pre_activation
# if caller not set bias, set bias automatically. if use_transpose_conv:
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias # Only "zeros" padding mode is supported for ConvTranspose2d
conv_kwargs["padding_mode"] = "zeros"
conv = nn.ConvTranspose2d
else:
conv = nn.Conv2d
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs) if pre_activation:
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs) self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
self.activation = _activation(activation_type) self.activation = _activation(activation_type, inplace=False)
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
else:
# if caller not set bias, set bias automatically.
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)
def forward(self, x): def forward(self, x):
if self.pre_activation:
return self.convolution(self.activation(self.normalization(x)))
return self.activation(self.normalization(self.convolution(x))) return self.activation(self.normalization(self.convolution(x)))
class ReverseConv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int,
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
super().__init__()
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
self.activation = _activation(activation_type, inplace=False)
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
def forward(self, x):
return self.convolution(self.activation(self.normalization(x)))
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, in_channels, def __init__(self, in_channels,
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False, padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
@ -109,16 +111,15 @@ class ResidualBlock(nn.Module):
self.learn_skip_connection = in_channels != out_channels self.learn_skip_connection = in_channels != out_channels
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type, conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
additional_norm_kwargs=additional_norm_kwargs, additional_norm_kwargs=additional_norm_kwargs, pre_activation=pre_activation,
padding_mode=padding_mode) padding_mode=padding_mode)
self.conv1 = conv_block(in_channels, in_channels, **conv_param) self.conv1 = Conv2dBlock(in_channels, in_channels, **conv_param)
self.conv2 = conv_block(in_channels, out_channels, **conv_param) self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
if self.learn_skip_connection: if self.learn_skip_connection:
self.res_conv = conv_block(in_channels, out_channels, **conv_param) self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
def forward(self, x): def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x) res = x if not self.learn_skip_connection else self.res_conv(x)

View File

@ -1,5 +1,6 @@
import torch.nn as nn import torch.nn as nn
from model import MODEL
from model.base.module import Conv2dBlock, ResidualBlock from model.base.module import Conv2dBlock, ResidualBlock
@ -20,7 +21,7 @@ class Encoder(nn.Module):
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple) multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
sequence.append(Conv2dBlock( sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels, multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode=padding_mode, kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode="zeros",
activation_type=activation_type, norm_type=down_conv_norm_type activation_type=activation_type, norm_type=down_conv_norm_type
)) ))
self.out_channels = multiple_now * base_channels self.out_channels = multiple_now * base_channels
@ -43,7 +44,7 @@ class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks, def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
activation_type="ReLU", padding_mode='reflect', activation_type="ReLU", padding_mode='reflect',
up_conv_kernel_size=5, up_conv_norm_type="LN", up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN", pre_activation=False): res_norm_type="AdaIN", pre_activation=False, use_transpose_conv=False):
super().__init__() super().__init__()
self.residual_blocks = nn.ModuleList([ self.residual_blocks = nn.ModuleList([
ResidualBlock( ResidualBlock(
@ -57,13 +58,23 @@ class Decoder(nn.Module):
sequence = list() sequence = list()
channels = in_channels channels = in_channels
padding = (up_conv_kernel_size - 1) // 2
for i in range(num_up_sampling): for i in range(num_up_sampling):
sequence.append(nn.Sequential( if use_transpose_conv:
nn.Upsample(scale_factor=2), sequence.append(Conv2dBlock(
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1, channels, channels // 2, kernel_size=up_conv_kernel_size, stride=2,
padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode, padding=padding, output_padding=padding,
activation_type=activation_type, norm_type=up_conv_norm_type), padding_mode=padding_mode,
)) activation_type=activation_type, norm_type=up_conv_norm_type,
use_transpose_conv=True
))
else:
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type),
))
channels = channels // 2 channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")) padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
@ -74,3 +85,67 @@ class Decoder(nn.Module):
for i, blk in enumerate(self.residual_blocks): for i, blk in enumerate(self.residual_blocks):
x = blk(x) x = blk(x)
return self.up_sequence(x) return self.up_sequence(x)
@MODEL.register_module("CycleGAN-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, activation_type="ReLU",
padding_mode='reflect', norm_type="IN", pre_activation=False, use_transpose_conv=True):
super().__init__()
self.encoder = Encoder(in_channels, base_channels, num_conv=2, num_res=num_blocks,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, res_norm_type=norm_type, pre_activation=pre_activation)
self.decoder = Decoder(self.encoder.out_channels, out_channels, num_up_sampling=2, num_residual_blocks=0,
padding_mode=padding_mode, activation_type=activation_type,
up_conv_kernel_size=3, up_conv_norm_type=norm_type,
pre_activation=pre_activation, use_transpose_conv=use_transpose_conv)
def forward(self, x):
return self.decoder(self.encoder(x))
@MODEL.register_module("PatchDiscriminator")
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False,
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.need_intermediate_feature = need_intermediate_feature
kernel_size = 4
padding = (kernel_size - 1) // 2
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 1
for i in range(1, num_conv):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 3)
stride = 1 if i == num_conv - 1 else 2
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
sequence.append(nn.Conv2d(
base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding, padding_mode=padding_mode))
if self.need_intermediate_feature:
self.sequence = nn.ModuleList(sequence)
else:
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
if self.need_intermediate_feature:
intermediate_feature = []
for layer in self.sequence:
x = layer(x)
intermediate_feature.append(x)
return tuple(intermediate_feature)
else:
return self.sequence(x)
if __name__ == '__main__':
g = Generator(**dict(in_channels=3, out_channels=3))
print(g)
pd = PatchDiscriminator(**dict(in_channels=3, base_channels=64, num_conv=4))
print(pd)

View File

@ -1,8 +1,12 @@
from collections import OrderedDict
from functools import partial
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from model.base.module import ResidualBlock, ReverseConv2dBlock, Conv2dBlock from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
class StyleEncoder(nn.Module): class StyleEncoder(nn.Module):
@ -33,6 +37,92 @@ class StyleEncoder(nn.Module):
return self.fc_avg(x), self.fc_var(x) return self.fc_avg(x), self.fc_var(x)
class ImprovedSPADEGenerator(nn.Module):
def __init__(self, in_channels, out_channels, output_size, have_style_input, style_dim, start_size=(4, 4),
base_channels=64, padding_mode='reflect', activation_type="LeakyReLU", pre_activation=False):
super().__init__()
assert output_size in (128, 256, 512, 1024)
self.output_size = output_size
kernel_size = 3
if have_style_input:
self.style_converter = nn.Sequential(
LinearBlock(style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
LinearBlock(2 * style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
)
base_conv = partial(
Conv2dBlock,
pre_activation=pre_activation, activation_type=activation_type,
norm_type="AdaIN" if have_style_input else "NONE",
kernel_size=kernel_size, padding=(kernel_size - 1) // 2, padding_mode=padding_mode
)
base_residual_block = partial(
ResidualBlock,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type="SPADE",
pre_activation=True,
additional_norm_kwargs=dict(
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
)
)
sequence = OrderedDict()
channels = (2 ** 4) * base_channels
sequence["block_head"] = nn.Sequential(OrderedDict([
("conv_input", base_conv(in_channels=in_channels, out_channels=channels)),
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
("res_a", base_residual_block(in_channels=channels, out_channels=channels)),
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
("up", nn.Upsample(scale_factor=2, mode='nearest'))
]))
for i in range(4, 9 - min(int(math.log(self.output_size, 2)), 8), -1):
channels = (2 ** (i - 1)) * base_channels
sequence[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
("res_a", base_residual_block(in_channels=channels * 2, out_channels=channels)),
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
("up", nn.Upsample(scale_factor=2, mode='nearest'))
]))
self.sequence = nn.Sequential(sequence)
# channels = 2*base_channels when output size is 256, 512, 1024
# channels = 5*base_channels when output size is 128
out_modules = OrderedDict()
out_modules["out_1"] = nn.Sequential(
Conv2dBlock(
channels, out_channels, kernel_size=5, stride=1, padding=2,
pre_activation=pre_activation,
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
),
nn.Tanh()
)
for i in range(int(math.log(self.output_size, 2)) - 8):
channels = channels // 2
out_modules[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
("res_a", base_residual_block(in_channels=2 * channels, out_channels=channels)),
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
("up", nn.Upsample(scale_factor=2, mode='nearest'))
]))
out_modules[f"out_{i + 2}"] = nn.Sequential(
Conv2dBlock(
channels, out_channels, kernel_size=5, stride=1, padding=2,
pre_activation=pre_activation,
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
),
nn.Tanh()
)
self.out_modules = nn.ModuleDict(out_modules)
def forward(self, seg, style=None):
pass
class SPADEGenerator(nn.Module): class SPADEGenerator(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64, def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
padding_mode='reflect', activation_type="LeakyReLU"): padding_mode='reflect', activation_type="LeakyReLU"):
@ -89,7 +179,8 @@ class SPADEGenerator(nn.Module):
x = blk(x) x = blk(x)
return self.output_converter(x) return self.output_converter(x)
if __name__ == '__main__': if __name__ == '__main__':
g = SPADEGenerator(3, 3, 7, False, 256) g = SPADEGenerator(3, 3, 7, False, 256)
print(g) print(g)
print(g(torch.randn(2, 3, 256, 256)).size()) print(g(torch.randn(2, 3, 256, 256)).size())

View File

@ -6,19 +6,6 @@ from model.base.module import Conv2dBlock, LinearBlock
from model.image_translation.CycleGAN import Encoder, Decoder from model.image_translation.CycleGAN import Encoder, Decoder
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
class CAMClassifier(nn.Module): class CAMClassifier(nn.Module):
def __init__(self, in_channels, activation_type="ReLU"): def __init__(self, in_channels, activation_type="ReLU"):
super(CAMClassifier, self).__init__() super(CAMClassifier, self).__init__()

View File

@ -1,76 +0,0 @@
import functools
import torch
import torch.nn as nn
def select_norm_layer(norm_type):
if norm_type == "BN":
return functools.partial(nn.BatchNorm2d)
elif norm_type == "IN":
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == "LN":
return functools.partial(LayerNorm2d, affine=True)
elif norm_type == "NONE":
return lambda num_features: nn.Identity()
elif norm_type == "AdaIN":
return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False)
else:
raise NotImplemented(f'normalization layer {norm_type} is not found')
class LayerNorm2d(nn.Module):
def __init__(self, num_features, eps: float = 1e-5, affine: bool = True):
super().__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
if self.affine:
self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.reset_parameters()
def reset_parameters(self):
if self.affine:
nn.init.uniform_(self.channel_gamma)
nn.init.zeros_(self.channel_beta)
def forward(self, x):
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
if self.affine:
return self.channel_gamma * x + self.channel_beta
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})"
class AdaptiveInstanceNorm2d(nn.Module):
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
affine: bool = False, track_running_stats: bool = False):
super().__init__()
self.num_features = num_features
self.affine = affine
self.track_running_stats = track_running_stats
self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats)
self.gamma = None
self.beta = None
self.have_set_style = False
def set_style(self, style):
style = style.view(*style.size(), 1, 1)
self.gamma, self.beta = style.chunk(2, 1)
self.have_set_style = True
def forward(self, x):
assert self.have_set_style
x = self.norm(x)
x = self.gamma * x + self.beta
self.have_set_style = False
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
f"affine={self.affine}, track_running_stats={self.track_running_stats})"

View File

@ -1,8 +1,10 @@
import inspect import inspect
from omegaconf.dictconfig import DictConfig
from omegaconf import OmegaConf
from types import ModuleType
import warnings import warnings
from types import ModuleType
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
class _Registry: class _Registry:
def __init__(self, name): def __init__(self, name):
@ -136,8 +138,11 @@ class Registry(_Registry):
if module_name is None: if module_name is None:
module_name = module_class.__name__ module_name = module_class.__name__
if not force and module_name in self._module_dict: if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered ' if self._module_dict[module_name] == module_class:
f'in {self.name}') warnings.warn(f'{module_name} is already registered in {self.name}, but is the same class')
return
raise KeyError(f'{module_name}:{self._module_dict[module_name]} is already registered in {self.name}'
f'so {module_class} can not be registered')
self._module_dict[module_name] = module_class self._module_dict[module_name] = module_class
def register_module(self, name=None, force=False, module=None): def register_module(self, name=None, force=False, module=None):