diff --git a/.idea/deployment.xml b/.idea/deployment.xml index d56324a..8ccfb5e 100644 --- a/.idea/deployment.xml +++ b/.idea/deployment.xml @@ -1,6 +1,6 @@ - + diff --git a/configs/synthesizers/CyCleGAN.yml b/configs/synthesizers/CyCleGAN.yml index 21e5eb4..f72bc6c 100644 --- a/configs/synthesizers/CyCleGAN.yml +++ b/configs/synthesizers/CyCleGAN.yml @@ -1,4 +1,4 @@ -name: selfie2anime-cycleGAN +name: huawei-cycylegan-7 engine: CycleGAN result_dir: ./result max_pairs: 1000000 @@ -27,18 +27,33 @@ model: out_channels: 3 base_channels: 64 num_blocks: 9 + use_transpose_conv: False + pre_activation: True +# discriminator: +# _type: MultiScaleDiscriminator +# _add_spectral_norm: True +# num_scale: 2 +# down_sample_method: "bilinear" +# discriminator_cfg: +# _type: PatchDiscriminator +# in_channels: 3 +# base_channels: 64 +# num_conv: 4 +# need_intermediate_feature: True discriminator: _type: PatchDiscriminator _add_spectral_norm: True in_channels: 3 base_channels: 64 num_conv: 4 + need_intermediate_feature: False + loss: gan: - loss_type: lsgan + loss_type: hinge weight: 1.0 - real_label_val: 1.0 + real_label_val: 1 fake_label_val: 0.0 cycle: level: 1 @@ -47,17 +62,22 @@ loss: level: 1 weight: 10.0 mgc: - weight: 5 + weight: 1 + fm: + weight: 0 + edge: + weight: 0 + hed_pretrained_model_path: ./network-bsds500.pytorch optimizers: generator: _type: Adam - lr: 0.0001 + lr: 1e-4 betas: [ 0.5, 0.999 ] weight_decay: 0.0001 discriminator: _type: Adam - lr: 1e-4 + lr: 4e-4 betas: [ 0.5, 0.999 ] weight_decay: 0.0001 @@ -75,10 +95,21 @@ data: drop_last: True dataset: _type: GenerationUnpairedDataset - root_a: "/data/i2i/selfie2anime/trainA" - root_b: "/data/i2i/selfie2anime/trainB" + root_a: "/data/face2cartoon/all_face" + root_b: "/data/selfie2anime/trainB/" random_pair: True - pipeline: + pipeline_a: + - Load + - RandomCrop: + size: [ 178, 178 ] + - Resize: + size: [ 256, 256 ] + - RandomHorizontalFlip + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + pipeline_b: - Load - Resize: size: [ 286, 286 ] @@ -99,10 +130,18 @@ data: drop_last: False dataset: _type: GenerationUnpairedDataset - root_a: "/data/i2i/selfie2anime/testA" - root_b: "/data/i2i/selfie2anime/testB" - random_pair: False - pipeline: + root_a: "/data/face2cartoon/test/human" + root_b: "/data/face2cartoon/test/anime" + random_pair: True + pipeline_a: + - Load + - Resize: + size: [ 256, 256 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + pipeline_b: - Load - Resize: size: [ 256, 256 ] diff --git a/data/dataset/image_translation.py b/data/dataset/image_translation.py index be22098..1614b71 100644 --- a/data/dataset/image_translation.py +++ b/data/dataset/image_translation.py @@ -38,9 +38,9 @@ class SingleFolderDataset(Dataset): @DATASET.register_module() class GenerationUnpairedDataset(Dataset): - def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False): - self.A = SingleFolderDataset(root_a, pipeline, with_path) - self.B = SingleFolderDataset(root_b, pipeline, with_path) + def __init__(self, root_a, root_b, random_pair, pipeline_a, pipeline_b, with_path=False): + self.A = SingleFolderDataset(root_a, pipeline_a, with_path) + self.B = SingleFolderDataset(root_b, pipeline_b, with_path) self.with_path = with_path self.random_pair = random_pair diff --git a/engine/CycleGAN.py b/engine/CycleGAN.py index 72e484d..cd2b817 100644 --- a/engine/CycleGAN.py +++ b/engine/CycleGAN.py @@ -1,11 +1,13 @@ from itertools import chain +import ignite.distributed as idist import torch from engine.base.i2i import EngineKernel, run_kernel from engine.util.build import build_model from engine.util.container import GANImageBuffer, LossContainer -from engine.util.loss import pixel_loss, gan_loss +from engine.util.loss import pixel_loss, gan_loss, feature_match_loss +from loss.I2I.edge_loss import EdgeLoss from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss from model.weight_init import generation_init_weights @@ -17,7 +19,10 @@ class CycleGANEngineKernel(EngineKernel): self.gan_loss = gan_loss(config.loss.gan) self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level)) self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level)) - self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss()) + self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite")) + self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same")) + self.edge_loss = LossContainer(config.loss.edge.weight, EdgeLoss( + "HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(idist.device())) self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in self.discriminators.keys()} @@ -64,8 +69,12 @@ class CycleGANEngineKernel(EngineKernel): loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph]) loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph]) loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph]) - loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss( - self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"]), True) + prediction_fake = self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"]) + loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True) + if self.fm_loss.weight > 0: + prediction_real = self.discriminators[ph](batch[ph]) + loss[f"feature_match_{ph}"] = self.fm_loss(prediction_fake, prediction_real) + loss[f"edge_{ph}"] = self.edge_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph], gt_is_edge=False) return loss def criterion_discriminators(self, batch, generated) -> dict: diff --git a/engine/base/i2i.py b/engine/base/i2i.py index d9af31f..8fef969 100644 --- a/engine/base/i2i.py +++ b/engine/base/i2i.py @@ -101,9 +101,12 @@ class EngineKernel(object): def _remove_no_grad_loss(loss_dict): + need_to_pop = [] for k in loss_dict: if not isinstance(loss_dict[k], torch.Tensor): - loss_dict.pop(k) + need_to_pop.append(k) + for k in need_to_pop: + loss_dict.pop(k) return loss_dict diff --git a/engine/util/loss.py b/engine/util/loss.py index 94e5e5d..70f3c84 100644 --- a/engine/util/loss.py +++ b/engine/util/loss.py @@ -23,3 +23,19 @@ def mse_loss(x, target_flag): def bce_loss(x, target_flag): return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) + + +def feature_match_loss(level, weight_policy): + compare_loss = pixel_loss(level) + assert weight_policy in ["same", "exponential_decline"] + + def fm_loss(generated_features, target_features): + num_scale = len(generated_features) + loss = 0 + for s_i in range(num_scale): + for i in range(len(generated_features[s_i]) - 1): + weight = 1 if weight_policy == "same" else 2 ** i + loss += weight * compare_loss(generated_features[s_i][i], target_features[s_i][i].detach()) / num_scale + return loss + + return fm_loss diff --git a/loss/I2I/minimal_geometry_distortion_constraint_loss.py b/loss/I2I/minimal_geometry_distortion_constraint_loss.py index c972990..ffc3e65 100644 --- a/loss/I2I/minimal_geometry_distortion_constraint_loss.py +++ b/loss/I2I/minimal_geometry_distortion_constraint_loss.py @@ -105,10 +105,12 @@ class MGCLoss(nn.Module): Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ """ - def __init__(self, beta=0.5, lambda_=0.05, device=idist.device()): + def __init__(self, mi_to_loss_way="opposite", beta=0.5, lambda_=0.05, device=idist.device()): super().__init__() self.beta = beta self.lambda_ = lambda_ + assert mi_to_loss_way in ["opposite", "reciprocal"] + self.mi_to_loss_way = mi_to_loss_way mu_y, mu_x = torch.meshgrid([torch.arange(-1, 1.25, 0.25), torch.arange(-1, 1.25, 0.25)]) self.mu_x = mu_x.flatten().to(device) self.mu_y = mu_y.flatten().to(device) @@ -134,6 +136,8 @@ class MGCLoss(nn.Module): def forward(self, fake, real): rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_, self.R) + if self.mi_to_loss_way == "reciprocal": + return 1/rSMI.mean() return -rSMI.mean() diff --git a/loss/gan.py b/loss/gan.py index 5e30bc4..e8f05c0 100644 --- a/loss/gan.py +++ b/loss/gan.py @@ -1,5 +1,6 @@ import torch.nn as nn import torch.nn.functional as F +import torch class GANLoss(nn.Module): @@ -10,7 +11,7 @@ class GANLoss(nn.Module): self.fake_label_val = fake_label_val self.loss_type = loss_type - def forward(self, prediction, target_is_real: bool, is_discriminator=False): + def single_forward(self, prediction, target_is_real: bool, is_discriminator=False): """ gan loss forward :param prediction: network prediction @@ -37,3 +38,14 @@ class GANLoss(nn.Module): return loss else: raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.') + + def forward(self, prediction, target_is_real: bool, is_discriminator=False): + if isinstance(prediction, torch.Tensor): + # origin + return self.single_forward(prediction, target_is_real, is_discriminator) + elif isinstance(prediction, list): + # for multi scale discriminator, e.g. MultiScaleDiscriminator + loss = 0 + for p in prediction: + loss += self.single_forward(p[-1], target_is_real, is_discriminator) + return loss diff --git a/model/__init__.py b/model/__init__.py index 386a519..79ffd63 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -2,3 +2,4 @@ from model.registry import MODEL, NORMALIZATION import model.base.normalization import model.image_translation.UGATIT import model.image_translation.CycleGAN +import model.image_translation.pix2pixHD diff --git a/model/image_translation/pix2pixHD.py b/model/image_translation/pix2pixHD.py new file mode 100644 index 0000000..d2be64d --- /dev/null +++ b/model/image_translation/pix2pixHD.py @@ -0,0 +1,29 @@ +import torch.nn as nn +import torch.nn.functional as F + +from model import MODEL + + +@MODEL.register_module() +class MultiScaleDiscriminator(nn.Module): + def __init__(self, num_scale, discriminator_cfg, down_sample_method="avg"): + super().__init__() + assert down_sample_method in ["avg", "bilinear"] + self.down_sample_method = down_sample_method + + self.discriminator_list = nn.ModuleList([ + MODEL.build_with(discriminator_cfg) for _ in range(num_scale) + ]) + + def down_sample(self, x): + if self.down_sample_method == "avg": + return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False) + if self.down_sample_method == "bilinear": + return F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True) + + def forward(self, x): + results = [] + for discriminator in self.discriminator_list: + results.append(discriminator(x)) + x = self.down_sample(x) + return results diff --git a/util/registry.py b/util/registry.py index c9c1a28..5461787 100644 --- a/util/registry.py +++ b/util/registry.py @@ -53,11 +53,9 @@ class _Registry: else: raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}') - for k in args: - assert isinstance(k, str) - if k.startswith("_"): - warnings.warn(f"got param start with `_`: {k}, will remove it") - args.pop(k) + for invalid_key in [k for k in args.keys() if k.startswith("_")]: + warnings.warn(f"got param start with `_`: {invalid_key}, will remove it") + args.pop(invalid_key) if not (isinstance(default_args, dict) or default_args is None): raise TypeError('default_args must be a dict or None, '