diff --git a/.idea/misc.xml b/.idea/misc.xml
index 1eef74e..1b9173d 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,4 +1,4 @@
-
+
\ No newline at end of file
diff --git a/.idea/raycv.iml b/.idea/raycv.iml
index a25e5bf..9781a97 100644
--- a/.idea/raycv.iml
+++ b/.idea/raycv.iml
@@ -2,7 +2,7 @@
-
+
diff --git a/configs/synthesizers/TAFG.yml b/configs/synthesizers/TAFG.yml
new file mode 100644
index 0000000..f269560
--- /dev/null
+++ b/configs/synthesizers/TAFG.yml
@@ -0,0 +1,145 @@
+name: TAFG
+engine: TAFG
+result_dir: ./result
+max_pairs: 1000000
+
+misc:
+ random_seed: 324
+
+checkpoint:
+ epoch_interval: 1 # one checkpoint every 1 epoch
+ n_saved: 2
+
+interval:
+ print_per_iteration: 10 # print once per 10 iteration
+ tensorboard:
+ scalar: 100
+ image: 2
+
+model:
+ generator:
+ _type: TAHG-Generator
+ _bn_to_sync_bn: False
+ style_in_channels: 3
+ content_in_channels: 1
+ num_blocks: 4
+ discriminator:
+ _type: MultiScaleDiscriminator
+ num_scale: 2
+ discriminator_cfg:
+ _type: base-PatchDiscriminator
+ in_channels: 3
+ base_channels: 64
+ use_spectral: True
+ need_intermediate_feature: True
+
+loss:
+ gan:
+ loss_type: hinge
+ real_label_val: 1.0
+ fake_label_val: 0.0
+ weight: 1.0
+ perceptual:
+ layer_weights:
+ "1": 0.03125
+ "6": 0.0625
+ "11": 0.125
+ "20": 0.25
+ "29": 1
+ criterion: 'L1'
+ style_loss: False
+ perceptual_loss: True
+ weight: 1
+ style:
+ layer_weights:
+ "1": 0.03125
+ "6": 0.0625
+ "11": 0.125
+ "20": 0.25
+ "29": 1
+ criterion: 'L2'
+ style_loss: True
+ perceptual_loss: False
+ weight: 0
+ fm:
+ level: 1
+ weight: 1
+ recon:
+ level: 1
+ weight: 1
+
+optimizers:
+ generator:
+ _type: Adam
+ lr: 0.0001
+ betas: [ 0, 0.9 ]
+ weight_decay: 0.0001
+ discriminator:
+ _type: Adam
+ lr: 4e-4
+ betas: [ 0, 0.9 ]
+ weight_decay: 0.0001
+
+data:
+ train:
+ scheduler:
+ start_proportion: 0.5
+ target_lr: 0
+ buffer_size: 50
+ dataloader:
+ batch_size: 256
+ shuffle: True
+ num_workers: 2
+ pin_memory: True
+ drop_last: True
+ dataset:
+ _type: GenerationUnpairedDatasetWithEdge
+ root_a: "/data/i2i/VoxCeleb2Anime/trainA"
+ root_b: "/data/i2i/VoxCeleb2Anime/trainB"
+ edges_path: "/data/i2i/VoxCeleb2Anime/edges"
+ edge_type: "hed"
+ size: [128, 128]
+ random_pair: True
+ pipeline:
+ - Load
+ - Resize:
+ size: [128, 128]
+ - ToTensor
+ - Normalize:
+ mean: [ 0.5, 0.5, 0.5 ]
+ std: [ 0.5, 0.5, 0.5 ]
+ test:
+ dataloader:
+ batch_size: 8
+ shuffle: False
+ num_workers: 1
+ pin_memory: False
+ drop_last: False
+ dataset:
+ _type: GenerationUnpairedDatasetWithEdge
+ root_a: "/data/i2i/VoxCeleb2Anime/testA"
+ root_b: "/data/i2i/VoxCeleb2Anime/testB"
+ edges_path: "/data/i2i/VoxCeleb2Anime/edges"
+ edge_type: "hed"
+ random_pair: False
+ size: [128, 128]
+ pipeline:
+ - Load
+ - Resize:
+ size: [128, 128]
+ - ToTensor
+ - Normalize:
+ mean: [ 0.5, 0.5, 0.5 ]
+ std: [ 0.5, 0.5, 0.5 ]
+ video_dataset:
+ _type: SingleFolderDataset
+ root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
+ with_path: True
+ pipeline:
+ - Load
+ - Resize:
+ size: [ 256, 256 ]
+ - ToTensor
+ - Normalize:
+ mean: [ 0.5, 0.5, 0.5 ]
+ std: [ 0.5, 0.5, 0.5 ]
diff --git a/configs/synthesizers/TAHG.yml b/configs/synthesizers/TAHG.yml
index 797bbf2..539b9a3 100644
--- a/configs/synthesizers/TAHG.yml
+++ b/configs/synthesizers/TAHG.yml
@@ -3,10 +3,6 @@ engine: TAHG
result_dir: ./result
max_pairs: 1000000
-distributed:
- model:
- # broadcast_buffers: False
-
misc:
random_seed: 324
@@ -23,6 +19,7 @@ interval:
model:
generator:
_type: TAHG-Generator
+ _bn_to_sync_bn: False
style_in_channels: 3
content_in_channels: 1
num_blocks: 4
diff --git a/engine/TAFG.py b/engine/TAFG.py
new file mode 100644
index 0000000..be13eeb
--- /dev/null
+++ b/engine/TAFG.py
@@ -0,0 +1,133 @@
+from itertools import chain
+from math import ceil
+
+from omegaconf import read_write, OmegaConf
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import ignite.distributed as idist
+
+import data
+from engine.base.i2i import get_trainer, EngineKernel, build_model
+from model.weight_init import generation_init_weights
+
+from loss.I2I.perceptual_loss import PerceptualLoss
+from loss.gan import GANLoss
+
+
+class TAFGEngineKernel(EngineKernel):
+ def __init__(self, config, logger):
+ super().__init__(config, logger)
+ perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
+ perceptual_loss_cfg.pop("weight")
+ self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
+
+ gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
+ gan_loss_cfg.pop("weight")
+ self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
+
+ self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
+ self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
+
+ def build_models(self) -> (dict, dict):
+ generators = dict(
+ main=build_model(self.config.model.generator)
+ )
+ discriminators = dict(
+ a=build_model(self.config.model.discriminator),
+ b=build_model(self.config.model.discriminator)
+ )
+ self.logger.debug(discriminators["a"])
+ self.logger.debug(generators["main"])
+
+ for m in chain(generators.values(), discriminators.values()):
+ generation_init_weights(m)
+
+ return generators, discriminators
+
+ def setup_before_d(self):
+ for discriminator in self.discriminators.values():
+ discriminator.requires_grad_(True)
+
+ def setup_before_g(self):
+ for discriminator in self.discriminators.values():
+ discriminator.requires_grad_(False)
+
+ def forward(self, batch, inference=False) -> dict:
+ generator = self.generators["main"]
+ with torch.set_grad_enabled(not inference):
+ fake = dict(
+ a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"),
+ b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"),
+ )
+ return fake
+
+ def criterion_generators(self, batch, generated) -> dict:
+ loss = dict()
+ loss["perceptual"], _, = self.perceptual_loss(generated["b"], batch["b"]) * self.config.loss.perceptual.weight
+ for phase in "ab":
+ pred_fake = self.discriminators[phase](generated[phase])
+ for i, sub_pred_fake in enumerate(pred_fake):
+ # last output is actual prediction
+ loss[f"gan_{phase}_sub_{i}"] = self.gan_loss(sub_pred_fake[-1], True)
+
+ if self.config.loss.fm.weight > 0 and phase == "b":
+ pred_real = self.discriminators[phase](batch[phase])
+ loss_fm = 0
+ num_scale_discriminator = len(pred_fake)
+ for i in range(num_scale_discriminator):
+ # last output is the final prediction, so we exclude it
+ num_intermediate_outputs = len(pred_fake[i]) - 1
+ for j in range(num_intermediate_outputs):
+ loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
+ loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
+ loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight
+ return loss
+
+ def criterion_discriminators(self, batch, generated) -> dict:
+ loss = dict()
+ for phase in self.discriminators.keys():
+ pred_real = self.discriminators[phase](batch[phase])
+ pred_fake = self.discriminators[phase](generated[phase].detach())
+ loss[f"gan_{phase}"] = 0
+ for i in range(len(pred_fake)):
+ loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ + self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
+ return loss
+
+ def intermediate_images(self, batch, generated) -> dict:
+ """
+ returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
+ :param batch:
+ :param generated: dict of images
+ :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
+ """
+ return dict(
+ a=[batch[f"edge_a"].expand(-1, 3, -1, -1).detach(), batch["a"].detach(), generated["a"].detach()],
+ b=[batch["b"].detach(), generated["b"].detach()]
+ )
+
+
+def run(task, config, logger):
+ assert torch.backends.cudnn.enabled
+ torch.backends.cudnn.benchmark = True
+ logger.info(f"start task {task}")
+ with read_write(config):
+ config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
+
+ if task == "train":
+ train_dataset = data.DATASET.build_with(config.data.train.dataset)
+ logger.info(f"train with dataset:\n{train_dataset}")
+ train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
+ trainer = get_trainer(config, TAFGEngineKernel(config, logger), len(train_data_loader))
+ if idist.get_rank() == 0:
+ test_dataset = data.DATASET.build_with(config.data.test.dataset)
+ trainer.state.test_dataset = test_dataset
+ try:
+ trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
+ except Exception:
+ import traceback
+ print(traceback.format_exc())
+ else:
+ return NotImplemented(f"invalid task: {task}")
diff --git a/engine/base/__init__.py b/engine/base/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/engine/base/i2i.py b/engine/base/i2i.py
new file mode 100644
index 0000000..55ea67b
--- /dev/null
+++ b/engine/base/i2i.py
@@ -0,0 +1,187 @@
+from itertools import chain
+from math import ceil
+from pathlib import Path
+import logging
+
+import torch
+
+import ignite.distributed as idist
+from ignite.engine import Events, Engine
+from ignite.metrics import RunningAverage
+from ignite.utils import convert_tensor
+from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
+from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
+
+from model import MODEL
+from util.image import make_2d_grid
+from util.handler import setup_common_handlers, setup_tensorboard_handler
+from util.build import build_optimizer
+
+from omegaconf import OmegaConf
+
+
+def build_model(cfg):
+ cfg = OmegaConf.to_container(cfg)
+ bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
+ model = MODEL.build_with(cfg)
+ if bn_to_sync_bn:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ return idist.auto_model(model)
+
+
+def build_lr_schedulers(optimizers, config):
+ # TODO: support more scheduler type
+ g_milestones_values = [
+ (0, config.optimizers.generator.lr),
+ (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
+ (config.max_iteration, config.data.train.scheduler.target_lr)
+ ]
+ d_milestones_values = [
+ (0, config.optimizers.discriminator.lr),
+ (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
+ (config.max_iteration, config.data.train.scheduler.target_lr)
+ ]
+ return dict(
+ g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
+ d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
+ )
+
+
+class EngineKernel(object):
+ def __init__(self, config, logger):
+ self.config = config
+ self.logger = logger
+ self.generators, self.discriminators = self.build_models()
+
+ def build_models(self) -> (dict, dict):
+ raise NotImplemented
+
+ def to_save(self):
+ to_save = {}
+ to_save.update({f"generator_{k}": self.generators[k] for k in self.generators})
+ to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators})
+ return to_save
+
+ def setup_before_d(self):
+ raise NotImplemented
+
+ def setup_before_g(self):
+ raise NotImplemented
+
+ def forward(self, batch, inference=False) -> dict:
+ raise NotImplemented
+
+ def criterion_generators(self, batch, generated) -> dict:
+ raise NotImplemented
+
+ def criterion_discriminators(self, batch, generated) -> dict:
+ raise NotImplemented
+
+ def intermediate_images(self, batch, generated) -> dict:
+ """
+ returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
+ :param batch:
+ :param generated: dict of images
+ :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
+ """
+ raise NotImplemented
+
+
+def get_trainer(config, ek: EngineKernel, iter_per_epoch):
+ logger = logging.getLogger(config.name)
+ generators, discriminators = ek.generators, ek.discriminators
+ optimizers = dict(
+ g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
+ d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
+ )
+ logger.info("build optimizers", optimizers)
+
+ lr_schedulers = build_lr_schedulers(optimizers, config)
+ logger.info(f"build lr_schedulers:\n{lr_schedulers}")
+
+ def _step(engine, batch):
+ batch = convert_tensor(batch, idist.device())
+
+ generated = ek.forward(batch)
+
+ ek.setup_before_g()
+ optimizers["g"].zero_grad()
+ loss_g = ek.criterion_generators(batch, generated)
+ sum(loss_g.values()).backward()
+ optimizers["g"].step()
+
+ ek.setup_before_d()
+ optimizers["d"].zero_grad()
+ loss_d = ek.criterion_discriminators(batch, generated)
+ sum(loss_d.values()).backward()
+ optimizers["d"].step()
+
+ return {
+ "loss": dict(g=loss_g, d=loss_d),
+ "img": ek.intermediate_images(batch, generated)
+ }
+
+ trainer = Engine(_step)
+ trainer.logger = logger
+ for lr_shd in lr_schedulers.values():
+ trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
+
+ RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
+ RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
+ to_save = dict(trainer=trainer)
+ to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
+ to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
+ to_save.update(ek.to_save())
+ setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
+ end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
+
+ def output_transform(output):
+ loss = dict()
+ for tl in output["loss"]:
+ if isinstance(output["loss"][tl], dict):
+ for l in output["loss"][tl]:
+ loss[f"{tl}_{l}"] = output["loss"][tl][l]
+ else:
+ loss[tl] = output["loss"][tl]
+ return loss
+
+ tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
+ if tensorboard_handler is not None:
+ tensorboard_handler.attach(
+ trainer,
+ log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
+ event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
+ )
+
+ @trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
+ def show_images(engine):
+ output = engine.state.output
+ test_images = {}
+ for k in output["img"]:
+ image_list = output["img"][k]
+ tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list), engine.state.iteration)
+ test_images[k] = []
+ for i in range(len(image_list)):
+ test_images[k].append([])
+
+ with torch.no_grad():
+ g = torch.Generator()
+ g.manual_seed(config.misc.random_seed)
+ random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
+ for i in range(random_start, random_start + 10):
+ batch = convert_tensor(engine.state.test_dataset[i], idist.device())
+ for k in batch:
+ batch[k] = batch[k].view(1, *batch[k].size())
+ generated = ek.forward(batch)
+ images = ek.intermediate_images(batch, generated)
+
+ for k in test_images:
+ for j in range(len(images[k])):
+ test_images[k][j].append(images[k][j])
+ for k in test_images:
+ tensorboard_handler.writer.add_image(
+ f"test/{k}",
+ make_2d_grid([torch.cat(ti) for ti in test_images[k]]),
+ engine.state.iteration
+ )
+ return trainer
diff --git a/model/GAN/base.py b/model/GAN/base.py
new file mode 100644
index 0000000..bd70ac2
--- /dev/null
+++ b/model/GAN/base.py
@@ -0,0 +1,61 @@
+import math
+
+import torch.nn as nn
+
+from model.normalization import select_norm_layer
+from model import MODEL
+
+
+# based SPADE or pix2pixHD Discriminator
+@MODEL.register_module("base-PatchDiscriminator")
+class PatchDiscriminator(nn.Module):
+ def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
+ need_intermediate_feature=False):
+ super().__init__()
+ self.need_intermediate_feature = need_intermediate_feature
+
+ kernel_size = 4
+ padding = math.ceil((kernel_size - 1.0) / 2)
+ norm_layer = select_norm_layer(norm_type)
+ use_bias = norm_type == "IN"
+ padding_mode = "zeros"
+
+ sequence = [nn.Sequential(
+ nn.Conv2d(in_channels, base_channels, kernel_size, stride=2, padding=padding),
+ nn.LeakyReLU(0.2, False)
+ )]
+ multiple_now = 1
+ for i in range(1, num_conv):
+ multiple_prev = multiple_now
+ multiple_now = min(2 ** i, 2 ** 3)
+ stride = 1 if i == num_conv - 1 else 2
+ sequence.append(nn.Sequential(
+ self.build_conv2d(use_spectral, base_channels * multiple_prev, base_channels * multiple_now,
+ kernel_size, stride, padding, bias=use_bias, padding_mode=padding_mode),
+ norm_layer(base_channels * multiple_now),
+ nn.LeakyReLU(0.2, inplace=False),
+ ))
+ multiple_now = min(2 ** num_conv, 8)
+ sequence.append(nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding,
+ padding_mode=padding_mode))
+ self.conv_blocks = nn.ModuleList(sequence)
+
+ @staticmethod
+ def build_conv2d(use_spectral, in_channels: int, out_channels: int, kernel_size, stride, padding,
+ bias=True, padding_mode: str = 'zeros'):
+ conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, padding_mode=padding_mode)
+ if not use_spectral:
+ return conv
+ return nn.utils.spectral_norm(conv)
+
+ def forward(self, x):
+ if self.need_intermediate_feature:
+ intermediate_feature = []
+ for layer in self.conv_blocks:
+ x = layer(x)
+ intermediate_feature.append(x)
+ return tuple(intermediate_feature)
+ else:
+ for layer in self.conv_blocks:
+ x = layer(x)
+ return x
diff --git a/model/GAN/wrapper.py b/model/GAN/wrapper.py
new file mode 100644
index 0000000..f5b7538
--- /dev/null
+++ b/model/GAN/wrapper.py
@@ -0,0 +1,25 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+from model import MODEL
+
+
+@MODEL.register_module()
+class MultiScaleDiscriminator(nn.Module):
+ def __init__(self, num_scale, discriminator_cfg):
+ super().__init__()
+
+ self.discriminator_list = nn.ModuleList([
+ MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
+ ])
+
+ @staticmethod
+ def down_sample(x):
+ return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
+
+ def forward(self, x):
+ results = []
+ for discriminator in self.discriminator_list:
+ results.append(discriminator(x))
+ x = self.down_sample(x)
+ return results
diff --git a/model/__init__.py b/model/__init__.py
index 08e1dfe..2b43540 100644
--- a/model/__init__.py
+++ b/model/__init__.py
@@ -3,3 +3,5 @@ import model.GAN.residual_generator
import model.GAN.TAHG
import model.GAN.UGATIT
import model.fewshot
+import model.GAN.wrapper
+import model.GAN.base
diff --git a/util/registry.py b/util/registry.py
index 6fd4a75..f6d6a1b 100644
--- a/util/registry.py
+++ b/util/registry.py
@@ -2,7 +2,7 @@ import inspect
from omegaconf.dictconfig import DictConfig
from omegaconf import OmegaConf
from types import ModuleType
-
+import warnings
class _Registry:
def __init__(self, name):
@@ -51,6 +51,12 @@ class _Registry:
else:
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
+ for k in args:
+ assert isinstance(k, str)
+ if k.startswith("_"):
+ warnings.warn(f"got param start with `_`: {k}, will remove it")
+ args.pop(k)
+
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')