From 0019d4034c578da5c6b27b936f61fed498bf29bc Mon Sep 17 00:00:00 2001 From: budui Date: Wed, 14 Oct 2020 18:55:51 +0800 Subject: [PATCH] change a lot --- configs/synthesizers/CyCleGAN.yml | 55 ++++-- configs/synthesizers/UGATIT.yml | 4 +- engine/{CyCleGAN.py => CycleGAN.py} | 37 ++-- engine/U-GAT-IT.py | 31 ++-- engine/util/loss.py | 25 +++ ...mal_geometry_distortion_constraint_loss.py | 164 +++++++++++++++--- model/__init__.py | 3 +- model/base/module.py | 7 +- model/image_translation/CycleGAN.py | 16 +- model/image_translation/UGATIT.py | 13 -- util/registry.py | 15 +- 11 files changed, 261 insertions(+), 109 deletions(-) rename engine/{CyCleGAN.py => CycleGAN.py} (72%) create mode 100644 engine/util/loss.py diff --git a/configs/synthesizers/CyCleGAN.yml b/configs/synthesizers/CyCleGAN.yml index be3c4b8..21e5eb4 100644 --- a/configs/synthesizers/CyCleGAN.yml +++ b/configs/synthesizers/CyCleGAN.yml @@ -1,34 +1,38 @@ -name: horse2zebra-CyCleGAN -engine: CyCleGAN +name: selfie2anime-cycleGAN +engine: CycleGAN result_dir: ./result -max_pairs: 266800 +max_pairs: 1000000 misc: random_seed: 324 handler: - clear_cuda_cache: False + clear_cuda_cache: True set_epoch_for_dist_sampler: True checkpoint: epoch_interval: 1 # checkpoint once per `epoch_interval` epoch n_saved: 2 tensorboard: scalar: 100 # log scalar `scalar` times per epoch - image: 2 # log image `image` times per epoch + image: 4 # log image `image` times per epoch + test: + random: True + images: 10 model: generator: - _type: CyCle-Generator + _type: CycleGAN-Generator + _add_spectral_norm: True in_channels: 3 out_channels: 3 base_channels: 64 num_blocks: 9 - padding_mode: reflect - norm_type: IN discriminator: _type: PatchDiscriminator + _add_spectral_norm: True in_channels: 3 base_channels: 64 + num_conv: 4 loss: gan: @@ -41,17 +45,21 @@ loss: weight: 10.0 id: level: 1 - weight: 0 + weight: 10.0 + mgc: + weight: 5 optimizers: generator: _type: Adam - lr: 2e-4 + lr: 0.0001 betas: [ 0.5, 0.999 ] + weight_decay: 0.0001 discriminator: _type: Adam - lr: 2e-4 + lr: 1e-4 betas: [ 0.5, 0.999 ] + weight_decay: 0.0001 data: train: @@ -60,15 +68,15 @@ data: target_lr: 0 buffer_size: 50 dataloader: - batch_size: 6 + batch_size: 1 shuffle: True num_workers: 2 pin_memory: True drop_last: True dataset: _type: GenerationUnpairedDataset - root_a: "/data/i2i/horse2zebra/trainA" - root_b: "/data/i2i/horse2zebra/trainB" + root_a: "/data/i2i/selfie2anime/trainA" + root_b: "/data/i2i/selfie2anime/trainB" random_pair: True pipeline: - Load @@ -82,16 +90,17 @@ data: mean: [ 0.5, 0.5, 0.5 ] std: [ 0.5, 0.5, 0.5 ] test: + which: video_dataset dataloader: - batch_size: 4 + batch_size: 1 shuffle: False num_workers: 1 pin_memory: False drop_last: False dataset: _type: GenerationUnpairedDataset - root_a: "/data/i2i/horse2zebra/testA" - root_b: "/data/i2i/horse2zebra/testB" + root_a: "/data/i2i/selfie2anime/testA" + root_b: "/data/i2i/selfie2anime/testB" random_pair: False pipeline: - Load @@ -101,3 +110,15 @@ data: - Normalize: mean: [ 0.5, 0.5, 0.5 ] std: [ 0.5, 0.5, 0.5 ] + video_dataset: + _type: SingleFolderDataset + root: "/data/i2i/VoxCeleb2Anime/test_video_frames/" + with_path: True + pipeline: + - Load + - Resize: + size: [ 256, 256 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] diff --git a/configs/synthesizers/UGATIT.yml b/configs/synthesizers/UGATIT.yml index bf1eb1a..68f6228 100644 --- a/configs/synthesizers/UGATIT.yml +++ b/configs/synthesizers/UGATIT.yml @@ -78,7 +78,7 @@ data: target_lr: 0 buffer_size: 50 dataloader: - batch_size: 4 + batch_size: 1 shuffle: True num_workers: 2 pin_memory: True @@ -102,7 +102,7 @@ data: test: which: video_dataset dataloader: - batch_size: 8 + batch_size: 1 shuffle: False num_workers: 1 pin_memory: False diff --git a/engine/CyCleGAN.py b/engine/CycleGAN.py similarity index 72% rename from engine/CyCleGAN.py rename to engine/CycleGAN.py index 96d6db4..72e484d 100644 --- a/engine/CyCleGAN.py +++ b/engine/CycleGAN.py @@ -1,26 +1,23 @@ from itertools import chain -import ignite.distributed as idist import torch -import torch.nn as nn -from omegaconf import OmegaConf from engine.base.i2i import EngineKernel, run_kernel from engine.util.build import build_model -from loss.gan import GANLoss -from model.GAN.base import GANImageBuffer +from engine.util.container import GANImageBuffer, LossContainer +from engine.util.loss import pixel_loss, gan_loss +from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss from model.weight_init import generation_init_weights -class TAFGEngineKernel(EngineKernel): +class CycleGANEngineKernel(EngineKernel): def __init__(self, config): super().__init__(config) - gan_loss_cfg = OmegaConf.to_container(config.loss.gan) - gan_loss_cfg.pop("weight") - self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) - self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() - self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss() + self.gan_loss = gan_loss(config.loss.gan) + self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level)) + self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level)) + self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss()) self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in self.discriminators.keys()} @@ -56,21 +53,19 @@ class TAFGEngineKernel(EngineKernel): images["b2a"] = self.generators["b2a"](batch["b"]) images["a2b2a"] = self.generators["b2a"](images["a2b"]) images["b2a2b"] = self.generators["a2b"](images["b2a"]) - if self.config.loss.id.weight > 0: + if self.id_loss.weight > 0: images["a2a"] = self.generators["b2a"](batch["a"]) images["b2b"] = self.generators["a2b"](batch["b"]) return images def criterion_generators(self, batch, generated) -> dict: loss = dict() - for phase in ["a2b", "b2a"]: - loss[f"cycle_{phase[0]}"] = self.config.loss.cycle.weight * self.cycle_loss( - generated[f"{phase}2{phase[0]}"], batch[phase[0]]) - loss[f"gan_{phase}"] = self.config.loss.gan.weight * self.gan_loss( - self.discriminators[phase[-1]](generated[phase]), True) - if self.config.loss.id.weight > 0: - loss[f"id_{phase[0]}"] = self.config.loss.id.weight * self.id_loss( - generated[f"{phase[0]}2{phase[0]}"], batch[phase[0]]) + for ph in "ab": + loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph]) + loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph]) + loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph]) + loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss( + self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"]), True) return loss def criterion_discriminators(self, batch, generated) -> dict: @@ -97,5 +92,5 @@ class TAFGEngineKernel(EngineKernel): def run(task, config, _): - kernel = TAFGEngineKernel(config) + kernel = CycleGANEngineKernel(config) run_kernel(task, config, kernel) diff --git a/engine/U-GAT-IT.py b/engine/U-GAT-IT.py index 3044c2c..957bfcb 100644 --- a/engine/U-GAT-IT.py +++ b/engine/U-GAT-IT.py @@ -1,38 +1,31 @@ -import ignite.distributed as idist import torch -import torch.nn as nn -import torch.nn.functional as F -from omegaconf import OmegaConf from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel from engine.util.build import build_model from engine.util.container import LossContainer +from engine.util.loss import bce_loss, mse_loss, pixel_loss, gan_loss from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss -from loss.gan import GANLoss -from model.image_translation.UGATIT import RhoClipper from util.image import attention_colored_map -def pixel_loss(level): - return nn.L1Loss() if level == 1 else nn.MSELoss() +class RhoClipper(object): + def __init__(self, clip_min, clip_max): + self.clip_min = clip_min + self.clip_max = clip_max + assert clip_min < clip_max - -def mse_loss(x, target_flag): - return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) - - -def bce_loss(x, target_flag): - return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) + def __call__(self, module): + if hasattr(module, 'rho'): + w = module.rho.data + w = w.clamp(self.clip_min, self.clip_max) + module.rho.data = w class UGATITEngineKernel(EngineKernel): def __init__(self, config): super().__init__(config) - gan_loss_cfg = OmegaConf.to_container(config.loss.gan) - gan_loss_cfg.pop("weight") - self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) - + self.gan_loss = gan_loss(config.loss.gan) self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level)) self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss()) self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level)) diff --git a/engine/util/loss.py b/engine/util/loss.py new file mode 100644 index 0000000..94e5e5d --- /dev/null +++ b/engine/util/loss.py @@ -0,0 +1,25 @@ +import ignite.distributed as idist +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import OmegaConf + +from loss.gan import GANLoss + + +def gan_loss(config): + gan_loss_cfg = OmegaConf.to_container(config) + gan_loss_cfg.pop("weight") + return GANLoss(**gan_loss_cfg).to(idist.device()) + + +def pixel_loss(level): + return nn.L1Loss() if level == 1 else nn.MSELoss() + + +def mse_loss(x, target_flag): + return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) + + +def bce_loss(x, target_flag): + return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) diff --git a/loss/I2I/minimal_geometry_distortion_constraint_loss.py b/loss/I2I/minimal_geometry_distortion_constraint_loss.py index ed662c3..c972990 100644 --- a/loss/I2I/minimal_geometry_distortion_constraint_loss.py +++ b/loss/I2I/minimal_geometry_distortion_constraint_loss.py @@ -1,3 +1,4 @@ +import ignite.distributed as idist import torch import torch.nn as nn @@ -5,17 +6,59 @@ import torch.nn as nn def gaussian_radial_basis_function(x, mu, sigma): # (kernel_size) -> (batch_size, kernel_size, c*h*w) mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3)) - mu = mu.to(x.device) # (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w) x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1) return torch.exp((x - mu).pow(2) / (2 * sigma ** 2)) +class ImporveMyLoss(torch.nn.Module): + def __init__(self, device=idist.device()): + super().__init__() + mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).to(device) + self.x_mu_list = mu.repeat(9).view(-1, 81) + self.y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81) + self.R = torch.eye(81).to(device) + + def batch_ERSMI(self, I1, I2): + batch_size = I1.shape[0] + img_size = I1.shape[1] * I1.shape[2] * I1.shape[3] + if I2.shape[1] == 1 and I1.shape[1] != 1: + I2 = I2.repeat(1, 3, 1, 1) + + def kernel_F(y, mu_list, sigma): + tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1) # [81, 784] + tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1) + tmp_y = tmp_mu - tmp_y + mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2)) + return mat_L + + mat_K = kernel_F(I1, self.x_mu_list, 1) + mat_L = kernel_F(I2, self.y_mu_list, 1) + mat_k_l = mat_K * mat_L + + H1 = (mat_K @ mat_K.transpose(1, 2)) * (mat_L @ mat_L.transpose(1, 2)) / (img_size ** 2) + h_hat = mat_k_l @ mat_k_l.transpose(1, 2) / img_size + small_h_hat = mat_K.sum(2).view(batch_size, -1, 1) * mat_L.sum(2).view(batch_size, -1, 1) / (img_size ** 2) + h_hat = 0.5 * H1 + 0.5 * h_hat + alpha = (h_hat + 0.05 * self.R).inverse() @ small_h_hat + + ersmi = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1 + + ersmi = -ersmi.squeeze().mean() + return ersmi + + def forward(self, fakeI, realI): + return self.batch_ERSMI(fakeI, realI) + + class MyLoss(torch.nn.Module): def __init__(self): super(MyLoss, self).__init__() def forward(self, fakeI, realI): + fakeI = fakeI.cuda() + realI = realI.cuda() + def batch_ERSMI(I1, I2): batch_size = I1.shape[0] img_size = I1.shape[1] * I1.shape[2] * I1.shape[3] @@ -49,6 +92,7 @@ class MyLoss(torch.nn.Module): alpha = alpha.matmul(h2) ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul( alpha) - 1).squeeze() + ersmi = -ersmi.mean() return ersmi @@ -61,16 +105,17 @@ class MGCLoss(nn.Module): Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ """ - def __init__(self, beta=0.5, lambda_=0.05): + def __init__(self, beta=0.5, lambda_=0.05, device=idist.device()): super().__init__() self.beta = beta self.lambda_ = lambda_ - mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]) - self.mu_x = mu.repeat(9) - self.mu_y = mu.unsqueeze(0).t().repeat(1, 9).view(-1) + mu_y, mu_x = torch.meshgrid([torch.arange(-1, 1.25, 0.25), torch.arange(-1, 1.25, 0.25)]) + self.mu_x = mu_x.flatten().to(device) + self.mu_y = mu_y.flatten().to(device) + self.R = torch.eye(81).unsqueeze(0).to(device) @staticmethod - def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_): + def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_, R): assert img1.size() == img2.size() num_pixel = img1.size(1) * img1.size(2) * img2.size(3) @@ -79,33 +124,102 @@ class MGCLoss(nn.Module): mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1) mat_k_mul_mat_l = mat_k * mat_l - h_hat = (1 - beta) * (mat_k_mul_mat_l.matmul(mat_k_mul_mat_l.transpose(1, 2))) / num_pixel - h_hat += beta * (mat_k.matmul(mat_k.transpose(1, 2)) * mat_l.matmul(mat_l.transpose(1, 2))) / (num_pixel ** 2) + h_hat = (1 - beta) * (mat_k_mul_mat_l @ mat_k_mul_mat_l.transpose(1, 2)) / num_pixel + h_hat += beta * ((mat_k @ mat_k.transpose(1, 2)) * (mat_l @ mat_l.transpose(1, 2))) / (num_pixel ** 2) small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2) - R = torch.eye(h_hat.size(1)).to(img1.device) - alpha = (h_hat + lambda_ * R).inverse().matmul(small_h_hat) - - rSMI = (2 * alpha.transpose(1, 2).matmul(small_h_hat)) - alpha.transpose(1, 2).matmul(h_hat).matmul(alpha) - 1 - return rSMI + alpha = (h_hat + lambda_ * R).inverse() @ small_h_hat + rSMI = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1 + return rSMI.squeeze() def forward(self, fake, real): - rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_) - return -rSMI.squeeze().mean() + rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_, self.R) + return -rSMI.mean() if __name__ == '__main__': - mg = MGCLoss().to("cuda") + mg = MGCLoss(device=torch.device("cpu")) + my = MyLoss().to("cuda") + imy = ImporveMyLoss() + from data.transform import transform_pipeline - def norm(x): - x -= x.min() - x /= x.max() - return (x - 0.5) * 2 + pipeline = transform_pipeline( + ['Load', 'ToTensor', {'Normalize': {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5]}}]) + img_a1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_1.jpg") + img_a2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_2.jpg") + img_a3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_3.jpg") + img_b1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_1.jpg") + img_b2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_2.jpg") + img_b3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_3.jpg") - x1 = norm(torch.randn(5, 3, 256, 256)) - x2 = norm(x1 * 2 + 1) - x3 = norm(torch.randn(5, 3, 256, 256)) - x4 = norm(torch.exp(x3)) - print(mg(x1, x1), mg(x1, x2), mg(x1, x3), mg(x1, x4)) + img_a1.requires_grad_(True) + img_a2.requires_grad_(True) + img_a3.requires_grad_(True) + + # print("MyLoss") + # l1 = my(img_a1.unsqueeze(0), img_b1.unsqueeze(0)) + # l2 = my(img_a2.unsqueeze(0), img_b2.unsqueeze(0)) + # l3 = my(img_a3.unsqueeze(0), img_b3.unsqueeze(0)) + # l = (l1+l2+l3)/3 + # l.backward() + # print(img_a1.grad[0][0][0:10]) + # print(img_a2.grad[0][0][0:10]) + # print(img_a3.grad[0][0][0:10]) + # + # img_a1.grad = None + # img_a2.grad = None + # img_a3.grad = None + # + # print("---") + # l = my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3])) + # l.backward() + # print(img_a1.grad[0][0][0:10]) + # print(img_a2.grad[0][0][0:10]) + # print(img_a3.grad[0][0][0:10]) + # img_a1.grad = None + # img_a2.grad = None + # img_a3.grad = None + + print("MGCLoss") + l1 = mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0)) + l2 = mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0)) + l3 = mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0)) + l = (l1 + l2 + l3) / 3 + l.backward() + print(img_a1.grad[0][0][0:10]) + print(img_a2.grad[0][0][0:10]) + print(img_a3.grad[0][0][0:10]) + + img_a1.grad = None + img_a2.grad = None + img_a3.grad = None + + print("---") + l = mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3])) + l.backward() + print(img_a1.grad[0][0][0:10]) + print(img_a2.grad[0][0][0:10]) + print(img_a3.grad[0][0][0:10]) + + # print("\nMGCLoss") + # mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0)) + # mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0)) + # mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0)) + # + # print("---") + # mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3])) + # + # import pprofile + # + # profiler = pprofile.Profile() + # with profiler: + # iter_times = 1000 + # for _ in range(iter_times): + # mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3])) + # for _ in range(iter_times): + # my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3])) + # for _ in range(iter_times): + # imy(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3])) + # profiler.print_stats() diff --git a/model/__init__.py b/model/__init__.py index 83051a6..386a519 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -1,3 +1,4 @@ from model.registry import MODEL, NORMALIZATION import model.base.normalization -import model.image_translation +import model.image_translation.UGATIT +import model.image_translation.CycleGAN diff --git a/model/base/module.py b/model/base/module.py index 7674256..91e4055 100644 --- a/model/base/module.py +++ b/model/base/module.py @@ -59,7 +59,12 @@ class Conv2dBlock(nn.Module): self.activation_type = activation_type self.pre_activation = pre_activation - conv = nn.ConvTranspose2d if use_transpose_conv else nn.Conv2d + if use_transpose_conv: + # Only "zeros" padding mode is supported for ConvTranspose2d + conv_kwargs["padding_mode"] = "zeros" + conv = nn.ConvTranspose2d + else: + conv = nn.Conv2d if pre_activation: self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs) diff --git a/model/image_translation/CycleGAN.py b/model/image_translation/CycleGAN.py index bf0484d..587977d 100644 --- a/model/image_translation/CycleGAN.py +++ b/model/image_translation/CycleGAN.py @@ -21,7 +21,7 @@ class Encoder(nn.Module): multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple) sequence.append(Conv2dBlock( multiple_prev * base_channels, multiple_now * base_channels, - kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode=padding_mode, + kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode="zeros", activation_type=activation_type, norm_type=down_conv_norm_type )) self.out_channels = multiple_now * base_channels @@ -62,7 +62,7 @@ class Decoder(nn.Module): for i in range(num_up_sampling): if use_transpose_conv: sequence.append(Conv2dBlock( - channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1, + channels, channels // 2, kernel_size=up_conv_kernel_size, stride=2, padding=padding, output_padding=padding, padding_mode=padding_mode, activation_type=activation_type, norm_type=up_conv_norm_type, @@ -90,7 +90,7 @@ class Decoder(nn.Module): @MODEL.register_module("CycleGAN-Generator") class Generator(nn.Module): def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, activation_type="ReLU", - padding_mode='reflect', norm_type="IN", pre_activation=True, use_transpose_conv=True): + padding_mode='reflect', norm_type="IN", pre_activation=False, use_transpose_conv=True): super().__init__() self.encoder = Encoder(in_channels, base_channels, num_conv=2, num_res=num_blocks, padding_mode=padding_mode, activation_type=activation_type, @@ -106,7 +106,7 @@ class Generator(nn.Module): @MODEL.register_module("PatchDiscriminator") class PatchDiscriminator(nn.Module): - def __int__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False, + def __init__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False, norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"): super().__init__() self.need_intermediate_feature = need_intermediate_feature @@ -118,7 +118,7 @@ class PatchDiscriminator(nn.Module): )] multiple_now = 1 - for i in range(1, num_conv + 1): + for i in range(1, num_conv): multiple_prev = multiple_now multiple_now = min(2 ** i, 2 ** 3) stride = 1 if i == num_conv - 1 else 2 @@ -143,3 +143,9 @@ class PatchDiscriminator(nn.Module): return tuple(intermediate_feature) else: return self.sequence(x) + +if __name__ == '__main__': + g = Generator(**dict(in_channels=3, out_channels=3)) + print(g) + pd = PatchDiscriminator(**dict(in_channels=3, base_channels=64, num_conv=4)) + print(pd) \ No newline at end of file diff --git a/model/image_translation/UGATIT.py b/model/image_translation/UGATIT.py index 9e4a7c1..290432d 100644 --- a/model/image_translation/UGATIT.py +++ b/model/image_translation/UGATIT.py @@ -6,19 +6,6 @@ from model.base.module import Conv2dBlock, LinearBlock from model.image_translation.CycleGAN import Encoder, Decoder -class RhoClipper(object): - def __init__(self, clip_min, clip_max): - self.clip_min = clip_min - self.clip_max = clip_max - assert clip_min < clip_max - - def __call__(self, module): - if hasattr(module, 'rho'): - w = module.rho.data - w = w.clamp(self.clip_min, self.clip_max) - module.rho.data = w - - class CAMClassifier(nn.Module): def __init__(self, in_channels, activation_type="ReLU"): super(CAMClassifier, self).__init__() diff --git a/util/registry.py b/util/registry.py index f6d6a1b..c9c1a28 100644 --- a/util/registry.py +++ b/util/registry.py @@ -1,8 +1,10 @@ import inspect -from omegaconf.dictconfig import DictConfig -from omegaconf import OmegaConf -from types import ModuleType import warnings +from types import ModuleType + +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig + class _Registry: def __init__(self, name): @@ -136,8 +138,11 @@ class Registry(_Registry): if module_name is None: module_name = module_class.__name__ if not force and module_name in self._module_dict: - raise KeyError(f'{module_name} is already registered ' - f'in {self.name}') + if self._module_dict[module_name] == module_class: + warnings.warn(f'{module_name} is already registered in {self.name}, but is the same class') + return + raise KeyError(f'{module_name}:{self._module_dict[module_name]} is already registered in {self.name}' + f'so {module_class} can not be registered') self._module_dict[module_name] = module_class def register_module(self, name=None, force=False, module=None):