from itertools import chain import ignite.distributed as idist import torch import torch.nn as nn from omegaconf import OmegaConf from engine.base.i2i import EngineKernel, run_kernel from engine.util.build import build_model from loss.I2I.context_loss import ContextLoss from loss.I2I.edge_loss import EdgeLoss from loss.I2I.perceptual_loss import PerceptualLoss from loss.gan import GANLoss from model.weight_init import generation_init_weights class TAEngineKernel(EngineKernel): def __init__(self, config): super().__init__(config) perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual) perceptual_loss_cfg.pop("weight") self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) style_loss_cfg = OmegaConf.to_container(config.loss.style) style_loss_cfg.pop("weight") self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device()) gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) context_loss_cfg = OmegaConf.to_container(config.loss.context) context_loss_cfg.pop("weight") self.context_loss = ContextLoss(**context_loss_cfg).to(idist.device()) self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss() self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to( idist.device()) def build_models(self) -> (dict, dict): generators = dict( anime=build_model(self.config.model.anime_generator), face=build_model(self.config.model.face_generator) ) discriminators = dict( anime=build_model(self.config.model.discriminator), face=build_model(self.config.model.discriminator) ) self.logger.debug(discriminators["face"]) self.logger.debug(generators["face"]) for m in chain(generators.values(), discriminators.values()): generation_init_weights(m) return generators, discriminators def setup_after_g(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(True) def setup_before_g(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(False) def forward(self, batch, inference=False) -> dict: with torch.set_grad_enabled(not inference): target_pose_anime = self.generators["anime"]( torch.cat([batch["face_1"], torch.flip(batch["anime_img"], dims=[3])], dim=1)) target_pose_face = self.generators["face"](target_pose_anime.mean(dim=1, keepdim=True), batch["face_0"]) return dict(fake_anime=target_pose_anime, fake_face=target_pose_face) def cal_gan_and_fm_loss(self, discriminator, generated_img, match_img=None): pred_fake = discriminator(generated_img) loss_gan = 0 for sub_pred_fake in pred_fake: # last output is actual prediction loss_gan += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True) if match_img is None: # do not cal feature match loss return loss_gan, 0 pred_real = discriminator(match_img) loss_fm = 0 num_scale_discriminator = len(pred_fake) for i in range(num_scale_discriminator): # last output is the final prediction, so we exclude it num_intermediate_outputs = len(pred_fake[i]) - 1 for j in range(num_intermediate_outputs): loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator loss_fm = self.config.loss.fm.weight * loss_fm return loss_gan, loss_fm def criterion_generators(self, batch, generated) -> dict: loss = dict() loss["face_style"] = self.config.loss.style.weight * self.style_loss( generated["fake_face"], batch["face_1"] ) loss["face_recon"] = self.config.loss.recon.weight * self.recon_loss( generated["fake_face"], batch["face_1"] ) loss["face_gan"], loss["face_fm"] = self.cal_gan_and_fm_loss( self.discriminators["face"], generated["fake_face"], batch["face_1"]) loss["anime_gan"], loss["anime_fm"] = self.cal_gan_and_fm_loss( self.discriminators["anime"], generated["fake_anime"], batch["anime_img"]) loss["anime_edge"] = self.config.loss.edge.weight * self.edge_loss( generated["fake_anime"], batch["face_1"], gt_is_edge=False, ) if self.config.loss.perceptual.weight > 0: loss["anime_perceptual"] = self.config.loss.perceptual.weight * self.perceptual_loss( generated["fake_anime"], batch["anime_img"] ) if self.config.loss.context.weight > 0: loss["anime_context"] = self.config.loss.context.weight * self.context_loss( generated["fake_anime"], batch["anime_img"], ) return loss def criterion_discriminators(self, batch, generated) -> dict: loss = dict() real = {"anime": "anime_img", "face": "face_1"} for phase in self.discriminators.keys(): pred_real = self.discriminators[phase](batch[real[phase]]) pred_fake = self.discriminators[phase](generated[f"fake_{phase}"].detach()) loss[f"gan_{phase}"] = 0 for i in range(len(pred_fake)): loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) + self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2 return loss def intermediate_images(self, batch, generated) -> dict: """ returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} :param batch: :param generated: dict of images :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ images = [batch["face_0"], batch["face_1"], batch["anime_img"], generated["fake_anime"].detach(), generated["fake_face"].detach()] return dict( b=[img for img in images] ) def run(task, config, _): kernel = TAEngineKernel(config) run_kernel(task, config, kernel)