From ab545843bf5a5b019d1cd4416a4869667793052d Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Sun, 6 Sep 2020 10:34:52 +0800 Subject: [PATCH] almost 0.1 --- configs/few-shot/crossdomain.yml | 51 ---- .../{cyclegan.yml => CyCleGAN.yml} | 58 ++-- configs/synthesizers/TAFG.yml | 4 +- engine/CyCleGAN.py | 101 +++++++ engine/TAFG.py | 13 +- engine/U-GAT-IT.py | 14 +- engine/base/i2i.py | 67 +++-- engine/cyclegan.py | 268 ------------------ environment.yml | 2 +- model/GAN/CycleGAN.py | 62 ++++ model/GAN/UGATIT.py | 1 - model/GAN/base.py | 57 +++- model/GAN/residual_generator.py | 182 ------------ model/__init__.py | 3 +- model/fewshot.py | 105 ------- 15 files changed, 308 insertions(+), 680 deletions(-) delete mode 100644 configs/few-shot/crossdomain.yml rename configs/synthesizers/{cyclegan.yml => CyCleGAN.yml} (63%) create mode 100644 engine/CyCleGAN.py delete mode 100644 engine/cyclegan.py create mode 100644 model/GAN/CycleGAN.py delete mode 100644 model/GAN/residual_generator.py delete mode 100644 model/fewshot.py diff --git a/configs/few-shot/crossdomain.yml b/configs/few-shot/crossdomain.yml deleted file mode 100644 index bafd121..0000000 --- a/configs/few-shot/crossdomain.yml +++ /dev/null @@ -1,51 +0,0 @@ -name: cross-domain-1 -engine: crossdomain -result_dir: ./result - -distributed: - model: - # broadcast_buffers: False - -misc: - random_seed: 1004 - -checkpoints: - interval: 2000 - -log: - logger: - level: 20 # DEBUG(10) INFO(20) - -model: - _type: resnet10 - -baseline: - plusplus: False - optimizers: - _type: Adam - data: - dataloader: - batch_size: 1200 - shuffle: True - num_workers: 16 - pin_memory: True - drop_last: True - dataset: - train: - path: /data/few-shot/mini_imagenet_full_size/train - lmdb_path: /data/few-shot/lmdb/mini-ImageNet/train.lmdb - pipeline: - - Load - - RandomResizedCrop: - size: [224, 224] - - ColorJitter: - brightness: 0.4 - contrast: 0.4 - saturation: 0.4 - - RandomHorizontalFlip - - ToTensor - - Normalize: - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - - diff --git a/configs/synthesizers/cyclegan.yml b/configs/synthesizers/CyCleGAN.yml similarity index 63% rename from configs/synthesizers/cyclegan.yml rename to configs/synthesizers/CyCleGAN.yml index 1798ebe..be3c4b8 100644 --- a/configs/synthesizers/cyclegan.yml +++ b/configs/synthesizers/CyCleGAN.yml @@ -1,40 +1,34 @@ -name: horse2zebra -engine: cyclegan +name: horse2zebra-CyCleGAN +engine: CyCleGAN result_dir: ./result -max_iteration: 16600 - -distributed: - model: - # broadcast_buffers: False +max_pairs: 266800 misc: random_seed: 324 -checkpoints: - interval: 2000 - -log: - logger: - level: 20 # DEBUG(10) INFO(20) +handler: + clear_cuda_cache: False + set_epoch_for_dist_sampler: True + checkpoint: + epoch_interval: 1 # checkpoint once per `epoch_interval` epoch + n_saved: 2 + tensorboard: + scalar: 100 # log scalar `scalar` times per epoch + image: 2 # log image `image` times per epoch model: generator: - _type: ResGenerator + _type: CyCle-Generator in_channels: 3 out_channels: 3 base_channels: 64 num_blocks: 9 padding_mode: reflect norm_type: IN - use_dropout: False discriminator: _type: PatchDiscriminator -# _distributed: -# bn_to_syncbn: False in_channels: 3 base_channels: 64 - num_conv: 3 - norm_type: IN loss: gan: @@ -53,19 +47,22 @@ optimizers: generator: _type: Adam lr: 2e-4 - betas: [0.5, 0.999] + betas: [ 0.5, 0.999 ] discriminator: _type: Adam lr: 2e-4 - betas: [0.5, 0.999] + betas: [ 0.5, 0.999 ] data: train: + scheduler: + start_proportion: 0.5 + target_lr: 0 buffer_size: 50 dataloader: - batch_size: 16 + batch_size: 6 shuffle: True - num_workers: 4 + num_workers: 2 pin_memory: True drop_last: True dataset: @@ -76,14 +73,14 @@ data: pipeline: - Load - Resize: - size: [286, 286] + size: [ 286, 286 ] - RandomCrop: - size: [256, 256] + size: [ 256, 256 ] - RandomHorizontalFlip - ToTensor - scheduler: - start: 8300 - target_lr: 0 + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] test: dataloader: batch_size: 4 @@ -99,5 +96,8 @@ data: pipeline: - Load - Resize: - size: [256, 256] + size: [ 256, 256 ] - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] diff --git a/configs/synthesizers/TAFG.yml b/configs/synthesizers/TAFG.yml index 0436434..df2aff7 100644 --- a/configs/synthesizers/TAFG.yml +++ b/configs/synthesizers/TAFG.yml @@ -1,7 +1,7 @@ name: TAFG engine: TAFG result_dir: ./result -max_pairs: 1000000 +max_pairs: 1500000 handler: clear_cuda_cache: True @@ -28,7 +28,7 @@ model: _type: MultiScaleDiscriminator num_scale: 2 discriminator_cfg: - _type: pix2pixHD-PatchDiscriminator + _type: PatchDiscriminator in_channels: 3 base_channels: 64 use_spectral: True diff --git a/engine/CyCleGAN.py b/engine/CyCleGAN.py new file mode 100644 index 0000000..96d6db4 --- /dev/null +++ b/engine/CyCleGAN.py @@ -0,0 +1,101 @@ +from itertools import chain + +import ignite.distributed as idist +import torch +import torch.nn as nn +from omegaconf import OmegaConf + +from engine.base.i2i import EngineKernel, run_kernel +from engine.util.build import build_model +from loss.gan import GANLoss +from model.GAN.base import GANImageBuffer +from model.weight_init import generation_init_weights + + +class TAFGEngineKernel(EngineKernel): + def __init__(self, config): + super().__init__(config) + + gan_loss_cfg = OmegaConf.to_container(config.loss.gan) + gan_loss_cfg.pop("weight") + self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) + self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() + self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss() + self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in + self.discriminators.keys()} + + def build_models(self) -> (dict, dict): + generators = dict( + a2b=build_model(self.config.model.generator), + b2a=build_model(self.config.model.generator) + ) + discriminators = dict( + a=build_model(self.config.model.discriminator), + b=build_model(self.config.model.discriminator) + ) + self.logger.debug(discriminators["a"]) + self.logger.debug(generators["a2b"]) + + for m in chain(generators.values(), discriminators.values()): + generation_init_weights(m) + + return generators, discriminators + + def setup_after_g(self): + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(True) + + def setup_before_g(self): + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(False) + + def forward(self, batch, inference=False) -> dict: + images = dict() + with torch.set_grad_enabled(not inference): + images["a2b"] = self.generators["a2b"](batch["a"]) + images["b2a"] = self.generators["b2a"](batch["b"]) + images["a2b2a"] = self.generators["b2a"](images["a2b"]) + images["b2a2b"] = self.generators["a2b"](images["b2a"]) + if self.config.loss.id.weight > 0: + images["a2a"] = self.generators["b2a"](batch["a"]) + images["b2b"] = self.generators["a2b"](batch["b"]) + return images + + def criterion_generators(self, batch, generated) -> dict: + loss = dict() + for phase in ["a2b", "b2a"]: + loss[f"cycle_{phase[0]}"] = self.config.loss.cycle.weight * self.cycle_loss( + generated[f"{phase}2{phase[0]}"], batch[phase[0]]) + loss[f"gan_{phase}"] = self.config.loss.gan.weight * self.gan_loss( + self.discriminators[phase[-1]](generated[phase]), True) + if self.config.loss.id.weight > 0: + loss[f"id_{phase[0]}"] = self.config.loss.id.weight * self.id_loss( + generated[f"{phase[0]}2{phase[0]}"], batch[phase[0]]) + return loss + + def criterion_discriminators(self, batch, generated) -> dict: + loss = dict() + for phase in "ab": + generated_image = self.image_buffers[phase].query(generated["b2a" if phase == "a" else "a2b"].detach()) + loss[f"gan_{phase}"] = (self.gan_loss(self.discriminators[phase](generated_image), False, + is_discriminator=True) + + self.gan_loss(self.discriminators[phase](batch[phase]), True, + is_discriminator=True)) / 2 + return loss + + def intermediate_images(self, batch, generated) -> dict: + """ + returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + :param batch: + :param generated: dict of images + :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + """ + return dict( + a=[batch["a"].detach(), generated["a2b"].detach(), generated["a2b2a"].detach()], + b=[batch["b"].detach(), generated["b2a"].detach(), generated["b2a2b"].detach()], + ) + + +def run(task, config, _): + kernel = TAFGEngineKernel(config) + run_kernel(task, config, kernel) diff --git a/engine/TAFG.py b/engine/TAFG.py index fec01de..77d3646 100644 --- a/engine/TAFG.py +++ b/engine/TAFG.py @@ -5,6 +5,9 @@ from omegaconf import OmegaConf import torch import torch.nn as nn import ignite.distributed as idist +from ignite.engine import Events + +from omegaconf import read_write, OmegaConf from model.weight_init import generation_init_weights from loss.I2I.perceptual_loss import PerceptualLoss @@ -49,7 +52,7 @@ class TAFGEngineKernel(EngineKernel): return generators, discriminators - def setup_before_d(self): + def setup_after_g(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(True) @@ -89,7 +92,7 @@ class TAFGEngineKernel(EngineKernel): for j in range(num_intermediate_outputs): loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm - loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight + loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"]) # loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss( # self.generators["main"].module.style_encoders["b"](batch["b"]), # self.generators["main"].module.style_encoders["b"](generated["b"]) @@ -122,6 +125,12 @@ class TAFGEngineKernel(EngineKernel): generated["b"].detach()] ) + def change_engine(self, config, trainer): + @trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3))) + def change_config(engine): + with read_write(config): + config.loss.perceptual.weight = 5 + def run(task, config, _): kernel = TAFGEngineKernel(config) diff --git a/engine/U-GAT-IT.py b/engine/U-GAT-IT.py index 4eb0886..9b7178c 100644 --- a/engine/U-GAT-IT.py +++ b/engine/U-GAT-IT.py @@ -1,5 +1,3 @@ -from itertools import chain - from omegaconf import OmegaConf import torch @@ -7,10 +5,9 @@ import torch.nn as nn import torch.nn.functional as F import ignite.distributed as idist -from model.weight_init import generation_init_weights from loss.gan import GANLoss from model.GAN.UGATIT import RhoClipper -from model.GAN.residual_generator import GANImageBuffer +from model.GAN.base import GANImageBuffer from util.image import attention_colored_map from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel from engine.util.build import build_model @@ -36,6 +33,7 @@ class UGATITEngineKernel(EngineKernel): self.rho_clipper = RhoClipper(0, 1) self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in self.discriminators.keys()} + self.train_generator_first = False def build_models(self) -> (dict, dict): generators = dict( @@ -51,12 +49,9 @@ class UGATITEngineKernel(EngineKernel): self.logger.debug(discriminators["ga"]) self.logger.debug(generators["a2b"]) - for m in chain(generators.values(), discriminators.values()): - generation_init_weights(m) - return generators, discriminators - def setup_before_d(self): + def setup_after_g(self): for generator in self.generators.values(): generator.apply(self.rho_clipper) for discriminator in self.discriminators.values(): @@ -101,8 +96,7 @@ class UGATITEngineKernel(EngineKernel): loss = dict() for phase in "ab": for level in "gl": - generated_image = self.image_buffers[level + phase].query( - generated["images"]["a2b" if phase == "b" else "b2a"]) + generated_image = generated["images"]["b2a" if phase == "a" else "a2b"].detach() pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image) pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase]) loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss( diff --git a/engine/base/i2i.py b/engine/base/i2i.py index f7177b2..9dbba6e 100644 --- a/engine/base/i2i.py +++ b/engine/base/i2i.py @@ -1,23 +1,21 @@ -from itertools import chain import logging +from itertools import chain from pathlib import Path +import ignite.distributed as idist import torch import torchvision - -import ignite.distributed as idist +from ignite.contrib.handlers.param_scheduler import PiecewiseLinear from ignite.engine import Events, Engine from ignite.metrics import RunningAverage from ignite.utils import convert_tensor -from ignite.contrib.handlers.param_scheduler import PiecewiseLinear - +from math import ceil from omegaconf import read_write, OmegaConf -from util.image import make_2d_grid -from engine.util.handler import setup_common_handlers, setup_tensorboard_handler -from engine.util.build import build_optimizer - import data +from engine.util.build import build_optimizer +from engine.util.handler import setup_common_handlers, setup_tensorboard_handler +from util.image import make_2d_grid def build_lr_schedulers(optimizers, config): @@ -59,6 +57,7 @@ class EngineKernel(object): self.config = config self.logger = logging.getLogger(config.name) self.generators, self.discriminators = self.build_models() + self.train_generator_first = True def build_models(self) -> (dict, dict): raise NotImplemented @@ -69,7 +68,7 @@ class EngineKernel(object): to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators}) return to_save - def setup_before_d(self): + def setup_after_g(self): raise NotImplemented def setup_before_g(self): @@ -93,6 +92,9 @@ class EngineKernel(object): """ raise NotImplemented + def change_engine(self, config, engine: Engine): + pass + def get_trainer(config, kernel: EngineKernel): logger = logging.getLogger(config.name) @@ -106,26 +108,37 @@ def get_trainer(config, kernel: EngineKernel): lr_schedulers = build_lr_schedulers(optimizers, config) logger.info(f"build lr_schedulers:\n{lr_schedulers}") - image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1) + iteration_per_image = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1) + + def train_generators(batch, generated): + kernel.setup_before_g() + optimizers["g"].zero_grad() + loss_g = kernel.criterion_generators(batch, generated) + sum(loss_g.values()).backward() + optimizers["g"].step() + kernel.setup_after_g() + return loss_g + + def train_discriminators(batch, generated): + optimizers["d"].zero_grad() + loss_d = kernel.criterion_discriminators(batch, generated) + sum(loss_d.values()).backward() + optimizers["d"].step() + return loss_d def _step(engine, batch): batch = convert_tensor(batch, idist.device()) generated = kernel.forward(batch) - kernel.setup_before_g() - optimizers["g"].zero_grad() - loss_g = kernel.criterion_generators(batch, generated) - sum(loss_g.values()).backward() - optimizers["g"].step() + if kernel.train_generator_first: + loss_g = train_generators(batch, generated) + loss_d = train_discriminators(batch, generated) + else: + loss_d = train_discriminators(batch, generated) + loss_g = train_generators(batch, generated) - kernel.setup_before_d() - optimizers["d"].zero_grad() - loss_d = kernel.criterion_discriminators(batch, generated) - sum(loss_d.values()).backward() - optimizers["d"].step() - - if engine.state.iteration % image_per_iteration == 0: + if engine.state.iteration % iteration_per_image == 0: return { "loss": dict(g=loss_g, d=loss_d), "img": kernel.intermediate_images(batch, generated) @@ -137,6 +150,8 @@ def get_trainer(config, kernel: EngineKernel): for lr_shd in lr_schedulers.values(): trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd) + kernel.change_engine(config, trainer) + RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") to_save = dict(trainer=trainer) @@ -150,7 +165,7 @@ def get_trainer(config, kernel: EngineKernel): tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item") if tensorboard_handler is not None: basic_image_event = Events.ITERATION_COMPLETED( - every=image_per_iteration) + every=iteration_per_image) pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size() @trainer.on(basic_image_event) @@ -227,7 +242,7 @@ def run_kernel(task, config, kernel): logger = logging.getLogger(config.name) with read_write(config): real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size() - config.max_iteration = config.max_pairs // real_batch_size + 1 + config.max_iteration = ceil(config.max_pairs / real_batch_size) if task == "train": train_dataset = data.DATASET.build_with(config.data.train.dataset) @@ -243,7 +258,7 @@ def run_kernel(task, config, kernel): test_dataset = data.DATASET.build_with(config.data.test.dataset) trainer.state.test_dataset = test_dataset try: - trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1) + trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader))) except Exception: import traceback print(traceback.format_exc()) diff --git a/engine/cyclegan.py b/engine/cyclegan.py deleted file mode 100644 index 2e56d52..0000000 --- a/engine/cyclegan.py +++ /dev/null @@ -1,268 +0,0 @@ -import itertools -from pathlib import Path - -import torch -import torch.nn as nn -import torchvision.utils - -import ignite.distributed as idist -from ignite.engine import Events, Engine -from ignite.contrib.handlers.param_scheduler import PiecewiseLinear -from ignite.metrics import RunningAverage -from ignite.contrib.handlers import ProgressBar -from ignite.utils import convert_tensor -from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler - -from omegaconf import OmegaConf - -import data -from loss.gan import GANLoss -from model.weight_init import generation_init_weights -from model.GAN.residual_generator import GANImageBuffer -from util.image import make_2d_grid -from util.handler import setup_common_handlers -from util.build import build_model, build_optimizer - - -def get_trainer(config, logger): - generator_a = build_model(config.model.generator, config.distributed.model) - generator_b = build_model(config.model.generator, config.distributed.model) - discriminator_a = build_model(config.model.discriminator, config.distributed.model) - discriminator_b = build_model(config.model.discriminator, config.distributed.model) - for m in [generator_b, generator_a, discriminator_b, discriminator_a]: - generation_init_weights(m) - logger.info(discriminator_a) - logger.info(generator_a) - - optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()), - config.optimizers.generator) - optimizer_d = build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()), - config.optimizers.discriminator) - - milestones_values = [ - (0, config.optimizers.generator.lr), - (100, config.optimizers.generator.lr), - (200, config.data.train.scheduler.target_lr) - ] - lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values) - - milestones_values = [ - (0, config.optimizers.discriminator.lr), - (100, config.optimizers.discriminator.lr), - (200, config.data.train.scheduler.target_lr) - ] - lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values) - - gan_loss_cfg = OmegaConf.to_container(config.loss.gan) - gan_loss_cfg.pop("weight") - gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) - cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() - id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() - - image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) - image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) - - def _step(engine, batch): - batch = convert_tensor(batch, idist.device()) - real_a, real_b = batch["a"], batch["b"] - - fake_b = generator_a(real_a) # G_A(A) - rec_a = generator_b(fake_b) # G_B(G_A(A)) - fake_a = generator_b(real_b) # G_B(B) - rec_b = generator_a(fake_a) # G_A(G_B(B)) - - optimizer_g.zero_grad() - discriminator_a.requires_grad_(False) - discriminator_b.requires_grad_(False) - loss_g = dict( - cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a), - cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b), - gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True), - gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True) - ) - if config.loss.id.weight > 0: - loss_g["id_a"] = config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B) - loss_g["id_b"] = config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A) - sum(loss_g.values()).backward() - optimizer_g.step() - - discriminator_a.requires_grad_(True) - discriminator_b.requires_grad_(True) - optimizer_d.zero_grad() - loss_d_a = dict( - real=gan_loss(discriminator_a(real_b), True, is_discriminator=True), - fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True), - ) - loss_d_b = dict( - real=gan_loss(discriminator_b(real_a), True, is_discriminator=True), - fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True), - ) - (sum(loss_d_a.values()) * 0.5).backward() - (sum(loss_d_b.values()) * 0.5).backward() - optimizer_d.step() - - return { - "loss": { - "g": {ln: loss_g[ln].mean().item() for ln in loss_g}, - "d_a": {ln: loss_d_a[ln].mean().item() for ln in loss_d_a}, - "d_b": {ln: loss_d_b[ln].mean().item() for ln in loss_d_b}, - }, - "img": [ - real_a.detach(), - fake_b.detach(), - rec_a.detach(), - real_b.detach(), - fake_a.detach(), - rec_b.detach() - ] - } - - trainer = Engine(_step) - trainer.logger = logger - trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_g) - trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_d) - - RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") - RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a") - RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b") - - to_save = dict( - generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a, - discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, - lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g - ) - - setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5, - filename_prefix=config.name, to_save=to_save, - print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED, - metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], - save_interval_event=Events.ITERATION_COMPLETED( - every=config.checkpoints.interval) | Events.COMPLETED) - - @trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration)) - def terminate(engine): - engine.terminate() - - if idist.get_rank() == 0: - # Create a logger - tb_logger = TensorboardLogger(log_dir=config.output_dir) - tb_writer = tb_logger.writer - - # Attach the logger to the trainer to log training loss at each iteration - def global_step_transform(*args, **kwargs): - return trainer.state.iteration - - def output_transform(output): - loss = dict() - for tl in output["loss"]: - if isinstance(output["loss"][tl], dict): - for l in output["loss"][tl]: - loss[f"{tl}_{l}"] = output["loss"][tl][l] - else: - loss[tl] = output["loss"][tl] - return loss - - tb_logger.attach( - trainer, - log_handler=OutputHandler( - tag="loss", - metric_names=["loss_g", "loss_d_a", "loss_d_b"], - global_step_transform=global_step_transform, - output_transform=output_transform - ), - event_name=Events.ITERATION_COMPLETED(every=50) - ) - tb_logger.attach( - trainer, - log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"), - event_name=Events.ITERATION_STARTED(every=50) - ) - - @trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval)) - def show_images(engine): - tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration) - - @trainer.on(Events.COMPLETED) - @idist.one_rank_only() - def _(): - # We need to close the logger with we are done - tb_logger.close() - - return trainer - - -def get_tester(config, logger): - generator_a = build_model(config.model.generator, config.distributed.model) - generator_b = build_model(config.model.generator, config.distributed.model) - - def _step(engine, batch): - batch = convert_tensor(batch, idist.device()) - real_a, real_b = batch["a"], batch["b"] - with torch.no_grad(): - fake_b = generator_a(real_a) # G_A(A) - rec_a = generator_b(fake_b) # G_B(G_A(A)) - fake_a = generator_b(real_b) # G_B(B) - rec_b = generator_a(fake_a) # G_A(G_B(B)) - return [ - real_a.detach(), - fake_b.detach(), - rec_a.detach(), - real_b.detach(), - fake_a.detach(), - rec_b.detach() - ] - - tester = Engine(_step) - tester.logger = logger - if idist.get_rank == 0: - ProgressBar(ncols=0).attach(tester) - to_load = dict(generator_a=generator_a, generator_b=generator_b) - setup_common_handlers(tester, use_profiler=False, to_save=to_load, resume_from=config.resume_from) - - @tester.on(Events.STARTED) - @idist.one_rank_only() - def mkdir(engine): - img_output_dir = Path(config.output_dir) / "test_images" - if not img_output_dir.exists(): - engine.logger.info(f"mkdir {img_output_dir}") - img_output_dir.mkdir() - - @tester.on(Events.ITERATION_COMPLETED) - def save_images(engine): - img_tensors = engine.state.output - batch_size = img_tensors[0].size(0) - for i in range(batch_size): - torchvision.utils.save_image([img[i] for img in img_tensors], - Path(config.output_dir) / f"test_images/{engine.state.iteration}_{i}.jpg", - nrow=len(img_tensors)) - - return tester - - -def run(task, config, logger): - assert torch.backends.cudnn.enabled - torch.backends.cudnn.benchmark = True - logger.info(f"start task {task}") - if task == "train": - train_dataset = data.DATASET.build_with(config.data.train.dataset) - logger.info(f"train with dataset:\n{train_dataset}") - train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) - trainer = get_trainer(config, logger) - try: - trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1) - except Exception: - import traceback - print(traceback.format_exc()) - elif task == "test": - assert config.resume_from is not None - test_dataset = data.DATASET.build_with(config.data.test.dataset) - logger.info(f"test with dataset:\n{test_dataset}") - test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader) - tester = get_tester(config, logger) - try: - tester.run(test_data_loader, max_epochs=1) - except Exception: - import traceback - print(traceback.format_exc()) - else: - return NotImplemented(f"invalid task: {task}") diff --git a/environment.yml b/environment.yml index 9d11461..a623f79 100644 --- a/environment.yml +++ b/environment.yml @@ -17,6 +17,6 @@ dependencies: - omegaconf - python-lmdb - fire - # - opencv + - opencv # - jupyterlab diff --git a/model/GAN/CycleGAN.py b/model/GAN/CycleGAN.py new file mode 100644 index 0000000..61cc3be --- /dev/null +++ b/model/GAN/CycleGAN.py @@ -0,0 +1,62 @@ +import torch.nn as nn + +from model.normalization import select_norm_layer +from model.registry import MODEL +from .base import ResidualBlock + + +@MODEL.register_module("CyCle-Generator") +class Generator(nn.Module): + def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect', + norm_type="IN"): + super(Generator, self).__init__() + assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.' + norm_layer = select_norm_layer(norm_type) + use_bias = norm_type == "IN" + + self.start_conv = nn.Sequential( + nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3, + bias=use_bias), + norm_layer(num_features=base_channels), + nn.ReLU(inplace=True) + ) + + # down sampling + submodules = [] + num_down_sampling = 2 + for i in range(num_down_sampling): + multiple = 2 ** i + submodules += [ + nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2, + kernel_size=3, stride=2, padding=1, bias=use_bias), + norm_layer(num_features=base_channels * multiple * 2), + nn.ReLU(inplace=True) + ] + self.encoder = nn.Sequential(*submodules) + + res_block_channels = num_down_sampling ** 2 * base_channels + self.resnet_middle = nn.Sequential( + *[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in + range(num_blocks)]) + + # up sampling + submodules = [] + for i in range(num_down_sampling): + multiple = 2 ** (num_down_sampling - i) + submodules += [ + nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2, + padding=1, output_padding=1, bias=use_bias), + norm_layer(num_features=base_channels * multiple // 2), + nn.ReLU(inplace=True), + ] + self.decoder = nn.Sequential(*submodules) + + self.end_conv = nn.Sequential( + nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode), + nn.Tanh() + ) + + def forward(self, x): + x = self.encoder(self.start_conv(x)) + x = self.resnet_middle(x) + return self.end_conv(self.decoder(x)) diff --git a/model/GAN/UGATIT.py b/model/GAN/UGATIT.py index f4d4bb0..f90375a 100644 --- a/model/GAN/UGATIT.py +++ b/model/GAN/UGATIT.py @@ -45,7 +45,6 @@ class Generator(nn.Module): # Down-Sampling Bottleneck mult = 2 ** n_down_sampling for i in range(num_blocks): - # TODO: change ResnetBlock to ResidualBlock, check use_bias param down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)] self.down_encoder = nn.Sequential(*down_encoder) diff --git a/model/GAN/base.py b/model/GAN/base.py index 3e15c3f..52a351a 100644 --- a/model/GAN/base.py +++ b/model/GAN/base.py @@ -1,13 +1,68 @@ import math +import torch import torch.nn as nn from model.normalization import select_norm_layer from model import MODEL +class GANImageBuffer(object): + """This class implements an image buffer that stores previously + generated images. + This buffer allows us to update the discriminator using a history of + generated images rather than the ones produced by the latest generator + to reduce model oscillation. + Args: + buffer_size (int): The size of image buffer. If buffer_size = 0, + no buffer will be created. + buffer_ratio (float): The chance / possibility to use the images + previously stored in the buffer. + """ + + def __init__(self, buffer_size, buffer_ratio=0.5): + self.buffer_size = buffer_size + # create an empty buffer + if self.buffer_size > 0: + self.img_num = 0 + self.image_buffer = [] + self.buffer_ratio = buffer_ratio + + def query(self, images): + """Query current image batch using a history of generated images. + Args: + images (Tensor): Current image batch without history information. + """ + if self.buffer_size == 0: # if the buffer size is 0, do nothing + return images + return_images = [] + for image in images: + image = torch.unsqueeze(image.data, 0) + # if the buffer is not full, keep inserting current images + if self.img_num < self.buffer_size: + self.img_num = self.img_num + 1 + self.image_buffer.append(image) + return_images.append(image) + else: + use_buffer = torch.rand(1) < self.buffer_ratio + # by self.buffer_ratio, the buffer will return a previously + # stored image, and insert the current image into the buffer + if use_buffer: + random_id = torch.randint(0, self.buffer_size, (1,)).item() + image_tmp = self.image_buffer[random_id].clone() + self.image_buffer[random_id] = image + return_images.append(image_tmp) + # by (1 - self.buffer_ratio), the buffer will return the + # current image + else: + return_images.append(image) + # collect all the images and return + return_images = torch.cat(return_images, 0) + return return_images + + # based SPADE or pix2pixHD Discriminator -@MODEL.register_module("pix2pixHD-PatchDiscriminator") +@MODEL.register_module("PatchDiscriminator") class PatchDiscriminator(nn.Module): def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN", need_intermediate_feature=False): diff --git a/model/GAN/residual_generator.py b/model/GAN/residual_generator.py deleted file mode 100644 index 7a0f0d3..0000000 --- a/model/GAN/residual_generator.py +++ /dev/null @@ -1,182 +0,0 @@ -import torch -import torch.nn as nn -from model.registry import MODEL -from model.normalization import select_norm_layer - - -class GANImageBuffer(object): - """This class implements an image buffer that stores previously - generated images. - This buffer allows us to update the discriminator using a history of - generated images rather than the ones produced by the latest generator - to reduce model oscillation. - Args: - buffer_size (int): The size of image buffer. If buffer_size = 0, - no buffer will be created. - buffer_ratio (float): The chance / possibility to use the images - previously stored in the buffer. - """ - - def __init__(self, buffer_size, buffer_ratio=0.5): - self.buffer_size = buffer_size - # create an empty buffer - if self.buffer_size > 0: - self.img_num = 0 - self.image_buffer = [] - self.buffer_ratio = buffer_ratio - - def query(self, images): - """Query current image batch using a history of generated images. - Args: - images (Tensor): Current image batch without history information. - """ - if self.buffer_size == 0: # if the buffer size is 0, do nothing - return images - return_images = [] - for image in images: - image = torch.unsqueeze(image.data, 0) - # if the buffer is not full, keep inserting current images - if self.img_num < self.buffer_size: - self.img_num = self.img_num + 1 - self.image_buffer.append(image) - return_images.append(image) - else: - use_buffer = torch.rand(1) < self.buffer_ratio - # by self.buffer_ratio, the buffer will return a previously - # stored image, and insert the current image into the buffer - if use_buffer: - random_id = torch.randint(0, self.buffer_size, (1,)).item() - image_tmp = self.image_buffer[random_id].clone() - self.image_buffer[random_id] = image - return_images.append(image_tmp) - # by (1 - self.buffer_ratio), the buffer will return the - # current image - else: - return_images.append(image) - # collect all the images and return - return_images = torch.cat(return_images, 0) - return return_images - - -class ResidualBlock(nn.Module): - def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None): - super(ResidualBlock, self).__init__() - - if use_bias is None: - # Only for IN, use bias since it does not have affine parameters. - use_bias = norm_type == "IN" - norm_layer = select_norm_layer(norm_type) - models = [nn.Sequential( - nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias), - norm_layer(num_channels), - nn.ReLU(inplace=True), - )] - if use_dropout: - models.append(nn.Dropout(0.5)) - models.append(nn.Sequential( - nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias), - norm_layer(num_channels), - )) - self.block = nn.Sequential(*models) - - def forward(self, x): - return x + self.block(x) - - -@MODEL.register_module() -class ResGenerator(nn.Module): - def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect', - norm_type="IN"): - super(ResGenerator, self).__init__() - assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.' - norm_layer = select_norm_layer(norm_type) - use_bias = norm_type == "IN" - - self.start_conv = nn.Sequential( - nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3, - bias=use_bias), - norm_layer(num_features=base_channels), - nn.ReLU(inplace=True) - ) - - # down sampling - submodules = [] - num_down_sampling = 2 - for i in range(num_down_sampling): - multiple = 2 ** i - submodules += [ - nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2, - kernel_size=3, stride=2, padding=1, bias=use_bias), - norm_layer(num_features=base_channels * multiple * 2), - nn.ReLU(inplace=True) - ] - self.encoder = nn.Sequential(*submodules) - - res_block_channels = num_down_sampling ** 2 * base_channels - self.resnet_middle = nn.Sequential( - *[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in - range(num_blocks)]) - - # up sampling - submodules = [] - for i in range(num_down_sampling): - multiple = 2 ** (num_down_sampling - i) - submodules += [ - nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2, - padding=1, output_padding=1, bias=use_bias), - norm_layer(num_features=base_channels * multiple // 2), - nn.ReLU(inplace=True), - ] - self.decoder = nn.Sequential(*submodules) - - self.end_conv = nn.Sequential( - nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode), - nn.Tanh() - ) - - def forward(self, x): - x = self.encoder(self.start_conv(x)) - x = self.resnet_middle(x) - return self.end_conv(self.decoder(x)) - - -@MODEL.register_module() -class PatchDiscriminator(nn.Module): - def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"): - super(PatchDiscriminator, self).__init__() - assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.' - norm_layer = select_norm_layer(norm_type) - use_bias = norm_type == "IN" - - kernel_size = 4 - padding = 1 - sequence = [ - nn.Conv2d(in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding), - nn.LeakyReLU(0.2, inplace=True), - ] - - # stacked intermediate layers, - # gradually increasing the number of filters - multiple_now = 1 - for n in range(1, num_conv): - multiple_prev = multiple_now - multiple_now = min(2 ** n, 8) - sequence += [ - nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=kernel_size, - padding=padding, stride=2, bias=use_bias), - norm_layer(base_channels * multiple_now), - nn.LeakyReLU(0.2, inplace=True) - ] - multiple_prev = multiple_now - multiple_now = min(2 ** num_conv, 8) - sequence += [ - nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size, stride=1, - padding=padding, bias=use_bias), - norm_layer(base_channels * multiple_now), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding) - ] - self.model = nn.Sequential(*sequence) - - def forward(self, x): - return self.model(x) diff --git a/model/__init__.py b/model/__init__.py index 6331c07..b3533b4 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -1,7 +1,6 @@ from model.registry import MODEL -import model.GAN.residual_generator +import model.GAN.CycleGAN import model.GAN.TAFG import model.GAN.UGATIT -import model.fewshot import model.GAN.wrapper import model.GAN.base diff --git a/model/fewshot.py b/model/fewshot.py deleted file mode 100644 index 851f6d9..0000000 --- a/model/fewshot.py +++ /dev/null @@ -1,105 +0,0 @@ -import math - -import torch.nn as nn - -from .registry import MODEL - - -# --- gaussian initialize --- -def init_layer(l): - # Initialization using fan-in - if isinstance(l, nn.Conv2d): - n = l.kernel_size[0] * l.kernel_size[1] * l.out_channels - l.weight.data.normal_(0, math.sqrt(2.0 / float(n))) - elif isinstance(l, nn.BatchNorm2d): - l.weight.data.fill_(1) - l.bias.data.fill_(0) - elif isinstance(l, nn.Linear): - l.bias.data.fill_(0) - - -class Flatten(nn.Module): - def __init__(self): - super(Flatten, self).__init__() - - def forward(self, x): - return x.view(x.size(0), -1) - - -class SimpleBlock(nn.Module): - def __init__(self, in_channels, out_channels, half_res, leakyrelu=False): - super(SimpleBlock, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - - self.block = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(out_channels), - ) - self.relu = nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True) - if in_channels != out_channels: - self.shortcut = nn.Sequential( - nn.Conv2d(in_channels, out_channels, 1, 2 if half_res else 1, bias=False), - nn.BatchNorm2d(out_channels) - ) - else: - self.shortcut = nn.Identity() - - def forward(self, x): - o = self.block(x) - return self.relu(o + self.shortcut(x)) - - -class ResNet(nn.Module): - def __init__(self, block, layers, dims, num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False): - super().__init__() - assert len(layers) == 4, 'Can have only four stages' - self.inplanes = 64 - - self.start = nn.Sequential( - nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), - nn.BatchNorm2d(self.inplanes), - nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True), - nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - ) - - trunk = [] - in_channels = self.inplanes - for i in range(4): - for j in range(layers[i]): - half_res = i >= 1 and j == 0 - trunk.append(block(in_channels, dims[i], half_res, leakyrelu)) - in_channels = dims[i] - if flatten: - trunk.append(nn.AvgPool2d(7)) - trunk.append(Flatten()) - if num_classes is not None: - if classifier_type == "linear": - trunk.append(nn.Linear(in_channels, num_classes)) - elif classifier_type == "distlinear": - pass - else: - raise ValueError(f"invalid classifier_type:{classifier_type}") - self.trunk = nn.Sequential(*trunk) - self.apply(init_layer) - - def forward(self, x): - return self.trunk(self.start(x)) - - -@MODEL.register_module() -def resnet10(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False): - return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu) - - -@MODEL.register_module() -def resnet18(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False): - return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu) - - -@MODEL.register_module() -def resnet34(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False): - return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)