From 0bec02bf6daea4651befcea0e90406ef33f9c984 Mon Sep 17 00:00:00 2001 From: budui Date: Fri, 23 Oct 2020 16:14:37 +0800 Subject: [PATCH] 23333 --- configs/synthesizers/GauGAN.yml | 167 ++++++++++++++++++++++++++++++ engine/GauGAN.py | 86 +++++++++++++++ engine/util/loss.py | 25 ++--- loss/gan.py | 20 +++- model/__init__.py | 1 + model/image_translation/GauGAN.py | 11 +- model/weight_init.py | 3 +- 7 files changed, 287 insertions(+), 26 deletions(-) create mode 100644 configs/synthesizers/GauGAN.yml create mode 100644 engine/GauGAN.py diff --git a/configs/synthesizers/GauGAN.yml b/configs/synthesizers/GauGAN.yml new file mode 100644 index 0000000..b4ddb1a --- /dev/null +++ b/configs/synthesizers/GauGAN.yml @@ -0,0 +1,167 @@ +name: huawei-GauGAN-3 +engine: GauGAN +result_dir: ./result +max_pairs: 1000000 + +misc: + random_seed: 324 + +handler: + clear_cuda_cache: True + set_epoch_for_dist_sampler: True + checkpoint: + epoch_interval: 1 # checkpoint once per `epoch_interval` epoch + n_saved: 2 + tensorboard: + scalar: 100 # log scalar `scalar` times per epoch + image: 4 # log image `image` times per epoch + test: + random: True + images: 10 + +model: + generator: + _type: SPADEGenerator + _add_spectral_norm: True + in_channels: 3 + out_channels: 3 + num_blocks: 7 + use_vae: False + num_z_dim: 256 +# discriminator: +# _type: MultiScaleDiscriminator +# _add_spectral_norm: True +# num_scale: 2 +# down_sample_method: "bilinear" +# discriminator_cfg: +# _type: PatchDiscriminator +# in_channels: 3 +# base_channels: 64 +# num_conv: 4 +# need_intermediate_feature: True + discriminator: + _type: PatchDiscriminator + _add_spectral_norm: True + in_channels: 3 + base_channels: 64 + num_conv: 4 + need_intermediate_feature: True + + +loss: + gan: + loss_type: hinge + weight: 1.0 + real_label_val: 1 + fake_label_val: 0.0 + perceptual: + layer_weights: + "1": 0.03125 + "6": 0.0625 + "11": 0.125 + "20": 0.25 + "29": 1 + criterion: 'L1' + style_loss: False + perceptual_loss: True + weight: 2 + mgc: + weight: 5 + fm: + weight: 5 + edge: + weight: 0 + hed_pretrained_model_path: ./network-bsds500.pytorch + +optimizers: + generator: + _type: Adam + lr: 1e-4 + betas: [ 0, 0.9 ] + weight_decay: 0.0001 + discriminator: + _type: Adam + lr: 4e-4 + betas: [ 0, 0.9 ] + weight_decay: 0.0001 + +data: + train: + scheduler: + start_proportion: 0.5 + target_lr: 0 + buffer_size: 50 + dataloader: + batch_size: 1 + shuffle: True + num_workers: 2 + pin_memory: True + drop_last: True + dataset: + _type: GenerationUnpairedDataset + root_a: "/data/face2cartoon/all_face" + root_b: "/data/selfie2anime/trainB/" + random_pair: True + pipeline_a: + - Load + - RandomCrop: + size: [ 178, 178 ] + - Resize: + size: [ 256, 256 ] + - RandomHorizontalFlip + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + pipeline_b: + - Load + - Resize: + size: [ 286, 286 ] + - RandomCrop: + size: [ 256, 256 ] + - RandomHorizontalFlip + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + test: + which: video_dataset + dataloader: + batch_size: 1 + shuffle: False + num_workers: 1 + pin_memory: False + drop_last: False + dataset: + _type: GenerationUnpairedDataset + root_a: "/data/face2cartoon/test/human" + root_b: "/data/face2cartoon/test/anime" + random_pair: True + pipeline_a: + - Load + - Resize: + size: [ 256, 256 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + pipeline_b: + - Load + - Resize: + size: [ 256, 256 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + video_dataset: + _type: SingleFolderDataset + root: "/data/i2i/VoxCeleb2Anime/test_video_frames/" + with_path: True + pipeline: + - Load + - Resize: + size: [ 256, 256 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] diff --git a/engine/GauGAN.py b/engine/GauGAN.py new file mode 100644 index 0000000..2c8b762 --- /dev/null +++ b/engine/GauGAN.py @@ -0,0 +1,86 @@ +from itertools import chain + +import torch + +from engine.base.i2i import EngineKernel, run_kernel +from engine.util.build import build_model +from engine.util.container import GANImageBuffer, LossContainer +from engine.util.loss import gan_loss, feature_match_loss, perceptual_loss +from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss +from model.weight_init import generation_init_weights + + +class GauGANEngineKernel(EngineKernel): + def __init__(self, config): + super().__init__(config) + + self.gan_loss = gan_loss(config.loss.gan) + self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite")) + self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "exponential_decline")) + self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual)) + + self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in + self.discriminators.keys()} + + def build_models(self) -> (dict, dict): + generators = dict( + main=build_model(self.config.model.generator) + ) + discriminators = dict( + b=build_model(self.config.model.discriminator) + ) + self.logger.debug(discriminators["b"]) + self.logger.debug(generators["main"]) + + for m in chain(generators.values(), discriminators.values()): + generation_init_weights(m) + + return generators, discriminators + + def setup_after_g(self): + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(True) + + def setup_before_g(self): + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(False) + + def forward(self, batch, inference=False) -> dict: + images = dict() + with torch.set_grad_enabled(not inference): + images["a2b"] = self.generators["main"](batch["a"]) + return images + + def criterion_generators(self, batch, generated) -> dict: + loss = dict() + prediction_fake = self.discriminators["b"](generated["a2b"]) + loss["gan"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True) + loss["mgc"] = self.mgc_loss(generated["a2b"], batch["a"]) + loss["perceptual"] = self.perceptual_loss(generated["a2b"], batch["a"]) + if self.fm_loss.weight > 0: + prediction_real = self.discriminators["b"](batch["b"]) + loss["feature_match"] = self.fm_loss(prediction_fake, prediction_real) + return loss + + def criterion_discriminators(self, batch, generated) -> dict: + loss = dict() + generated_image = self.image_buffers["b"].query(generated["a2b"].detach()) + loss["b"] = (self.gan_loss(self.discriminators["b"](generated_image), False, is_discriminator=True) + + self.gan_loss(self.discriminators["b"](batch["b"]), True, is_discriminator=True)) / 2 + return loss + + def intermediate_images(self, batch, generated) -> dict: + """ + returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + :param batch: + :param generated: dict of images + :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + """ + return dict( + a=[batch["a"].detach(), generated["a2b"].detach()], + ) + + +def run(task, config, _): + kernel = GauGANEngineKernel(config) + run_kernel(task, config, kernel) diff --git a/engine/util/loss.py b/engine/util/loss.py index 5559ce4..1b161ad 100644 --- a/engine/util/loss.py +++ b/engine/util/loss.py @@ -4,29 +4,20 @@ import torch.nn as nn import torch.nn.functional as F from omegaconf import OmegaConf +from loss.I2I.perceptual_loss import PerceptualLoss from loss.gan import GANLoss def gan_loss(config): gan_loss_cfg = OmegaConf.to_container(config) gan_loss_cfg.pop("weight") - gl = GANLoss(**gan_loss_cfg).to(idist.device()) - def gan_loss_fn(prediction, target_is_real: bool, is_discriminator=False): - if isinstance(prediction, torch.Tensor): - # origin - return gl(prediction, target_is_real, is_discriminator) - elif isinstance(prediction, list) and isinstance(prediction[0], list): - # for multi scale discriminator, e.g. MultiScaleDiscriminator - loss = 0 - for p in prediction: - loss += gl(p[-1], target_is_real, is_discriminator) - return loss - elif isinstance(prediction, list) and isinstance(prediction[0], torch.Tensor): - # for discriminator set `need_intermediate_feature` true - return gl(prediction[-1], target_is_real, is_discriminator) - else: - raise NotImplementedError("not support discriminator output") - return gan_loss_fn + return GANLoss(**gan_loss_cfg).to(idist.device()) + + +def perceptual_loss(config): + perceptual_loss_cfg = OmegaConf.to_container(config) + perceptual_loss_cfg.pop("weight") + return PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) def pixel_loss(level): diff --git a/loss/gan.py b/loss/gan.py index 5e30bc4..438998d 100644 --- a/loss/gan.py +++ b/loss/gan.py @@ -1,4 +1,5 @@ import torch.nn as nn +import torch import torch.nn.functional as F @@ -10,7 +11,7 @@ class GANLoss(nn.Module): self.fake_label_val = fake_label_val self.loss_type = loss_type - def forward(self, prediction, target_is_real: bool, is_discriminator=False): + def single_forward(self, prediction, target_is_real: bool, is_discriminator=False): """ gan loss forward :param prediction: network prediction @@ -37,3 +38,20 @@ class GANLoss(nn.Module): return loss else: raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.') + + def forward(self, prediction, target_is_real: bool, is_discriminator=False): + if isinstance(prediction, torch.Tensor): + # origin + return self.single_forward(prediction, target_is_real, is_discriminator) + elif isinstance(prediction, list): + # for multi scale discriminator, e.g. MultiScaleDiscriminator + loss = 0 + for p in prediction: + loss += self.single_forward(p[-1], target_is_real, is_discriminator) + return loss + elif isinstance(prediction, tuple): + # for single discriminator set `need_intermediate_feature` true + return self.single_forward(prediction[-1], target_is_real, is_discriminator) + else: + raise NotImplementedError(f"not support discriminator output: {prediction}") + diff --git a/model/__init__.py b/model/__init__.py index 79ffd63..c825e4a 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -3,3 +3,4 @@ import model.base.normalization import model.image_translation.UGATIT import model.image_translation.CycleGAN import model.image_translation.pix2pixHD +import model.image_translation.GauGAN \ No newline at end of file diff --git a/model/image_translation/GauGAN.py b/model/image_translation/GauGAN.py index 3b1b7e5..67e152a 100644 --- a/model/image_translation/GauGAN.py +++ b/model/image_translation/GauGAN.py @@ -7,7 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock - +from model import MODEL class StyleEncoder(nn.Module): def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64, @@ -122,7 +122,7 @@ class ImprovedSPADEGenerator(nn.Module): def forward(self, seg, style=None): pass - +@MODEL.register_module() class SPADEGenerator(nn.Module): def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64, padding_mode='reflect', activation_type="LeakyReLU"): @@ -156,11 +156,8 @@ class SPADEGenerator(nn.Module): ) )) self.sequence = nn.Sequential(*sequence) - self.output_converter = nn.Sequential( - ReverseConv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1, - padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"), - nn.Tanh() - ) + self.output_converter = Conv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1, + padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE") def forward(self, seg, z=None): if self.use_vae: diff --git a/model/weight_init.py b/model/weight_init.py index 6a64a4c..8eafece 100644 --- a/model/weight_init.py +++ b/model/weight_init.py @@ -65,7 +65,8 @@ def generation_init_weights(module, init_type='normal', init_gain=0.02): elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; # only normal distribution applies. - normal_init(m, 1.0, init_gain) + if m.weight is not None: + normal_init(m, 1.0, init_gain) assert isinstance(module, nn.Module) module.apply(init_func)