From 2ff4a91057a835d4d7026c4e1262c3949dcd14e9 Mon Sep 17 00:00:00 2001 From: budui Date: Mon, 14 Sep 2020 22:30:05 +0800 Subject: [PATCH] add MUNIT --- .idea/deployment.xml | 4 +- configs/synthesizers/MUNIT.yml | 132 ++++++++++++++++++++++++++++ engine/MUNIT.py | 154 +++++++++++++++++++++++++++++++++ environment.yml | 1 - model/GAN/MUNIT.py | 154 +++++++++++++++++++++++++++++++++ model/GAN/base.py | 68 ++++++++++++++- model/__init__.py | 3 +- 7 files changed, 510 insertions(+), 6 deletions(-) create mode 100644 configs/synthesizers/MUNIT.yml create mode 100644 engine/MUNIT.py create mode 100644 model/GAN/MUNIT.py diff --git a/.idea/deployment.xml b/.idea/deployment.xml index fd2b0b6..8ccfb5e 100644 --- a/.idea/deployment.xml +++ b/.idea/deployment.xml @@ -1,11 +1,11 @@ - + - + diff --git a/configs/synthesizers/MUNIT.yml b/configs/synthesizers/MUNIT.yml new file mode 100644 index 0000000..3de00d2 --- /dev/null +++ b/configs/synthesizers/MUNIT.yml @@ -0,0 +1,132 @@ +name: MUNIT-edges2shoes +engine: MUNIT +result_dir: ./result +max_pairs: 1000000 + +handler: + clear_cuda_cache: True + set_epoch_for_dist_sampler: True + checkpoint: + epoch_interval: 1 # checkpoint once per `epoch_interval` epoch + n_saved: 2 + tensorboard: + scalar: 100 # log scalar `scalar` times per epoch + image: 2 # log image `image` times per epoch + + +misc: + random_seed: 324 + +model: + generator: + _type: MUNIT-Generator + in_channels: 3 + out_channels: 3 + base_channels: 64 + num_sampling: 2 + num_style_dim: 8 + num_style_conv: 4 + num_content_res_blocks: 4 + num_decoder_res_blocks: 4 + num_fusion_dim: 256 + num_fusion_blocks: 3 + + discriminator: + _type: MultiScaleDiscriminator + num_scale: 2 + discriminator_cfg: + _type: PatchDiscriminator + in_channels: 3 + base_channels: 64 + use_spectral: True + need_intermediate_feature: True + +loss: + gan: + loss_type: lsgan + real_label_val: 1.0 + fake_label_val: 0.0 + weight: 1.0 + perceptual: + layer_weights: + "1": 0.03125 + "6": 0.0625 + "11": 0.125 + "20": 0.25 + "29": 1 + criterion: 'L1' + style_loss: False + perceptual_loss: True + weight: 0 + recon: + level: 1 + style: + weight: 1 + content: + weight: 1 + image: + weight: 10 + cycle: + weight: 0 + +optimizers: + generator: + _type: Adam + lr: 0.0001 + betas: [ 0.5, 0.999 ] + weight_decay: 0.0001 + discriminator: + _type: Adam + lr: 4e-4 + betas: [ 0.5, 0.999 ] + weight_decay: 0.0001 + +data: + train: + scheduler: + start_proportion: 0.5 + target_lr: 0 + buffer_size: 50 + dataloader: + batch_size: 1 + shuffle: True + num_workers: 1 + pin_memory: True + drop_last: True + dataset: + _type: GenerationUnpairedDataset + root_a: "/data/i2i/edges2shoes/trainA" + root_b: "/data/i2i/edges2shoes/trainB" + random_pair: True + pipeline: + - Load + - Resize: + size: [ 286, 286 ] + - RandomCrop: + size: [ 256, 256 ] + - RandomHorizontalFlip + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + test: + which: dataset + dataloader: + batch_size: 8 + shuffle: False + num_workers: 1 + pin_memory: False + drop_last: False + dataset: + _type: GenerationUnpairedDataset + root_a: "/data/i2i/edges2shoes/testA" + root_b: "/data/i2i/edges2shoes/testB" + random_pair: False + pipeline: + - Load + - Resize: + size: [ 256, 256 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] diff --git a/engine/MUNIT.py b/engine/MUNIT.py new file mode 100644 index 0000000..a0ae713 --- /dev/null +++ b/engine/MUNIT.py @@ -0,0 +1,154 @@ +import ignite.distributed as idist +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import OmegaConf + +from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel +from engine.util.build import build_model +from loss.I2I.perceptual_loss import PerceptualLoss +from loss.gan import GANLoss + + +def mse_loss(x, target_flag): + return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) + + +def bce_loss(x, target_flag): + return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) + + +class MUNITEngineKernel(EngineKernel): + def __init__(self, config): + super().__init__(config) + + perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual) + perceptual_loss_cfg.pop("weight") + self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) + + gan_loss_cfg = OmegaConf.to_container(config.loss.gan) + gan_loss_cfg.pop("weight") + self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) + self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() + self.train_generator_first = False + + def build_models(self) -> (dict, dict): + generators = dict( + a=build_model(self.config.model.generator), + b=build_model(self.config.model.generator) + ) + discriminators = dict( + a=build_model(self.config.model.discriminator), + b=build_model(self.config.model.discriminator) + ) + self.logger.debug(discriminators["a"]) + self.logger.debug(generators["a"]) + + return generators, discriminators + + def setup_after_g(self): + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(True) + + def setup_before_g(self): + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(False) + + def forward(self, batch, inference=False) -> dict: + styles = dict() + contents = dict() + images = dict() + + for phase in "ab": + contents[phase], styles[phase] = self.generators[phase].encode(batch[phase]) + images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase]) + styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device()) + + for phase in ("a2b", "b2a"): + # images["a2b"] = Gb.decode(content_a, random_style_b) + images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"]) + # contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"]) + contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase]) + if self.config.loss.recon.cycle.weight > 0: + images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]]) + return dict(styles=styles, contents=contents, images=images) + + def criterion_generators(self, batch, generated) -> dict: + loss = dict() + for phase in "ab": + loss[f"recon_image_{phase}"] = self.config.loss.recon.image.weight * self.recon_loss( + batch[phase], generated["images"]["{0}2{0}".format(phase)]) + loss[f"recon_content_{phase}"] = self.config.loss.recon.content.weight * self.recon_loss( + generated["contents"][phase], generated["contents"]["a2b" if phase == "a" else "b2a"]) + loss[f"recon_style_{phase}"] = self.config.loss.recon.style.weight * self.recon_loss( + generated["styles"][f"random_{phase}"], generated["styles"]["b2a" if phase == "a" else "a2b"]) + pred_fake = self.discriminators[phase](generated["images"]["b2a" if phase == "a" else "a2b"]) + loss[f"gan_{phase}"] = 0 + for sub_pred_fake in pred_fake: + # last output is actual prediction + loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True) + if self.config.loss.recon.cycle.weight > 0: + loss[f"recon_cycle_{phase}"] = self.config.loss.recon.cycle.weight * self.recon_loss( + batch[phase], generated["images"]["a2b2a" if phase == "a" else "b2a2b"]) + if self.config.loss.perceptual.weight > 0: + loss[f"perceptual_{phase}"] = self.config.loss.perceptual.weight * self.perceptual_loss( + batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"]) + return loss + + def criterion_discriminators(self, batch, generated) -> dict: + loss = dict() + for phase in ("a2b", "b2a"): + pred_real = self.discriminators[phase[-1]](batch[phase[-1]]) + pred_fake = self.discriminators[phase[-1]](generated["images"][phase].detach()) + loss[f"gan_{phase[-1]}"] = 0 + for i in range(len(pred_fake)): + loss[f"gan_{phase[-1]}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) + + self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2 + return loss + + def intermediate_images(self, batch, generated) -> dict: + """ + returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + :param batch: + :param generated: dict of images + :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + """ + generated = {img: generated["images"][img].detach() for img in generated["images"]} + images = dict() + for phase in "ab": + images[phase] = [batch[phase].detach(), generated["{0}2{0}".format(phase)], + generated["a2b" if phase == "a" else "b2a"]] + if self.config.loss.recon.cycle.weight > 0: + images[phase].append(generated["a2b2a" if phase == "a" else "b2a2b"]) + return images + + +class MUNITTestEngineKernel(TestEngineKernel): + def __init__(self, config): + super().__init__(config) + + def build_generators(self) -> dict: + generators = dict( + a=build_model(self.config.model.generator), + b=build_model(self.config.model.generator) + ) + return generators + + def to_load(self): + return {f"generator_{k}": self.generators[k] for k in self.generators} + + def inference(self, batch): + with torch.no_grad(): + fake, _, _ = self.generators["a2b"](batch[0]) + return fake.detach() + + +def run(task, config, _): + if task == "train": + kernel = MUNITEngineKernel(config) + run_kernel(task, config, kernel) + elif task == "test": + kernel = MUNITTestEngineKernel(config) + run_kernel(task, config, kernel) + else: + raise NotImplemented diff --git a/environment.yml b/environment.yml index a623f79..155be02 100644 --- a/environment.yml +++ b/environment.yml @@ -6,7 +6,6 @@ channels: dependencies: - python=3.8 - numpy - - ipython - tqdm - pyyaml - pytorch=1.6.* diff --git a/model/GAN/MUNIT.py b/model/GAN/MUNIT.py new file mode 100644 index 0000000..0c2e221 --- /dev/null +++ b/model/GAN/MUNIT.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn + +from model import MODEL +from model.GAN.base import Conv2dBlock, ResBlock +from model.normalization import select_norm_layer + + +class StyleEncoder(nn.Module): + def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False, + padding_mode='reflect', activation_type="ReLU", norm_type="NONE"): + super(StyleEncoder, self).__init__() + + sequence = [Conv2dBlock( + in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, + use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type + )] + + multiple_now = 1 + for i in range(1, num_conv + 1): + multiple_prev = multiple_now + multiple_now = min(2 ** i, 2 ** 2) + sequence.append(Conv2dBlock( + multiple_prev * base_channels, multiple_now * base_channels, + kernel_size=4, stride=2, padding=1, padding_mode=padding_mode, + use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type + )) + sequence.append(nn.AdaptiveAvgPool2d(1)) + # conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code + sequence.append(nn.Conv2d(multiple_now * base_channels, out_dim, kernel_size=1, stride=1, padding=0)) + self.model = nn.Sequential(*sequence) + + def forward(self, x): + return self.model(x).view(x.size(0), -1) + + +class ContentEncoder(nn.Module): + def __init__(self, in_channels, num_down_sampling, num_res_blocks, base_channels=64, use_spectral_norm=False, + padding_mode='reflect', activation_type="ReLU", norm_type="IN"): + super().__init__() + sequence = [Conv2dBlock( + in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, + use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type + )] + + for i in range(num_down_sampling): + sequence.append(Conv2dBlock( + base_channels * (2 ** i), base_channels * (2 ** (i + 1)), + kernel_size=4, stride=2, padding=1, padding_mode=padding_mode, + use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type + )) + + for _ in range(num_res_blocks): + sequence.append( + ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type, + activation_type) + ) + + self.sequence = nn.Sequential(*sequence) + + def forward(self, x): + return self.sequence(x) + + +class Decoder(nn.Module): + def __init__(self, in_channels, out_channels, num_up_sampling, num_res_blocks, + use_spectral_norm=False, res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU", + padding_mode='reflect'): + super(Decoder, self).__init__() + self.res_norm_type = res_norm_type + self.res_blocks = nn.ModuleList([ + ResBlock(in_channels, use_spectral_norm, padding_mode, res_norm_type, activation_type=activation_type) + for _ in range(num_res_blocks) + ]) + sequence = list() + channels = in_channels + for i in range(num_up_sampling): + sequence.append(nn.Sequential( + nn.Upsample(scale_factor=2), + Conv2dBlock(channels, channels // 2, + kernel_size=5, stride=1, padding=2, padding_mode=padding_mode, + use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type + ), + )) + channels = channels // 2 + sequence.append( + Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect", + use_spectral_norm=use_spectral_norm, activation_type="Tanh", norm_type="NONE")) + self.sequence = nn.Sequential(*sequence) + + def forward(self, x): + for blk in self.res_blocks: + x = blk(x) + return self.sequence(x) + + +class Fusion(nn.Module): + def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"): + super().__init__() + norm_layer = select_norm_layer(norm_type) + self.start_fc = nn.Sequential( + nn.Linear(in_features, base_features), + norm_layer(base_features), + nn.ReLU(True), + ) + self.fcs = nn.Sequential(*[ + nn.Sequential( + nn.Linear(base_features, base_features), + norm_layer(base_features), + nn.ReLU(True), + ) for _ in range(n_blocks - 2) + ]) + self.end_fc = nn.Sequential( + nn.Linear(base_features, out_features), + ) + + def forward(self, x): + x = self.start_fc(x) + x = self.fcs(x) + return self.end_fc(x) + + +@MODEL.register_module("MUNIT-Generator") +class Generator(nn.Module): + def __init__(self, in_channels, out_channels, base_channels, num_sampling, num_style_dim, num_style_conv, + num_content_res_blocks, num_decoder_res_blocks, num_fusion_dim, num_fusion_blocks, + use_spectral_norm=False, activation_type="ReLU", padding_mode='reflect'): + super().__init__() + self.num_decoder_res_blocks = num_decoder_res_blocks + self.content_encoder = ContentEncoder(in_channels, num_sampling, num_content_res_blocks, base_channels, + use_spectral_norm, padding_mode, activation_type, norm_type="IN") + self.style_encoder = StyleEncoder(in_channels, num_style_dim, num_style_conv, base_channels, use_spectral_norm, + padding_mode, activation_type, norm_type="NONE") + content_channels = base_channels * (2 ** 2) + self.decoder = Decoder(content_channels, out_channels, num_sampling, + num_decoder_res_blocks, use_spectral_norm, "AdaIN", norm_type="LN", + activation_type=activation_type, padding_mode=padding_mode) + self.fusion = Fusion(num_style_dim, num_decoder_res_blocks * 2 * content_channels * 2, + base_features=num_fusion_dim, n_blocks=num_fusion_blocks, norm_type="NONE") + + def encode(self, x): + return self.content_encoder(x), self.style_encoder(x) + + def decode(self, content, style): + as_param_style = torch.chunk(self.fusion(style), self.num_decoder_res_blocks * 2, dim=1) + # set style for decoder + for i, blk in enumerate(self.decoder.res_blocks): + blk.conv1.normalization.set_style(as_param_style[2 * i]) + blk.conv2.normalization.set_style(as_param_style[2 * i + 1]) + return self.decoder(content) + + def forward(self, x): + content, style = self.encode(x) + return self.decode(content, style) diff --git a/model/GAN/base.py b/model/GAN/base.py index 52a351a..fb73169 100644 --- a/model/GAN/base.py +++ b/model/GAN/base.py @@ -1,10 +1,11 @@ -import math +from functools import partial +import math import torch import torch.nn as nn -from model.normalization import select_norm_layer from model import MODEL +from model.normalization import select_norm_layer class GANImageBuffer(object): @@ -137,3 +138,66 @@ class ResidualBlock(nn.Module): x = self.relu1(self.norm1(self.conv1(x))) x = self.norm2(self.conv2(x)) return x + res + + +_DO_NO_THING_FUNC = lambda x: x + + +def select_activation(t): + if t == "ReLU": + return partial(nn.ReLU, inplace=True) + elif t == "LeakyReLU": + return partial(nn.LeakyReLU, negative_slope=0.2, inplace=True) + elif t == "Tanh": + return partial(nn.Tanh) + elif t == "NONE": + return _DO_NO_THING_FUNC + else: + raise NotImplemented + + +def _use_bias_checker(norm_type): + return norm_type not in ["IN", "BN", "AdaIN"] + + +class Conv2dBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int, use_spectral_norm=False, activation_type="ReLU", + bias=None, norm_type="NONE", **conv_kwargs): + super().__init__() + self.norm_type = norm_type + self.activation_type = activation_type + conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias + conv = nn.Conv2d(in_channels, out_channels, **conv_kwargs) + self.convolution = nn.utils.spectral_norm(conv) if use_spectral_norm else conv + if norm_type != "NONE": + self.normalization = select_norm_layer(norm_type)(out_channels) + if activation_type != "NONE": + self.activation = select_activation(activation_type)() + + def forward(self, x): + x = self.convolution(x) + if self.norm_type != "NONE": + x = self.normalization(x) + if self.activation_type != "NONE": + x = self.activation(x) + return x + + +class ResBlock(nn.Module): + def __init__(self, num_channels, use_spectral_norm=False, padding_mode='reflect', + norm_type="IN", activation_type="relu", use_bias=None): + super().__init__() + self.norm_type = norm_type + if use_bias is None: + # bias will be canceled after channel wise normalization + use_bias = _use_bias_checker(norm_type) + + self.conv1 = Conv2dBlock(num_channels, num_channels, use_spectral_norm, + kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias, + norm_type=norm_type, activation_type=activation_type) + self.conv2 = Conv2dBlock(num_channels, num_channels, use_spectral_norm, + kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias, + norm_type=norm_type, activation_type="NONE") + + def forward(self, x): + return self.conv2(self.conv1(x)) + x diff --git a/model/__init__.py b/model/__init__.py index 53029b2..c4c6cf5 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -4,4 +4,5 @@ import model.GAN.TAFG import model.GAN.UGATIT import model.GAN.wrapper import model.GAN.base -import model.GAN.TSIT \ No newline at end of file +import model.GAN.TSIT +import model.GAN.MUNIT \ No newline at end of file