diff --git a/.idea/misc.xml b/.idea/misc.xml index 1eef74e..1b9173d 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/.idea/raycv.iml b/.idea/raycv.iml index a25e5bf..9781a97 100644 --- a/.idea/raycv.iml +++ b/.idea/raycv.iml @@ -2,7 +2,7 @@ - + diff --git a/configs/synthesizers/TAFG.yml b/configs/synthesizers/TAFG.yml new file mode 100644 index 0000000..f269560 --- /dev/null +++ b/configs/synthesizers/TAFG.yml @@ -0,0 +1,145 @@ +name: TAFG +engine: TAFG +result_dir: ./result +max_pairs: 1000000 + +misc: + random_seed: 324 + +checkpoint: + epoch_interval: 1 # one checkpoint every 1 epoch + n_saved: 2 + +interval: + print_per_iteration: 10 # print once per 10 iteration + tensorboard: + scalar: 100 + image: 2 + +model: + generator: + _type: TAHG-Generator + _bn_to_sync_bn: False + style_in_channels: 3 + content_in_channels: 1 + num_blocks: 4 + discriminator: + _type: MultiScaleDiscriminator + num_scale: 2 + discriminator_cfg: + _type: base-PatchDiscriminator + in_channels: 3 + base_channels: 64 + use_spectral: True + need_intermediate_feature: True + +loss: + gan: + loss_type: hinge + real_label_val: 1.0 + fake_label_val: 0.0 + weight: 1.0 + perceptual: + layer_weights: + "1": 0.03125 + "6": 0.0625 + "11": 0.125 + "20": 0.25 + "29": 1 + criterion: 'L1' + style_loss: False + perceptual_loss: True + weight: 1 + style: + layer_weights: + "1": 0.03125 + "6": 0.0625 + "11": 0.125 + "20": 0.25 + "29": 1 + criterion: 'L2' + style_loss: True + perceptual_loss: False + weight: 0 + fm: + level: 1 + weight: 1 + recon: + level: 1 + weight: 1 + +optimizers: + generator: + _type: Adam + lr: 0.0001 + betas: [ 0, 0.9 ] + weight_decay: 0.0001 + discriminator: + _type: Adam + lr: 4e-4 + betas: [ 0, 0.9 ] + weight_decay: 0.0001 + +data: + train: + scheduler: + start_proportion: 0.5 + target_lr: 0 + buffer_size: 50 + dataloader: + batch_size: 256 + shuffle: True + num_workers: 2 + pin_memory: True + drop_last: True + dataset: + _type: GenerationUnpairedDatasetWithEdge + root_a: "/data/i2i/VoxCeleb2Anime/trainA" + root_b: "/data/i2i/VoxCeleb2Anime/trainB" + edges_path: "/data/i2i/VoxCeleb2Anime/edges" + edge_type: "hed" + size: [128, 128] + random_pair: True + pipeline: + - Load + - Resize: + size: [128, 128] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + test: + dataloader: + batch_size: 8 + shuffle: False + num_workers: 1 + pin_memory: False + drop_last: False + dataset: + _type: GenerationUnpairedDatasetWithEdge + root_a: "/data/i2i/VoxCeleb2Anime/testA" + root_b: "/data/i2i/VoxCeleb2Anime/testB" + edges_path: "/data/i2i/VoxCeleb2Anime/edges" + edge_type: "hed" + random_pair: False + size: [128, 128] + pipeline: + - Load + - Resize: + size: [128, 128] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + video_dataset: + _type: SingleFolderDataset + root: "/data/i2i/VoxCeleb2Anime/test_video_frames/" + with_path: True + pipeline: + - Load + - Resize: + size: [ 256, 256 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] diff --git a/configs/synthesizers/TAHG.yml b/configs/synthesizers/TAHG.yml index 797bbf2..539b9a3 100644 --- a/configs/synthesizers/TAHG.yml +++ b/configs/synthesizers/TAHG.yml @@ -3,10 +3,6 @@ engine: TAHG result_dir: ./result max_pairs: 1000000 -distributed: - model: - # broadcast_buffers: False - misc: random_seed: 324 @@ -23,6 +19,7 @@ interval: model: generator: _type: TAHG-Generator + _bn_to_sync_bn: False style_in_channels: 3 content_in_channels: 1 num_blocks: 4 diff --git a/engine/TAFG.py b/engine/TAFG.py new file mode 100644 index 0000000..be13eeb --- /dev/null +++ b/engine/TAFG.py @@ -0,0 +1,133 @@ +from itertools import chain +from math import ceil + +from omegaconf import read_write, OmegaConf + +import torch +import torch.nn as nn +import torch.nn.functional as F +import ignite.distributed as idist + +import data +from engine.base.i2i import get_trainer, EngineKernel, build_model +from model.weight_init import generation_init_weights + +from loss.I2I.perceptual_loss import PerceptualLoss +from loss.gan import GANLoss + + +class TAFGEngineKernel(EngineKernel): + def __init__(self, config, logger): + super().__init__(config, logger) + perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual) + perceptual_loss_cfg.pop("weight") + self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) + + gan_loss_cfg = OmegaConf.to_container(config.loss.gan) + gan_loss_cfg.pop("weight") + self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) + + self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss() + self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() + + def build_models(self) -> (dict, dict): + generators = dict( + main=build_model(self.config.model.generator) + ) + discriminators = dict( + a=build_model(self.config.model.discriminator), + b=build_model(self.config.model.discriminator) + ) + self.logger.debug(discriminators["a"]) + self.logger.debug(generators["main"]) + + for m in chain(generators.values(), discriminators.values()): + generation_init_weights(m) + + return generators, discriminators + + def setup_before_d(self): + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(True) + + def setup_before_g(self): + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(False) + + def forward(self, batch, inference=False) -> dict: + generator = self.generators["main"] + with torch.set_grad_enabled(not inference): + fake = dict( + a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"), + b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"), + ) + return fake + + def criterion_generators(self, batch, generated) -> dict: + loss = dict() + loss["perceptual"], _, = self.perceptual_loss(generated["b"], batch["b"]) * self.config.loss.perceptual.weight + for phase in "ab": + pred_fake = self.discriminators[phase](generated[phase]) + for i, sub_pred_fake in enumerate(pred_fake): + # last output is actual prediction + loss[f"gan_{phase}_sub_{i}"] = self.gan_loss(sub_pred_fake[-1], True) + + if self.config.loss.fm.weight > 0 and phase == "b": + pred_real = self.discriminators[phase](batch[phase]) + loss_fm = 0 + num_scale_discriminator = len(pred_fake) + for i in range(num_scale_discriminator): + # last output is the final prediction, so we exclude it + num_intermediate_outputs = len(pred_fake[i]) - 1 + for j in range(num_intermediate_outputs): + loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator + loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm + loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight + return loss + + def criterion_discriminators(self, batch, generated) -> dict: + loss = dict() + for phase in self.discriminators.keys(): + pred_real = self.discriminators[phase](batch[phase]) + pred_fake = self.discriminators[phase](generated[phase].detach()) + loss[f"gan_{phase}"] = 0 + for i in range(len(pred_fake)): + loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) + + self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2 + return loss + + def intermediate_images(self, batch, generated) -> dict: + """ + returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + :param batch: + :param generated: dict of images + :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + """ + return dict( + a=[batch[f"edge_a"].expand(-1, 3, -1, -1).detach(), batch["a"].detach(), generated["a"].detach()], + b=[batch["b"].detach(), generated["b"].detach()] + ) + + +def run(task, config, logger): + assert torch.backends.cudnn.enabled + torch.backends.cudnn.benchmark = True + logger.info(f"start task {task}") + with read_write(config): + config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size) + + if task == "train": + train_dataset = data.DATASET.build_with(config.data.train.dataset) + logger.info(f"train with dataset:\n{train_dataset}") + train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) + trainer = get_trainer(config, TAFGEngineKernel(config, logger), len(train_data_loader)) + if idist.get_rank() == 0: + test_dataset = data.DATASET.build_with(config.data.test.dataset) + trainer.state.test_dataset = test_dataset + try: + trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader))) + except Exception: + import traceback + print(traceback.format_exc()) + else: + return NotImplemented(f"invalid task: {task}") diff --git a/engine/base/__init__.py b/engine/base/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/engine/base/i2i.py b/engine/base/i2i.py new file mode 100644 index 0000000..55ea67b --- /dev/null +++ b/engine/base/i2i.py @@ -0,0 +1,187 @@ +from itertools import chain +from math import ceil +from pathlib import Path +import logging + +import torch + +import ignite.distributed as idist +from ignite.engine import Events, Engine +from ignite.metrics import RunningAverage +from ignite.utils import convert_tensor +from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler +from ignite.contrib.handlers.param_scheduler import PiecewiseLinear + +from model import MODEL +from util.image import make_2d_grid +from util.handler import setup_common_handlers, setup_tensorboard_handler +from util.build import build_optimizer + +from omegaconf import OmegaConf + + +def build_model(cfg): + cfg = OmegaConf.to_container(cfg) + bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False) + model = MODEL.build_with(cfg) + if bn_to_sync_bn: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + return idist.auto_model(model) + + +def build_lr_schedulers(optimizers, config): + # TODO: support more scheduler type + g_milestones_values = [ + (0, config.optimizers.generator.lr), + (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr), + (config.max_iteration, config.data.train.scheduler.target_lr) + ] + d_milestones_values = [ + (0, config.optimizers.discriminator.lr), + (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr), + (config.max_iteration, config.data.train.scheduler.target_lr) + ] + return dict( + g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values), + d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values) + ) + + +class EngineKernel(object): + def __init__(self, config, logger): + self.config = config + self.logger = logger + self.generators, self.discriminators = self.build_models() + + def build_models(self) -> (dict, dict): + raise NotImplemented + + def to_save(self): + to_save = {} + to_save.update({f"generator_{k}": self.generators[k] for k in self.generators}) + to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators}) + return to_save + + def setup_before_d(self): + raise NotImplemented + + def setup_before_g(self): + raise NotImplemented + + def forward(self, batch, inference=False) -> dict: + raise NotImplemented + + def criterion_generators(self, batch, generated) -> dict: + raise NotImplemented + + def criterion_discriminators(self, batch, generated) -> dict: + raise NotImplemented + + def intermediate_images(self, batch, generated) -> dict: + """ + returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + :param batch: + :param generated: dict of images + :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + """ + raise NotImplemented + + +def get_trainer(config, ek: EngineKernel, iter_per_epoch): + logger = logging.getLogger(config.name) + generators, discriminators = ek.generators, ek.discriminators + optimizers = dict( + g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator), + d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator), + ) + logger.info("build optimizers", optimizers) + + lr_schedulers = build_lr_schedulers(optimizers, config) + logger.info(f"build lr_schedulers:\n{lr_schedulers}") + + def _step(engine, batch): + batch = convert_tensor(batch, idist.device()) + + generated = ek.forward(batch) + + ek.setup_before_g() + optimizers["g"].zero_grad() + loss_g = ek.criterion_generators(batch, generated) + sum(loss_g.values()).backward() + optimizers["g"].step() + + ek.setup_before_d() + optimizers["d"].zero_grad() + loss_d = ek.criterion_discriminators(batch, generated) + sum(loss_d.values()).backward() + optimizers["d"].step() + + return { + "loss": dict(g=loss_g, d=loss_d), + "img": ek.intermediate_images(batch, generated) + } + + trainer = Engine(_step) + trainer.logger = logger + for lr_shd in lr_schedulers.values(): + trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd) + + RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") + RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") + to_save = dict(trainer=trainer) + to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers}) + to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) + to_save.update(ek.to_save()) + setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True, + end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) + + def output_transform(output): + loss = dict() + for tl in output["loss"]: + if isinstance(output["loss"][tl], dict): + for l in output["loss"][tl]: + loss[f"{tl}_{l}"] = output["loss"][tl][l] + else: + loss[tl] = output["loss"][tl] + return loss + + tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch) + if tensorboard_handler is not None: + tensorboard_handler.attach( + trainer, + log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"), + event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1)) + ) + + @trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1))) + def show_images(engine): + output = engine.state.output + test_images = {} + for k in output["img"]: + image_list = output["img"][k] + tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list), engine.state.iteration) + test_images[k] = [] + for i in range(len(image_list)): + test_images[k].append([]) + + with torch.no_grad(): + g = torch.Generator() + g.manual_seed(config.misc.random_seed) + random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0] + for i in range(random_start, random_start + 10): + batch = convert_tensor(engine.state.test_dataset[i], idist.device()) + for k in batch: + batch[k] = batch[k].view(1, *batch[k].size()) + generated = ek.forward(batch) + images = ek.intermediate_images(batch, generated) + + for k in test_images: + for j in range(len(images[k])): + test_images[k][j].append(images[k][j]) + for k in test_images: + tensorboard_handler.writer.add_image( + f"test/{k}", + make_2d_grid([torch.cat(ti) for ti in test_images[k]]), + engine.state.iteration + ) + return trainer diff --git a/model/GAN/base.py b/model/GAN/base.py new file mode 100644 index 0000000..bd70ac2 --- /dev/null +++ b/model/GAN/base.py @@ -0,0 +1,61 @@ +import math + +import torch.nn as nn + +from model.normalization import select_norm_layer +from model import MODEL + + +# based SPADE or pix2pixHD Discriminator +@MODEL.register_module("base-PatchDiscriminator") +class PatchDiscriminator(nn.Module): + def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN", + need_intermediate_feature=False): + super().__init__() + self.need_intermediate_feature = need_intermediate_feature + + kernel_size = 4 + padding = math.ceil((kernel_size - 1.0) / 2) + norm_layer = select_norm_layer(norm_type) + use_bias = norm_type == "IN" + padding_mode = "zeros" + + sequence = [nn.Sequential( + nn.Conv2d(in_channels, base_channels, kernel_size, stride=2, padding=padding), + nn.LeakyReLU(0.2, False) + )] + multiple_now = 1 + for i in range(1, num_conv): + multiple_prev = multiple_now + multiple_now = min(2 ** i, 2 ** 3) + stride = 1 if i == num_conv - 1 else 2 + sequence.append(nn.Sequential( + self.build_conv2d(use_spectral, base_channels * multiple_prev, base_channels * multiple_now, + kernel_size, stride, padding, bias=use_bias, padding_mode=padding_mode), + norm_layer(base_channels * multiple_now), + nn.LeakyReLU(0.2, inplace=False), + )) + multiple_now = min(2 ** num_conv, 8) + sequence.append(nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding, + padding_mode=padding_mode)) + self.conv_blocks = nn.ModuleList(sequence) + + @staticmethod + def build_conv2d(use_spectral, in_channels: int, out_channels: int, kernel_size, stride, padding, + bias=True, padding_mode: str = 'zeros'): + conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, padding_mode=padding_mode) + if not use_spectral: + return conv + return nn.utils.spectral_norm(conv) + + def forward(self, x): + if self.need_intermediate_feature: + intermediate_feature = [] + for layer in self.conv_blocks: + x = layer(x) + intermediate_feature.append(x) + return tuple(intermediate_feature) + else: + for layer in self.conv_blocks: + x = layer(x) + return x diff --git a/model/GAN/wrapper.py b/model/GAN/wrapper.py new file mode 100644 index 0000000..f5b7538 --- /dev/null +++ b/model/GAN/wrapper.py @@ -0,0 +1,25 @@ +import torch.nn as nn +import torch.nn.functional as F + +from model import MODEL + + +@MODEL.register_module() +class MultiScaleDiscriminator(nn.Module): + def __init__(self, num_scale, discriminator_cfg): + super().__init__() + + self.discriminator_list = nn.ModuleList([ + MODEL.build_with(discriminator_cfg) for _ in range(num_scale) + ]) + + @staticmethod + def down_sample(x): + return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False) + + def forward(self, x): + results = [] + for discriminator in self.discriminator_list: + results.append(discriminator(x)) + x = self.down_sample(x) + return results diff --git a/model/__init__.py b/model/__init__.py index 08e1dfe..2b43540 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -3,3 +3,5 @@ import model.GAN.residual_generator import model.GAN.TAHG import model.GAN.UGATIT import model.fewshot +import model.GAN.wrapper +import model.GAN.base diff --git a/util/registry.py b/util/registry.py index 6fd4a75..f6d6a1b 100644 --- a/util/registry.py +++ b/util/registry.py @@ -2,7 +2,7 @@ import inspect from omegaconf.dictconfig import DictConfig from omegaconf import OmegaConf from types import ModuleType - +import warnings class _Registry: def __init__(self, name): @@ -51,6 +51,12 @@ class _Registry: else: raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}') + for k in args: + assert isinstance(k, str) + if k.startswith("_"): + warnings.warn(f"got param start with `_`: {k}, will remove it") + args.pop(k) + if not (isinstance(default_args, dict) or default_args is None): raise TypeError('default_args must be a dict or None, ' f'but got {type(default_args)}')