Compare commits
5 Commits
2de00d0245
...
0019d4034c
| Author | SHA1 | Date | |
|---|---|---|---|
| 0019d4034c | |||
| 0927fa3de5 | |||
| 611901cbdf | |||
| a6ffab1445 | |||
| 7b05b45156 |
@ -1,34 +1,38 @@
|
|||||||
name: horse2zebra-CyCleGAN
|
name: selfie2anime-cycleGAN
|
||||||
engine: CyCleGAN
|
engine: CycleGAN
|
||||||
result_dir: ./result
|
result_dir: ./result
|
||||||
max_pairs: 266800
|
max_pairs: 1000000
|
||||||
|
|
||||||
misc:
|
misc:
|
||||||
random_seed: 324
|
random_seed: 324
|
||||||
|
|
||||||
handler:
|
handler:
|
||||||
clear_cuda_cache: False
|
clear_cuda_cache: True
|
||||||
set_epoch_for_dist_sampler: True
|
set_epoch_for_dist_sampler: True
|
||||||
checkpoint:
|
checkpoint:
|
||||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||||
n_saved: 2
|
n_saved: 2
|
||||||
tensorboard:
|
tensorboard:
|
||||||
scalar: 100 # log scalar `scalar` times per epoch
|
scalar: 100 # log scalar `scalar` times per epoch
|
||||||
image: 2 # log image `image` times per epoch
|
image: 4 # log image `image` times per epoch
|
||||||
|
test:
|
||||||
|
random: True
|
||||||
|
images: 10
|
||||||
|
|
||||||
model:
|
model:
|
||||||
generator:
|
generator:
|
||||||
_type: CyCle-Generator
|
_type: CycleGAN-Generator
|
||||||
|
_add_spectral_norm: True
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
out_channels: 3
|
out_channels: 3
|
||||||
base_channels: 64
|
base_channels: 64
|
||||||
num_blocks: 9
|
num_blocks: 9
|
||||||
padding_mode: reflect
|
|
||||||
norm_type: IN
|
|
||||||
discriminator:
|
discriminator:
|
||||||
_type: PatchDiscriminator
|
_type: PatchDiscriminator
|
||||||
|
_add_spectral_norm: True
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
base_channels: 64
|
base_channels: 64
|
||||||
|
num_conv: 4
|
||||||
|
|
||||||
loss:
|
loss:
|
||||||
gan:
|
gan:
|
||||||
@ -41,17 +45,21 @@ loss:
|
|||||||
weight: 10.0
|
weight: 10.0
|
||||||
id:
|
id:
|
||||||
level: 1
|
level: 1
|
||||||
weight: 0
|
weight: 10.0
|
||||||
|
mgc:
|
||||||
|
weight: 5
|
||||||
|
|
||||||
optimizers:
|
optimizers:
|
||||||
generator:
|
generator:
|
||||||
_type: Adam
|
_type: Adam
|
||||||
lr: 2e-4
|
lr: 0.0001
|
||||||
betas: [ 0.5, 0.999 ]
|
betas: [ 0.5, 0.999 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
discriminator:
|
discriminator:
|
||||||
_type: Adam
|
_type: Adam
|
||||||
lr: 2e-4
|
lr: 1e-4
|
||||||
betas: [ 0.5, 0.999 ]
|
betas: [ 0.5, 0.999 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
|
||||||
data:
|
data:
|
||||||
train:
|
train:
|
||||||
@ -60,15 +68,15 @@ data:
|
|||||||
target_lr: 0
|
target_lr: 0
|
||||||
buffer_size: 50
|
buffer_size: 50
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 6
|
batch_size: 1
|
||||||
shuffle: True
|
shuffle: True
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
pin_memory: True
|
pin_memory: True
|
||||||
drop_last: True
|
drop_last: True
|
||||||
dataset:
|
dataset:
|
||||||
_type: GenerationUnpairedDataset
|
_type: GenerationUnpairedDataset
|
||||||
root_a: "/data/i2i/horse2zebra/trainA"
|
root_a: "/data/i2i/selfie2anime/trainA"
|
||||||
root_b: "/data/i2i/horse2zebra/trainB"
|
root_b: "/data/i2i/selfie2anime/trainB"
|
||||||
random_pair: True
|
random_pair: True
|
||||||
pipeline:
|
pipeline:
|
||||||
- Load
|
- Load
|
||||||
@ -82,16 +90,17 @@ data:
|
|||||||
mean: [ 0.5, 0.5, 0.5 ]
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
std: [ 0.5, 0.5, 0.5 ]
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
test:
|
test:
|
||||||
|
which: video_dataset
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 4
|
batch_size: 1
|
||||||
shuffle: False
|
shuffle: False
|
||||||
num_workers: 1
|
num_workers: 1
|
||||||
pin_memory: False
|
pin_memory: False
|
||||||
drop_last: False
|
drop_last: False
|
||||||
dataset:
|
dataset:
|
||||||
_type: GenerationUnpairedDataset
|
_type: GenerationUnpairedDataset
|
||||||
root_a: "/data/i2i/horse2zebra/testA"
|
root_a: "/data/i2i/selfie2anime/testA"
|
||||||
root_b: "/data/i2i/horse2zebra/testB"
|
root_b: "/data/i2i/selfie2anime/testB"
|
||||||
random_pair: False
|
random_pair: False
|
||||||
pipeline:
|
pipeline:
|
||||||
- Load
|
- Load
|
||||||
@ -101,3 +110,15 @@ data:
|
|||||||
- Normalize:
|
- Normalize:
|
||||||
mean: [ 0.5, 0.5, 0.5 ]
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
std: [ 0.5, 0.5, 0.5 ]
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
video_dataset:
|
||||||
|
_type: SingleFolderDataset
|
||||||
|
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
|
||||||
|
with_path: True
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
|||||||
@ -78,7 +78,7 @@ data:
|
|||||||
target_lr: 0
|
target_lr: 0
|
||||||
buffer_size: 50
|
buffer_size: 50
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 4
|
batch_size: 1
|
||||||
shuffle: True
|
shuffle: True
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
pin_memory: True
|
pin_memory: True
|
||||||
@ -102,7 +102,7 @@ data:
|
|||||||
test:
|
test:
|
||||||
which: video_dataset
|
which: video_dataset
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 8
|
batch_size: 1
|
||||||
shuffle: False
|
shuffle: False
|
||||||
num_workers: 1
|
num_workers: 1
|
||||||
pin_memory: False
|
pin_memory: False
|
||||||
|
|||||||
@ -1,26 +1,23 @@
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
import ignite.distributed as idist
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
from engine.base.i2i import EngineKernel, run_kernel
|
from engine.base.i2i import EngineKernel, run_kernel
|
||||||
from engine.util.build import build_model
|
from engine.util.build import build_model
|
||||||
from loss.gan import GANLoss
|
from engine.util.container import GANImageBuffer, LossContainer
|
||||||
from model.GAN.base import GANImageBuffer
|
from engine.util.loss import pixel_loss, gan_loss
|
||||||
|
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
|
||||||
from model.weight_init import generation_init_weights
|
from model.weight_init import generation_init_weights
|
||||||
|
|
||||||
|
|
||||||
class TAFGEngineKernel(EngineKernel):
|
class CycleGANEngineKernel(EngineKernel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
self.gan_loss = gan_loss(config.loss.gan)
|
||||||
gan_loss_cfg.pop("weight")
|
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
|
||||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
|
||||||
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss())
|
||||||
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
|
|
||||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||||
self.discriminators.keys()}
|
self.discriminators.keys()}
|
||||||
|
|
||||||
@ -56,21 +53,19 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
images["b2a"] = self.generators["b2a"](batch["b"])
|
images["b2a"] = self.generators["b2a"](batch["b"])
|
||||||
images["a2b2a"] = self.generators["b2a"](images["a2b"])
|
images["a2b2a"] = self.generators["b2a"](images["a2b"])
|
||||||
images["b2a2b"] = self.generators["a2b"](images["b2a"])
|
images["b2a2b"] = self.generators["a2b"](images["b2a"])
|
||||||
if self.config.loss.id.weight > 0:
|
if self.id_loss.weight > 0:
|
||||||
images["a2a"] = self.generators["b2a"](batch["a"])
|
images["a2a"] = self.generators["b2a"](batch["a"])
|
||||||
images["b2b"] = self.generators["a2b"](batch["b"])
|
images["b2b"] = self.generators["a2b"](batch["b"])
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def criterion_generators(self, batch, generated) -> dict:
|
def criterion_generators(self, batch, generated) -> dict:
|
||||||
loss = dict()
|
loss = dict()
|
||||||
for phase in ["a2b", "b2a"]:
|
for ph in "ab":
|
||||||
loss[f"cycle_{phase[0]}"] = self.config.loss.cycle.weight * self.cycle_loss(
|
loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph])
|
||||||
generated[f"{phase}2{phase[0]}"], batch[phase[0]])
|
loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph])
|
||||||
loss[f"gan_{phase}"] = self.config.loss.gan.weight * self.gan_loss(
|
loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph])
|
||||||
self.discriminators[phase[-1]](generated[phase]), True)
|
loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(
|
||||||
if self.config.loss.id.weight > 0:
|
self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"]), True)
|
||||||
loss[f"id_{phase[0]}"] = self.config.loss.id.weight * self.id_loss(
|
|
||||||
generated[f"{phase[0]}2{phase[0]}"], batch[phase[0]])
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def criterion_discriminators(self, batch, generated) -> dict:
|
def criterion_discriminators(self, batch, generated) -> dict:
|
||||||
@ -97,5 +92,5 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
|
|
||||||
|
|
||||||
def run(task, config, _):
|
def run(task, config, _):
|
||||||
kernel = TAFGEngineKernel(config)
|
kernel = CycleGANEngineKernel(config)
|
||||||
run_kernel(task, config, kernel)
|
run_kernel(task, config, kernel)
|
||||||
@ -1,38 +1,31 @@
|
|||||||
import ignite.distributed as idist
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||||
from engine.util.build import build_model
|
from engine.util.build import build_model
|
||||||
from engine.util.container import LossContainer
|
from engine.util.container import LossContainer
|
||||||
|
from engine.util.loss import bce_loss, mse_loss, pixel_loss, gan_loss
|
||||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
|
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
|
||||||
from loss.gan import GANLoss
|
|
||||||
from model.image_translation.UGATIT import RhoClipper
|
|
||||||
from util.image import attention_colored_map
|
from util.image import attention_colored_map
|
||||||
|
|
||||||
|
|
||||||
def pixel_loss(level):
|
class RhoClipper(object):
|
||||||
return nn.L1Loss() if level == 1 else nn.MSELoss()
|
def __init__(self, clip_min, clip_max):
|
||||||
|
self.clip_min = clip_min
|
||||||
|
self.clip_max = clip_max
|
||||||
|
assert clip_min < clip_max
|
||||||
|
|
||||||
|
def __call__(self, module):
|
||||||
def mse_loss(x, target_flag):
|
if hasattr(module, 'rho'):
|
||||||
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
w = module.rho.data
|
||||||
|
w = w.clamp(self.clip_min, self.clip_max)
|
||||||
|
module.rho.data = w
|
||||||
def bce_loss(x, target_flag):
|
|
||||||
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
|
||||||
|
|
||||||
|
|
||||||
class UGATITEngineKernel(EngineKernel):
|
class UGATITEngineKernel(EngineKernel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
self.gan_loss = gan_loss(config.loss.gan)
|
||||||
gan_loss_cfg.pop("weight")
|
|
||||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
|
||||||
|
|
||||||
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
|
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
|
||||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
|
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
|
||||||
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
|
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class LossContainer:
|
class LossContainer:
|
||||||
def __init__(self, weight, loss):
|
def __init__(self, weight, loss):
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
@ -7,3 +10,57 @@ class LossContainer:
|
|||||||
if self.weight > 0:
|
if self.weight > 0:
|
||||||
return self.weight * self.loss(*args, **kwargs)
|
return self.weight * self.loss(*args, **kwargs)
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class GANImageBuffer:
|
||||||
|
"""This class implements an image buffer that stores previously
|
||||||
|
generated images.
|
||||||
|
This buffer allows us to update the discriminator using a history of
|
||||||
|
generated images rather than the ones produced by the latest generator
|
||||||
|
to reduce model oscillation.
|
||||||
|
Args:
|
||||||
|
buffer_size (int): The size of image buffer. If buffer_size = 0,
|
||||||
|
no buffer will be created.
|
||||||
|
buffer_ratio (float): The chance / possibility to use the images
|
||||||
|
previously stored in the buffer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, buffer_size, buffer_ratio=0.5):
|
||||||
|
self.buffer_size = buffer_size
|
||||||
|
# create an empty buffer
|
||||||
|
if self.buffer_size > 0:
|
||||||
|
self.img_num = 0
|
||||||
|
self.image_buffer = []
|
||||||
|
self.buffer_ratio = buffer_ratio
|
||||||
|
|
||||||
|
def query(self, images):
|
||||||
|
"""Query current image batch using a history of generated images.
|
||||||
|
Args:
|
||||||
|
images (Tensor): Current image batch without history information.
|
||||||
|
"""
|
||||||
|
if self.buffer_size == 0: # if the buffer size is 0, do nothing
|
||||||
|
return images
|
||||||
|
return_images = []
|
||||||
|
for image in images:
|
||||||
|
image = torch.unsqueeze(image.data, 0)
|
||||||
|
# if the buffer is not full, keep inserting current images
|
||||||
|
if self.img_num < self.buffer_size:
|
||||||
|
self.img_num = self.img_num + 1
|
||||||
|
self.image_buffer.append(image)
|
||||||
|
return_images.append(image)
|
||||||
|
else:
|
||||||
|
use_buffer = torch.rand(1) < self.buffer_ratio
|
||||||
|
# by self.buffer_ratio, the buffer will return a previously
|
||||||
|
# stored image, and insert the current image into the buffer
|
||||||
|
if use_buffer:
|
||||||
|
random_id = torch.randint(0, self.buffer_size, (1,)).item()
|
||||||
|
image_tmp = self.image_buffer[random_id].clone()
|
||||||
|
self.image_buffer[random_id] = image
|
||||||
|
return_images.append(image_tmp)
|
||||||
|
# by (1 - self.buffer_ratio), the buffer will return the
|
||||||
|
# current image
|
||||||
|
else:
|
||||||
|
return_images.append(image)
|
||||||
|
# collect all the images and return
|
||||||
|
return_images = torch.cat(return_images, 0)
|
||||||
|
return return_images
|
||||||
|
|||||||
25
engine/util/loss.py
Normal file
25
engine/util/loss.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import ignite.distributed as idist
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from loss.gan import GANLoss
|
||||||
|
|
||||||
|
|
||||||
|
def gan_loss(config):
|
||||||
|
gan_loss_cfg = OmegaConf.to_container(config)
|
||||||
|
gan_loss_cfg.pop("weight")
|
||||||
|
return GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
|
||||||
|
def pixel_loss(level):
|
||||||
|
return nn.L1Loss() if level == 1 else nn.MSELoss()
|
||||||
|
|
||||||
|
|
||||||
|
def mse_loss(x, target_flag):
|
||||||
|
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||||
|
|
||||||
|
|
||||||
|
def bce_loss(x, target_flag):
|
||||||
|
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
import ignite.distributed as idist
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -5,17 +6,59 @@ import torch.nn as nn
|
|||||||
def gaussian_radial_basis_function(x, mu, sigma):
|
def gaussian_radial_basis_function(x, mu, sigma):
|
||||||
# (kernel_size) -> (batch_size, kernel_size, c*h*w)
|
# (kernel_size) -> (batch_size, kernel_size, c*h*w)
|
||||||
mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3))
|
mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3))
|
||||||
mu = mu.to(x.device)
|
|
||||||
# (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w)
|
# (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w)
|
||||||
x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1)
|
x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1)
|
||||||
return torch.exp((x - mu).pow(2) / (2 * sigma ** 2))
|
return torch.exp((x - mu).pow(2) / (2 * sigma ** 2))
|
||||||
|
|
||||||
|
|
||||||
|
class ImporveMyLoss(torch.nn.Module):
|
||||||
|
def __init__(self, device=idist.device()):
|
||||||
|
super().__init__()
|
||||||
|
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).to(device)
|
||||||
|
self.x_mu_list = mu.repeat(9).view(-1, 81)
|
||||||
|
self.y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
|
||||||
|
self.R = torch.eye(81).to(device)
|
||||||
|
|
||||||
|
def batch_ERSMI(self, I1, I2):
|
||||||
|
batch_size = I1.shape[0]
|
||||||
|
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
|
||||||
|
if I2.shape[1] == 1 and I1.shape[1] != 1:
|
||||||
|
I2 = I2.repeat(1, 3, 1, 1)
|
||||||
|
|
||||||
|
def kernel_F(y, mu_list, sigma):
|
||||||
|
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1) # [81, 784]
|
||||||
|
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
|
||||||
|
tmp_y = tmp_mu - tmp_y
|
||||||
|
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
|
||||||
|
return mat_L
|
||||||
|
|
||||||
|
mat_K = kernel_F(I1, self.x_mu_list, 1)
|
||||||
|
mat_L = kernel_F(I2, self.y_mu_list, 1)
|
||||||
|
mat_k_l = mat_K * mat_L
|
||||||
|
|
||||||
|
H1 = (mat_K @ mat_K.transpose(1, 2)) * (mat_L @ mat_L.transpose(1, 2)) / (img_size ** 2)
|
||||||
|
h_hat = mat_k_l @ mat_k_l.transpose(1, 2) / img_size
|
||||||
|
small_h_hat = mat_K.sum(2).view(batch_size, -1, 1) * mat_L.sum(2).view(batch_size, -1, 1) / (img_size ** 2)
|
||||||
|
h_hat = 0.5 * H1 + 0.5 * h_hat
|
||||||
|
alpha = (h_hat + 0.05 * self.R).inverse() @ small_h_hat
|
||||||
|
|
||||||
|
ersmi = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
|
||||||
|
|
||||||
|
ersmi = -ersmi.squeeze().mean()
|
||||||
|
return ersmi
|
||||||
|
|
||||||
|
def forward(self, fakeI, realI):
|
||||||
|
return self.batch_ERSMI(fakeI, realI)
|
||||||
|
|
||||||
|
|
||||||
class MyLoss(torch.nn.Module):
|
class MyLoss(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MyLoss, self).__init__()
|
super(MyLoss, self).__init__()
|
||||||
|
|
||||||
def forward(self, fakeI, realI):
|
def forward(self, fakeI, realI):
|
||||||
|
fakeI = fakeI.cuda()
|
||||||
|
realI = realI.cuda()
|
||||||
|
|
||||||
def batch_ERSMI(I1, I2):
|
def batch_ERSMI(I1, I2):
|
||||||
batch_size = I1.shape[0]
|
batch_size = I1.shape[0]
|
||||||
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
|
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
|
||||||
@ -49,6 +92,7 @@ class MyLoss(torch.nn.Module):
|
|||||||
alpha = alpha.matmul(h2)
|
alpha = alpha.matmul(h2)
|
||||||
ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
|
ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
|
||||||
alpha) - 1).squeeze()
|
alpha) - 1).squeeze()
|
||||||
|
|
||||||
ersmi = -ersmi.mean()
|
ersmi = -ersmi.mean()
|
||||||
return ersmi
|
return ersmi
|
||||||
|
|
||||||
@ -61,16 +105,17 @@ class MGCLoss(nn.Module):
|
|||||||
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
|
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, beta=0.5, lambda_=0.05):
|
def __init__(self, beta=0.5, lambda_=0.05, device=idist.device()):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.lambda_ = lambda_
|
self.lambda_ = lambda_
|
||||||
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0])
|
mu_y, mu_x = torch.meshgrid([torch.arange(-1, 1.25, 0.25), torch.arange(-1, 1.25, 0.25)])
|
||||||
self.mu_x = mu.repeat(9)
|
self.mu_x = mu_x.flatten().to(device)
|
||||||
self.mu_y = mu.unsqueeze(0).t().repeat(1, 9).view(-1)
|
self.mu_y = mu_y.flatten().to(device)
|
||||||
|
self.R = torch.eye(81).unsqueeze(0).to(device)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_):
|
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_, R):
|
||||||
assert img1.size() == img2.size()
|
assert img1.size() == img2.size()
|
||||||
|
|
||||||
num_pixel = img1.size(1) * img1.size(2) * img2.size(3)
|
num_pixel = img1.size(1) * img1.size(2) * img2.size(3)
|
||||||
@ -79,33 +124,102 @@ class MGCLoss(nn.Module):
|
|||||||
mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1)
|
mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1)
|
||||||
|
|
||||||
mat_k_mul_mat_l = mat_k * mat_l
|
mat_k_mul_mat_l = mat_k * mat_l
|
||||||
h_hat = (1 - beta) * (mat_k_mul_mat_l.matmul(mat_k_mul_mat_l.transpose(1, 2))) / num_pixel
|
h_hat = (1 - beta) * (mat_k_mul_mat_l @ mat_k_mul_mat_l.transpose(1, 2)) / num_pixel
|
||||||
h_hat += beta * (mat_k.matmul(mat_k.transpose(1, 2)) * mat_l.matmul(mat_l.transpose(1, 2))) / (num_pixel ** 2)
|
h_hat += beta * ((mat_k @ mat_k.transpose(1, 2)) * (mat_l @ mat_l.transpose(1, 2))) / (num_pixel ** 2)
|
||||||
small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2)
|
small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2)
|
||||||
|
|
||||||
R = torch.eye(h_hat.size(1)).to(img1.device)
|
alpha = (h_hat + lambda_ * R).inverse() @ small_h_hat
|
||||||
alpha = (h_hat + lambda_ * R).inverse().matmul(small_h_hat)
|
rSMI = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
|
||||||
|
return rSMI.squeeze()
|
||||||
rSMI = (2 * alpha.transpose(1, 2).matmul(small_h_hat)) - alpha.transpose(1, 2).matmul(h_hat).matmul(alpha) - 1
|
|
||||||
return rSMI
|
|
||||||
|
|
||||||
def forward(self, fake, real):
|
def forward(self, fake, real):
|
||||||
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_)
|
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_, self.R)
|
||||||
return -rSMI.squeeze().mean()
|
return -rSMI.mean()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
mg = MGCLoss().to("cuda")
|
mg = MGCLoss(device=torch.device("cpu"))
|
||||||
|
my = MyLoss().to("cuda")
|
||||||
|
imy = ImporveMyLoss()
|
||||||
|
|
||||||
|
from data.transform import transform_pipeline
|
||||||
|
|
||||||
def norm(x):
|
pipeline = transform_pipeline(
|
||||||
x -= x.min()
|
['Load', 'ToTensor', {'Normalize': {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5]}}])
|
||||||
x /= x.max()
|
|
||||||
return (x - 0.5) * 2
|
|
||||||
|
|
||||||
|
img_a1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_1.jpg")
|
||||||
|
img_a2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_2.jpg")
|
||||||
|
img_a3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_3.jpg")
|
||||||
|
img_b1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_1.jpg")
|
||||||
|
img_b2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_2.jpg")
|
||||||
|
img_b3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_3.jpg")
|
||||||
|
|
||||||
x1 = norm(torch.randn(5, 3, 256, 256))
|
img_a1.requires_grad_(True)
|
||||||
x2 = norm(x1 * 2 + 1)
|
img_a2.requires_grad_(True)
|
||||||
x3 = norm(torch.randn(5, 3, 256, 256))
|
img_a3.requires_grad_(True)
|
||||||
x4 = norm(torch.exp(x3))
|
|
||||||
print(mg(x1, x1), mg(x1, x2), mg(x1, x3), mg(x1, x4))
|
# print("MyLoss")
|
||||||
|
# l1 = my(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
|
||||||
|
# l2 = my(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
|
||||||
|
# l3 = my(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
|
||||||
|
# l = (l1+l2+l3)/3
|
||||||
|
# l.backward()
|
||||||
|
# print(img_a1.grad[0][0][0:10])
|
||||||
|
# print(img_a2.grad[0][0][0:10])
|
||||||
|
# print(img_a3.grad[0][0][0:10])
|
||||||
|
#
|
||||||
|
# img_a1.grad = None
|
||||||
|
# img_a2.grad = None
|
||||||
|
# img_a3.grad = None
|
||||||
|
#
|
||||||
|
# print("---")
|
||||||
|
# l = my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
|
||||||
|
# l.backward()
|
||||||
|
# print(img_a1.grad[0][0][0:10])
|
||||||
|
# print(img_a2.grad[0][0][0:10])
|
||||||
|
# print(img_a3.grad[0][0][0:10])
|
||||||
|
# img_a1.grad = None
|
||||||
|
# img_a2.grad = None
|
||||||
|
# img_a3.grad = None
|
||||||
|
|
||||||
|
print("MGCLoss")
|
||||||
|
l1 = mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
|
||||||
|
l2 = mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
|
||||||
|
l3 = mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
|
||||||
|
l = (l1 + l2 + l3) / 3
|
||||||
|
l.backward()
|
||||||
|
print(img_a1.grad[0][0][0:10])
|
||||||
|
print(img_a2.grad[0][0][0:10])
|
||||||
|
print(img_a3.grad[0][0][0:10])
|
||||||
|
|
||||||
|
img_a1.grad = None
|
||||||
|
img_a2.grad = None
|
||||||
|
img_a3.grad = None
|
||||||
|
|
||||||
|
print("---")
|
||||||
|
l = mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
|
||||||
|
l.backward()
|
||||||
|
print(img_a1.grad[0][0][0:10])
|
||||||
|
print(img_a2.grad[0][0][0:10])
|
||||||
|
print(img_a3.grad[0][0][0:10])
|
||||||
|
|
||||||
|
# print("\nMGCLoss")
|
||||||
|
# mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
|
||||||
|
# mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
|
||||||
|
# mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
|
||||||
|
#
|
||||||
|
# print("---")
|
||||||
|
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
|
||||||
|
#
|
||||||
|
# import pprofile
|
||||||
|
#
|
||||||
|
# profiler = pprofile.Profile()
|
||||||
|
# with profiler:
|
||||||
|
# iter_times = 1000
|
||||||
|
# for _ in range(iter_times):
|
||||||
|
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
|
||||||
|
# for _ in range(iter_times):
|
||||||
|
# my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
|
||||||
|
# for _ in range(iter_times):
|
||||||
|
# imy(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
|
||||||
|
# profiler.print_stats()
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
from model.registry import MODEL, NORMALIZATION
|
from model.registry import MODEL, NORMALIZATION
|
||||||
import model.base.normalization
|
import model.base.normalization
|
||||||
import model.image_translation
|
import model.image_translation.UGATIT
|
||||||
|
import model.image_translation.CycleGAN
|
||||||
|
|||||||
@ -52,35 +52,37 @@ class LinearBlock(nn.Module):
|
|||||||
|
|
||||||
class Conv2dBlock(nn.Module):
|
class Conv2dBlock(nn.Module):
|
||||||
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
||||||
activation_type="ReLU", norm_type="NONE",
|
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None,
|
||||||
additional_norm_kwargs=None, **conv_kwargs):
|
pre_activation=False, use_transpose_conv=False, **conv_kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_type = norm_type
|
self.norm_type = norm_type
|
||||||
self.activation_type = activation_type
|
self.activation_type = activation_type
|
||||||
|
self.pre_activation = pre_activation
|
||||||
|
|
||||||
# if caller not set bias, set bias automatically.
|
if use_transpose_conv:
|
||||||
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
# Only "zeros" padding mode is supported for ConvTranspose2d
|
||||||
|
conv_kwargs["padding_mode"] = "zeros"
|
||||||
|
conv = nn.ConvTranspose2d
|
||||||
|
else:
|
||||||
|
conv = nn.Conv2d
|
||||||
|
|
||||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
if pre_activation:
|
||||||
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
|
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
||||||
self.activation = _activation(activation_type)
|
self.activation = _activation(activation_type, inplace=False)
|
||||||
|
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
|
||||||
|
else:
|
||||||
|
# if caller not set bias, set bias automatically.
|
||||||
|
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
||||||
|
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
|
||||||
|
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
|
||||||
|
self.activation = _activation(activation_type)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if self.pre_activation:
|
||||||
|
return self.convolution(self.activation(self.normalization(x)))
|
||||||
return self.activation(self.normalization(self.convolution(x)))
|
return self.activation(self.normalization(self.convolution(x)))
|
||||||
|
|
||||||
|
|
||||||
class ReverseConv2dBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels: int, out_channels: int,
|
|
||||||
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
|
||||||
self.activation = _activation(activation_type, inplace=False)
|
|
||||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.convolution(self.activation(self.normalization(x)))
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, in_channels,
|
def __init__(self, in_channels,
|
||||||
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
|
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
|
||||||
@ -109,16 +111,15 @@ class ResidualBlock(nn.Module):
|
|||||||
|
|
||||||
self.learn_skip_connection = in_channels != out_channels
|
self.learn_skip_connection = in_channels != out_channels
|
||||||
|
|
||||||
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
|
|
||||||
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
|
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
|
||||||
additional_norm_kwargs=additional_norm_kwargs,
|
additional_norm_kwargs=additional_norm_kwargs, pre_activation=pre_activation,
|
||||||
padding_mode=padding_mode)
|
padding_mode=padding_mode)
|
||||||
|
|
||||||
self.conv1 = conv_block(in_channels, in_channels, **conv_param)
|
self.conv1 = Conv2dBlock(in_channels, in_channels, **conv_param)
|
||||||
self.conv2 = conv_block(in_channels, out_channels, **conv_param)
|
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
|
||||||
|
|
||||||
if self.learn_skip_connection:
|
if self.learn_skip_connection:
|
||||||
self.res_conv = conv_block(in_channels, out_channels, **conv_param)
|
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
res = x if not self.learn_skip_connection else self.res_conv(x)
|
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from model import MODEL
|
||||||
from model.base.module import Conv2dBlock, ResidualBlock
|
from model.base.module import Conv2dBlock, ResidualBlock
|
||||||
|
|
||||||
|
|
||||||
@ -20,7 +21,7 @@ class Encoder(nn.Module):
|
|||||||
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
|
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
|
||||||
sequence.append(Conv2dBlock(
|
sequence.append(Conv2dBlock(
|
||||||
multiple_prev * base_channels, multiple_now * base_channels,
|
multiple_prev * base_channels, multiple_now * base_channels,
|
||||||
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode=padding_mode,
|
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode="zeros",
|
||||||
activation_type=activation_type, norm_type=down_conv_norm_type
|
activation_type=activation_type, norm_type=down_conv_norm_type
|
||||||
))
|
))
|
||||||
self.out_channels = multiple_now * base_channels
|
self.out_channels = multiple_now * base_channels
|
||||||
@ -43,7 +44,7 @@ class Decoder(nn.Module):
|
|||||||
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
|
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
|
||||||
activation_type="ReLU", padding_mode='reflect',
|
activation_type="ReLU", padding_mode='reflect',
|
||||||
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
||||||
res_norm_type="AdaIN", pre_activation=False):
|
res_norm_type="AdaIN", pre_activation=False, use_transpose_conv=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.residual_blocks = nn.ModuleList([
|
self.residual_blocks = nn.ModuleList([
|
||||||
ResidualBlock(
|
ResidualBlock(
|
||||||
@ -57,13 +58,23 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
sequence = list()
|
sequence = list()
|
||||||
channels = in_channels
|
channels = in_channels
|
||||||
|
padding = (up_conv_kernel_size - 1) // 2
|
||||||
for i in range(num_up_sampling):
|
for i in range(num_up_sampling):
|
||||||
sequence.append(nn.Sequential(
|
if use_transpose_conv:
|
||||||
nn.Upsample(scale_factor=2),
|
sequence.append(Conv2dBlock(
|
||||||
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
|
channels, channels // 2, kernel_size=up_conv_kernel_size, stride=2,
|
||||||
padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode,
|
padding=padding, output_padding=padding,
|
||||||
activation_type=activation_type, norm_type=up_conv_norm_type),
|
padding_mode=padding_mode,
|
||||||
))
|
activation_type=activation_type, norm_type=up_conv_norm_type,
|
||||||
|
use_transpose_conv=True
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
sequence.append(nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=2),
|
||||||
|
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
|
||||||
|
padding=padding, padding_mode=padding_mode,
|
||||||
|
activation_type=activation_type, norm_type=up_conv_norm_type),
|
||||||
|
))
|
||||||
channels = channels // 2
|
channels = channels // 2
|
||||||
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
|
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
|
||||||
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
|
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
|
||||||
@ -74,3 +85,67 @@ class Decoder(nn.Module):
|
|||||||
for i, blk in enumerate(self.residual_blocks):
|
for i, blk in enumerate(self.residual_blocks):
|
||||||
x = blk(x)
|
x = blk(x)
|
||||||
return self.up_sequence(x)
|
return self.up_sequence(x)
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL.register_module("CycleGAN-Generator")
|
||||||
|
class Generator(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, activation_type="ReLU",
|
||||||
|
padding_mode='reflect', norm_type="IN", pre_activation=False, use_transpose_conv=True):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = Encoder(in_channels, base_channels, num_conv=2, num_res=num_blocks,
|
||||||
|
padding_mode=padding_mode, activation_type=activation_type,
|
||||||
|
down_conv_norm_type=norm_type, res_norm_type=norm_type, pre_activation=pre_activation)
|
||||||
|
self.decoder = Decoder(self.encoder.out_channels, out_channels, num_up_sampling=2, num_residual_blocks=0,
|
||||||
|
padding_mode=padding_mode, activation_type=activation_type,
|
||||||
|
up_conv_kernel_size=3, up_conv_norm_type=norm_type,
|
||||||
|
pre_activation=pre_activation, use_transpose_conv=use_transpose_conv)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.decoder(self.encoder(x))
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL.register_module("PatchDiscriminator")
|
||||||
|
class PatchDiscriminator(nn.Module):
|
||||||
|
def __init__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False,
|
||||||
|
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
|
||||||
|
super().__init__()
|
||||||
|
self.need_intermediate_feature = need_intermediate_feature
|
||||||
|
kernel_size = 4
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
sequence = [Conv2dBlock(
|
||||||
|
in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding, padding_mode=padding_mode,
|
||||||
|
activation_type=activation_type, norm_type=norm_type
|
||||||
|
)]
|
||||||
|
|
||||||
|
multiple_now = 1
|
||||||
|
for i in range(1, num_conv):
|
||||||
|
multiple_prev = multiple_now
|
||||||
|
multiple_now = min(2 ** i, 2 ** 3)
|
||||||
|
stride = 1 if i == num_conv - 1 else 2
|
||||||
|
sequence.append(Conv2dBlock(
|
||||||
|
multiple_prev * base_channels, multiple_now * base_channels,
|
||||||
|
kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode,
|
||||||
|
activation_type=activation_type, norm_type=norm_type
|
||||||
|
))
|
||||||
|
sequence.append(nn.Conv2d(
|
||||||
|
base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding, padding_mode=padding_mode))
|
||||||
|
if self.need_intermediate_feature:
|
||||||
|
self.sequence = nn.ModuleList(sequence)
|
||||||
|
else:
|
||||||
|
self.sequence = nn.Sequential(*sequence)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.need_intermediate_feature:
|
||||||
|
intermediate_feature = []
|
||||||
|
for layer in self.sequence:
|
||||||
|
x = layer(x)
|
||||||
|
intermediate_feature.append(x)
|
||||||
|
return tuple(intermediate_feature)
|
||||||
|
else:
|
||||||
|
return self.sequence(x)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
g = Generator(**dict(in_channels=3, out_channels=3))
|
||||||
|
print(g)
|
||||||
|
pd = PatchDiscriminator(**dict(in_channels=3, base_channels=64, num_conv=4))
|
||||||
|
print(pd)
|
||||||
@ -1,8 +1,12 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from model.base.module import ResidualBlock, ReverseConv2dBlock, Conv2dBlock
|
from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
|
||||||
|
|
||||||
|
|
||||||
class StyleEncoder(nn.Module):
|
class StyleEncoder(nn.Module):
|
||||||
@ -33,6 +37,92 @@ class StyleEncoder(nn.Module):
|
|||||||
return self.fc_avg(x), self.fc_var(x)
|
return self.fc_avg(x), self.fc_var(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ImprovedSPADEGenerator(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, output_size, have_style_input, style_dim, start_size=(4, 4),
|
||||||
|
base_channels=64, padding_mode='reflect', activation_type="LeakyReLU", pre_activation=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert output_size in (128, 256, 512, 1024)
|
||||||
|
self.output_size = output_size
|
||||||
|
|
||||||
|
kernel_size = 3
|
||||||
|
|
||||||
|
if have_style_input:
|
||||||
|
self.style_converter = nn.Sequential(
|
||||||
|
LinearBlock(style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
|
||||||
|
LinearBlock(2 * style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
|
||||||
|
)
|
||||||
|
|
||||||
|
base_conv = partial(
|
||||||
|
Conv2dBlock,
|
||||||
|
pre_activation=pre_activation, activation_type=activation_type,
|
||||||
|
norm_type="AdaIN" if have_style_input else "NONE",
|
||||||
|
kernel_size=kernel_size, padding=(kernel_size - 1) // 2, padding_mode=padding_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
base_residual_block = partial(
|
||||||
|
ResidualBlock,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
activation_type=activation_type,
|
||||||
|
norm_type="SPADE",
|
||||||
|
pre_activation=True,
|
||||||
|
additional_norm_kwargs=dict(
|
||||||
|
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
|
||||||
|
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence = OrderedDict()
|
||||||
|
channels = (2 ** 4) * base_channels
|
||||||
|
sequence["block_head"] = nn.Sequential(OrderedDict([
|
||||||
|
("conv_input", base_conv(in_channels=in_channels, out_channels=channels)),
|
||||||
|
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
|
||||||
|
("res_a", base_residual_block(in_channels=channels, out_channels=channels)),
|
||||||
|
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
|
||||||
|
("up", nn.Upsample(scale_factor=2, mode='nearest'))
|
||||||
|
]))
|
||||||
|
|
||||||
|
for i in range(4, 9 - min(int(math.log(self.output_size, 2)), 8), -1):
|
||||||
|
channels = (2 ** (i - 1)) * base_channels
|
||||||
|
sequence[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
|
||||||
|
("res_a", base_residual_block(in_channels=channels * 2, out_channels=channels)),
|
||||||
|
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
|
||||||
|
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
|
||||||
|
("up", nn.Upsample(scale_factor=2, mode='nearest'))
|
||||||
|
]))
|
||||||
|
self.sequence = nn.Sequential(sequence)
|
||||||
|
# channels = 2*base_channels when output size is 256, 512, 1024
|
||||||
|
# channels = 5*base_channels when output size is 128
|
||||||
|
out_modules = OrderedDict()
|
||||||
|
out_modules["out_1"] = nn.Sequential(
|
||||||
|
Conv2dBlock(
|
||||||
|
channels, out_channels, kernel_size=5, stride=1, padding=2,
|
||||||
|
pre_activation=pre_activation,
|
||||||
|
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
|
||||||
|
),
|
||||||
|
nn.Tanh()
|
||||||
|
)
|
||||||
|
for i in range(int(math.log(self.output_size, 2)) - 8):
|
||||||
|
channels = channels // 2
|
||||||
|
out_modules[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
|
||||||
|
("res_a", base_residual_block(in_channels=2 * channels, out_channels=channels)),
|
||||||
|
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
|
||||||
|
("up", nn.Upsample(scale_factor=2, mode='nearest'))
|
||||||
|
]))
|
||||||
|
out_modules[f"out_{i + 2}"] = nn.Sequential(
|
||||||
|
Conv2dBlock(
|
||||||
|
channels, out_channels, kernel_size=5, stride=1, padding=2,
|
||||||
|
pre_activation=pre_activation,
|
||||||
|
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
|
||||||
|
),
|
||||||
|
nn.Tanh()
|
||||||
|
)
|
||||||
|
self.out_modules = nn.ModuleDict(out_modules)
|
||||||
|
|
||||||
|
def forward(self, seg, style=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SPADEGenerator(nn.Module):
|
class SPADEGenerator(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
|
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
|
||||||
padding_mode='reflect', activation_type="LeakyReLU"):
|
padding_mode='reflect', activation_type="LeakyReLU"):
|
||||||
@ -89,7 +179,8 @@ class SPADEGenerator(nn.Module):
|
|||||||
x = blk(x)
|
x = blk(x)
|
||||||
return self.output_converter(x)
|
return self.output_converter(x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
g = SPADEGenerator(3, 3, 7, False, 256)
|
g = SPADEGenerator(3, 3, 7, False, 256)
|
||||||
print(g)
|
print(g)
|
||||||
print(g(torch.randn(2, 3, 256, 256)).size())
|
print(g(torch.randn(2, 3, 256, 256)).size())
|
||||||
|
|||||||
@ -6,19 +6,6 @@ from model.base.module import Conv2dBlock, LinearBlock
|
|||||||
from model.image_translation.CycleGAN import Encoder, Decoder
|
from model.image_translation.CycleGAN import Encoder, Decoder
|
||||||
|
|
||||||
|
|
||||||
class RhoClipper(object):
|
|
||||||
def __init__(self, clip_min, clip_max):
|
|
||||||
self.clip_min = clip_min
|
|
||||||
self.clip_max = clip_max
|
|
||||||
assert clip_min < clip_max
|
|
||||||
|
|
||||||
def __call__(self, module):
|
|
||||||
if hasattr(module, 'rho'):
|
|
||||||
w = module.rho.data
|
|
||||||
w = w.clamp(self.clip_min, self.clip_max)
|
|
||||||
module.rho.data = w
|
|
||||||
|
|
||||||
|
|
||||||
class CAMClassifier(nn.Module):
|
class CAMClassifier(nn.Module):
|
||||||
def __init__(self, in_channels, activation_type="ReLU"):
|
def __init__(self, in_channels, activation_type="ReLU"):
|
||||||
super(CAMClassifier, self).__init__()
|
super(CAMClassifier, self).__init__()
|
||||||
|
|||||||
@ -1,76 +0,0 @@
|
|||||||
import functools
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def select_norm_layer(norm_type):
|
|
||||||
if norm_type == "BN":
|
|
||||||
return functools.partial(nn.BatchNorm2d)
|
|
||||||
elif norm_type == "IN":
|
|
||||||
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
|
||||||
elif norm_type == "LN":
|
|
||||||
return functools.partial(LayerNorm2d, affine=True)
|
|
||||||
elif norm_type == "NONE":
|
|
||||||
return lambda num_features: nn.Identity()
|
|
||||||
elif norm_type == "AdaIN":
|
|
||||||
return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False)
|
|
||||||
else:
|
|
||||||
raise NotImplemented(f'normalization layer {norm_type} is not found')
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm2d(nn.Module):
|
|
||||||
def __init__(self, num_features, eps: float = 1e-5, affine: bool = True):
|
|
||||||
super().__init__()
|
|
||||||
self.num_features = num_features
|
|
||||||
self.eps = eps
|
|
||||||
self.affine = affine
|
|
||||||
if self.affine:
|
|
||||||
self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
|
||||||
self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
if self.affine:
|
|
||||||
nn.init.uniform_(self.channel_gamma)
|
|
||||||
nn.init.zeros_(self.channel_beta)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
|
||||||
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
|
|
||||||
if self.affine:
|
|
||||||
return self.channel_gamma * x + self.channel_beta
|
|
||||||
return x
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})"
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveInstanceNorm2d(nn.Module):
|
|
||||||
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
|
|
||||||
affine: bool = False, track_running_stats: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
self.num_features = num_features
|
|
||||||
self.affine = affine
|
|
||||||
self.track_running_stats = track_running_stats
|
|
||||||
self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats)
|
|
||||||
|
|
||||||
self.gamma = None
|
|
||||||
self.beta = None
|
|
||||||
self.have_set_style = False
|
|
||||||
|
|
||||||
def set_style(self, style):
|
|
||||||
style = style.view(*style.size(), 1, 1)
|
|
||||||
self.gamma, self.beta = style.chunk(2, 1)
|
|
||||||
self.have_set_style = True
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
assert self.have_set_style
|
|
||||||
x = self.norm(x)
|
|
||||||
x = self.gamma * x + self.beta
|
|
||||||
self.have_set_style = False
|
|
||||||
return x
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
|
||||||
f"affine={self.affine}, track_running_stats={self.track_running_stats})"
|
|
||||||
@ -1,8 +1,10 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from omegaconf.dictconfig import DictConfig
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from types import ModuleType
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
|
||||||
|
|
||||||
class _Registry:
|
class _Registry:
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
@ -136,8 +138,11 @@ class Registry(_Registry):
|
|||||||
if module_name is None:
|
if module_name is None:
|
||||||
module_name = module_class.__name__
|
module_name = module_class.__name__
|
||||||
if not force and module_name in self._module_dict:
|
if not force and module_name in self._module_dict:
|
||||||
raise KeyError(f'{module_name} is already registered '
|
if self._module_dict[module_name] == module_class:
|
||||||
f'in {self.name}')
|
warnings.warn(f'{module_name} is already registered in {self.name}, but is the same class')
|
||||||
|
return
|
||||||
|
raise KeyError(f'{module_name}:{self._module_dict[module_name]} is already registered in {self.name}'
|
||||||
|
f'so {module_class} can not be registered')
|
||||||
self._module_dict[module_name] = module_class
|
self._module_dict[module_name] = module_class
|
||||||
|
|
||||||
def register_module(self, name=None, force=False, module=None):
|
def register_module(self, name=None, force=False, module=None):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user