from itertools import chain import ignite.distributed as idist import torch import torch.nn as nn from omegaconf import OmegaConf from engine.base.i2i import EngineKernel, run_kernel from engine.util.build import build_model from loss.I2I.edge_loss import EdgeLoss from loss.I2I.perceptual_loss import PerceptualLoss from loss.gan import GANLoss from model.weight_init import generation_init_weights class TAFGEngineKernel(EngineKernel): def __init__(self, config): super().__init__(config) perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual) perceptual_loss_cfg.pop("weight") self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) style_loss_cfg = OmegaConf.to_container(config.loss.style) style_loss_cfg.pop("weight") self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device()) gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss() self.content_recon_loss = nn.L1Loss() if config.loss.content_recon.level == 1 else nn.MSELoss() self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to( idist.device()) def _process_batch(self, batch, inference=False): # batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size()) return batch def build_models(self) -> (dict, dict): generators = dict( main=build_model(self.config.model.generator) ) discriminators = dict( a=build_model(self.config.model.discriminator), b=build_model(self.config.model.discriminator) ) self.logger.debug(discriminators["a"]) self.logger.debug(generators["main"]) for m in chain(generators.values(), discriminators.values()): generation_init_weights(m) return generators, discriminators def setup_after_g(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(True) def setup_before_g(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(False) def forward(self, batch, inference=False) -> dict: generator = self.generators["main"] batch = self._process_batch(batch, inference) styles = dict() contents = dict() images = dict() with torch.set_grad_enabled(not inference): contents["a"], styles["a"] = generator.encode(batch["a"]["edge"], batch["a"]["img"], "a", "a") contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b") for ph in "ab": images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph) if self.engine.state.epoch > self.config.misc.add_new_loss_epoch: styles[f"random_b"] = torch.randn_like(styles["b"]).to(idist.device()) images["a2b"] = generator.decode(contents["a"], styles["random_b"], "b") contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]), images["a2b"], "b", "b") images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b") images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a") return dict(styles=styles, contents=contents, images=images) def criterion_generators(self, batch, generated) -> dict: batch = self._process_batch(batch) loss = dict() for ph in "ab": loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss( generated["images"][f"{ph}2{ph}"], batch[ph]["img"]) pred_fake = self.discriminators[ph](generated["images"][f"{ph}2{ph}"]) loss[f"gan_{ph}"] = 0 for sub_pred_fake in pred_fake: # last output is actual prediction loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight if self.engine.state.epoch == self.config.misc.add_new_loss_epoch: self.generators["main"].style_converters.requires_grad_(False) self.generators["main"].style_encoders.requires_grad_(False) if self.engine.state.epoch > self.config.misc.add_new_loss_epoch: pred_fake = self.discriminators[ph](generated["images"]["a2b"]) loss["gan_a2b"] = 0 for sub_pred_fake in pred_fake: # last output is actual prediction loss["gan_a2b"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss( generated["contents"]["a"], generated["contents"]["recon_a"] ) loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss( generated["styles"]["random_b"], generated["styles"]["recon_b"] ) if self.config.loss.perceptual.weight > 0: loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss( batch["a"]["img"], generated["images"]["a2b"] ) if self.config.loss.cycle.weight > 0: loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss( batch["a"]["img"], generated["images"][f"cycle_a"] ) # for ph in "ab": # # if self.config.loss.style.weight > 0: # loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss( # batch[ph]["img"], generated["images"][f"a2{ph}"] # ) if self.config.loss.edge.weight > 0: loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss( generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :] ) return loss def criterion_discriminators(self, batch, generated) -> dict: loss = dict() if self.engine.state.epoch > self.config.misc.add_new_loss_epoch: for phase in self.discriminators.keys(): pred_real = self.discriminators[phase](batch[phase]["img"]) pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach()) pred_fake_2 = self.discriminators[phase](generated["images"]["a2b"].detach()) loss[f"gan_{phase}"] = 0 for i in range(len(pred_fake)): loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) + self.gan_loss(pred_fake_2[i][-1], False, is_discriminator=True) + self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 3 else: for phase in self.discriminators.keys(): pred_real = self.discriminators[phase](batch[phase]["img"]) pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach()) loss[f"gan_{phase}"] = 0 for i in range(len(pred_fake)): loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) + self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2 return loss def intermediate_images(self, batch, generated) -> dict: """ returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} :param batch: :param generated: dict of images :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ batch = self._process_batch(batch) if self.engine.state.epoch > self.config.misc.add_new_loss_epoch: return dict( a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(), batch["a"]["img"].detach(), generated["images"]["a2a"].detach(), generated["images"]["a2b"].detach(), generated["images"]["cycle_a"].detach(), ], b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(), batch["b"]["img"].detach(), generated["images"]["b2b"].detach(), generated["images"]["cycle_b"].detach()] ) else: return dict( a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(), batch["a"]["img"].detach(), generated["images"]["a2a"].detach(), ], b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(), batch["b"]["img"].detach(), generated["images"]["b2b"].detach(), ] ) def change_engine(self, config, trainer): pass # @trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3))) # def change_config(engine): # with read_write(config): # config.loss.perceptual.weight = 5 def run(task, config, _): kernel = TAFGEngineKernel(config) run_kernel(task, config, kernel)