import ignite.distributed as idist import torch import torch.nn as nn import torch.nn.functional as F from omegaconf import OmegaConf from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel from engine.util.build import build_model from loss.I2I.perceptual_loss import PerceptualLoss from loss.gan import GANLoss def mse_loss(x, target_flag): return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) def bce_loss(x, target_flag): return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) class MUNITEngineKernel(EngineKernel): def __init__(self, config): super().__init__(config) perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual) perceptual_loss_cfg.pop("weight") self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() self.train_generator_first = False def build_models(self) -> (dict, dict): generators = dict( a=build_model(self.config.model.generator), b=build_model(self.config.model.generator) ) discriminators = dict( a=build_model(self.config.model.discriminator), b=build_model(self.config.model.discriminator) ) self.logger.debug(discriminators["a"]) self.logger.debug(generators["a"]) return generators, discriminators def setup_after_g(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(True) def setup_before_g(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(False) def forward(self, batch, inference=False) -> dict: styles = dict() contents = dict() images = dict() with torch.set_grad_enabled(not inference): for phase in "ab": contents[phase], styles[phase] = self.generators[phase].encode(batch[phase]) images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase]) styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device()) for phase in ("a2b", "b2a"): # images["a2b"] = Gb.decode(content_a, random_style_b) images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"]) # contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"]) contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase]) if self.config.loss.recon.cycle.weight > 0: images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]]) return dict(styles=styles, contents=contents, images=images) def criterion_generators(self, batch, generated) -> dict: loss = dict() for phase in "ab": loss[f"recon_image_{phase}"] = self.config.loss.recon.image.weight * self.recon_loss( batch[phase], generated["images"]["{0}2{0}".format(phase)]) loss[f"recon_content_{phase}"] = self.config.loss.recon.content.weight * self.recon_loss( generated["contents"][phase], generated["contents"]["a2b" if phase == "a" else "b2a"]) loss[f"recon_style_{phase}"] = self.config.loss.recon.style.weight * self.recon_loss( generated["styles"][f"random_{phase}"], generated["styles"]["b2a" if phase == "a" else "a2b"]) pred_fake = self.discriminators[phase](generated["images"]["b2a" if phase == "a" else "a2b"]) loss[f"gan_{phase}"] = 0 for sub_pred_fake in pred_fake: # last output is actual prediction loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True) if self.config.loss.recon.cycle.weight > 0: loss[f"recon_cycle_{phase}"] = self.config.loss.recon.cycle.weight * self.recon_loss( batch[phase], generated["images"]["a2b2a" if phase == "a" else "b2a2b"]) if self.config.loss.perceptual.weight > 0: loss[f"perceptual_{phase}"] = self.config.loss.perceptual.weight * self.perceptual_loss( batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"]) return loss def criterion_discriminators(self, batch, generated) -> dict: loss = dict() for phase in ("a2b", "b2a"): pred_real = self.discriminators[phase[-1]](batch[phase[-1]]) pred_fake = self.discriminators[phase[-1]](generated["images"][phase].detach()) loss[f"gan_{phase[-1]}"] = 0 for i in range(len(pred_fake)): loss[f"gan_{phase[-1]}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) + self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2 return loss def intermediate_images(self, batch, generated) -> dict: """ returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} :param batch: :param generated: dict of images :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ generated = {img: generated["images"][img].detach() for img in generated["images"]} images = dict() for phase in "ab": images[phase] = [batch[phase].detach(), generated["{0}2{0}".format(phase)], generated["a2b" if phase == "a" else "b2a"]] if self.config.loss.recon.cycle.weight > 0: images[phase].append(generated["a2b2a" if phase == "a" else "b2a2b"]) return images class MUNITTestEngineKernel(TestEngineKernel): def __init__(self, config): super().__init__(config) def build_generators(self) -> dict: generators = dict( a=build_model(self.config.model.generator), b=build_model(self.config.model.generator) ) return generators def to_load(self): return {f"generator_{k}": self.generators[k] for k in self.generators} def inference(self, batch): with torch.no_grad(): fake, _, _ = self.generators["a2b"](batch[0]) return fake.detach() def run(task, config, _): if task == "train": kernel = MUNITEngineKernel(config) run_kernel(task, config, kernel) elif task == "test": kernel = MUNITTestEngineKernel(config) run_kernel(task, config, kernel) else: raise NotImplemented