from itertools import chain from math import ceil from omegaconf import read_write, OmegaConf import torch import torch.nn as nn import torch.nn.functional as F import ignite.distributed as idist import data from engine.base.i2i import get_trainer, EngineKernel, build_model from model.weight_init import generation_init_weights from loss.I2I.perceptual_loss import PerceptualLoss from loss.gan import GANLoss class TAFGEngineKernel(EngineKernel): def __init__(self, config, logger): super().__init__(config, logger) perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual) perceptual_loss_cfg.pop("weight") self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss() self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() def build_models(self) -> (dict, dict): generators = dict( main=build_model(self.config.model.generator) ) discriminators = dict( a=build_model(self.config.model.discriminator), b=build_model(self.config.model.discriminator) ) self.logger.debug(discriminators["a"]) self.logger.debug(generators["main"]) for m in chain(generators.values(), discriminators.values()): generation_init_weights(m) return generators, discriminators def setup_before_d(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(True) def setup_before_g(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(False) def forward(self, batch, inference=False) -> dict: generator = self.generators["main"] with torch.set_grad_enabled(not inference): fake = dict( a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"), b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"), ) return fake def criterion_generators(self, batch, generated) -> dict: loss = dict() loss["perceptual"], _, = self.perceptual_loss(generated["b"], batch["b"]) * self.config.loss.perceptual.weight for phase in "ab": pred_fake = self.discriminators[phase](generated[phase]) for i, sub_pred_fake in enumerate(pred_fake): # last output is actual prediction loss[f"gan_{phase}_sub_{i}"] = self.gan_loss(sub_pred_fake[-1], True) if self.config.loss.fm.weight > 0 and phase == "b": pred_real = self.discriminators[phase](batch[phase]) loss_fm = 0 num_scale_discriminator = len(pred_fake) for i in range(num_scale_discriminator): # last output is the final prediction, so we exclude it num_intermediate_outputs = len(pred_fake[i]) - 1 for j in range(num_intermediate_outputs): loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight return loss def criterion_discriminators(self, batch, generated) -> dict: loss = dict() for phase in self.discriminators.keys(): pred_real = self.discriminators[phase](batch[phase]) pred_fake = self.discriminators[phase](generated[phase].detach()) loss[f"gan_{phase}"] = 0 for i in range(len(pred_fake)): loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) + self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2 return loss def intermediate_images(self, batch, generated) -> dict: """ returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} :param batch: :param generated: dict of images :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ return dict( a=[batch[f"edge_a"].expand(-1, 3, -1, -1).detach(), batch["a"].detach(), generated["a"].detach()], b=[batch["b"].detach(), generated["b"].detach()] ) def run(task, config, logger): assert torch.backends.cudnn.enabled torch.backends.cudnn.benchmark = True logger.info(f"start task {task}") with read_write(config): config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size) if task == "train": train_dataset = data.DATASET.build_with(config.data.train.dataset) logger.info(f"train with dataset:\n{train_dataset}") train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) trainer = get_trainer(config, TAFGEngineKernel(config, logger), len(train_data_loader)) if idist.get_rank() == 0: test_dataset = data.DATASET.build_with(config.data.test.dataset) trainer.state.test_dataset = test_dataset try: trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader))) except Exception: import traceback print(traceback.format_exc()) else: return NotImplemented(f"invalid task: {task}")