from itertools import chain import ignite.distributed as idist import torch import torch.nn as nn from omegaconf import OmegaConf from engine.base.i2i import EngineKernel, run_kernel from engine.util.build import build_model from loss.gan import GANLoss from model.GAN.base import GANImageBuffer from model.weight_init import generation_init_weights class TAFGEngineKernel(EngineKernel): def __init__(self, config): super().__init__(config) gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss() self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in self.discriminators.keys()} def build_models(self) -> (dict, dict): generators = dict( a2b=build_model(self.config.model.generator), b2a=build_model(self.config.model.generator) ) discriminators = dict( a=build_model(self.config.model.discriminator), b=build_model(self.config.model.discriminator) ) self.logger.debug(discriminators["a"]) self.logger.debug(generators["a2b"]) for m in chain(generators.values(), discriminators.values()): generation_init_weights(m) return generators, discriminators def setup_after_g(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(True) def setup_before_g(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(False) def forward(self, batch, inference=False) -> dict: images = dict() with torch.set_grad_enabled(not inference): images["a2b"] = self.generators["a2b"](batch["a"]) images["b2a"] = self.generators["b2a"](batch["b"]) images["a2b2a"] = self.generators["b2a"](images["a2b"]) images["b2a2b"] = self.generators["a2b"](images["b2a"]) if self.config.loss.id.weight > 0: images["a2a"] = self.generators["b2a"](batch["a"]) images["b2b"] = self.generators["a2b"](batch["b"]) return images def criterion_generators(self, batch, generated) -> dict: loss = dict() for phase in ["a2b", "b2a"]: loss[f"cycle_{phase[0]}"] = self.config.loss.cycle.weight * self.cycle_loss( generated[f"{phase}2{phase[0]}"], batch[phase[0]]) loss[f"gan_{phase}"] = self.config.loss.gan.weight * self.gan_loss( self.discriminators[phase[-1]](generated[phase]), True) if self.config.loss.id.weight > 0: loss[f"id_{phase[0]}"] = self.config.loss.id.weight * self.id_loss( generated[f"{phase[0]}2{phase[0]}"], batch[phase[0]]) return loss def criterion_discriminators(self, batch, generated) -> dict: loss = dict() for phase in "ab": generated_image = self.image_buffers[phase].query(generated["b2a" if phase == "a" else "a2b"].detach()) loss[f"gan_{phase}"] = (self.gan_loss(self.discriminators[phase](generated_image), False, is_discriminator=True) + self.gan_loss(self.discriminators[phase](batch[phase]), True, is_discriminator=True)) / 2 return loss def intermediate_images(self, batch, generated) -> dict: """ returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} :param batch: :param generated: dict of images :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ return dict( a=[batch["a"].detach(), generated["a2b"].detach(), generated["a2b2a"].detach()], b=[batch["b"].detach(), generated["b2a"].detach(), generated["b2a2b"].detach()], ) def run(task, config, _): kernel = TAFGEngineKernel(config) run_kernel(task, config, kernel)