154 lines
6.4 KiB
Python
154 lines
6.4 KiB
Python
from itertools import chain
|
|
|
|
import ignite.distributed as idist
|
|
import torch
|
|
import torch.nn as nn
|
|
from omegaconf import OmegaConf
|
|
|
|
from engine.base.i2i import EngineKernel, run_kernel
|
|
from engine.util.build import build_model
|
|
from loss.I2I.context_loss import ContextLoss
|
|
from loss.I2I.edge_loss import EdgeLoss
|
|
from loss.I2I.perceptual_loss import PerceptualLoss
|
|
from loss.gan import GANLoss
|
|
from model.weight_init import generation_init_weights
|
|
|
|
|
|
class TAEngineKernel(EngineKernel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
|
perceptual_loss_cfg.pop("weight")
|
|
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
|
|
|
style_loss_cfg = OmegaConf.to_container(config.loss.style)
|
|
style_loss_cfg.pop("weight")
|
|
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
|
|
|
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
|
gan_loss_cfg.pop("weight")
|
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
|
|
|
context_loss_cfg = OmegaConf.to_container(config.loss.context)
|
|
context_loss_cfg.pop("weight")
|
|
self.context_loss = ContextLoss(**context_loss_cfg).to(idist.device())
|
|
|
|
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
|
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
|
|
|
|
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
|
|
idist.device())
|
|
|
|
def build_models(self) -> (dict, dict):
|
|
generators = dict(
|
|
anime=build_model(self.config.model.anime_generator),
|
|
face=build_model(self.config.model.face_generator)
|
|
)
|
|
discriminators = dict(
|
|
anime=build_model(self.config.model.discriminator),
|
|
face=build_model(self.config.model.discriminator)
|
|
)
|
|
self.logger.debug(discriminators["face"])
|
|
self.logger.debug(generators["face"])
|
|
|
|
for m in chain(generators.values(), discriminators.values()):
|
|
generation_init_weights(m)
|
|
|
|
return generators, discriminators
|
|
|
|
def setup_after_g(self):
|
|
for discriminator in self.discriminators.values():
|
|
discriminator.requires_grad_(True)
|
|
|
|
def setup_before_g(self):
|
|
for discriminator in self.discriminators.values():
|
|
discriminator.requires_grad_(False)
|
|
|
|
def forward(self, batch, inference=False) -> dict:
|
|
with torch.set_grad_enabled(not inference):
|
|
target_pose_anime = self.generators["anime"](
|
|
torch.cat([batch["face_1"], torch.flip(batch["anime_img"], dims=[3])], dim=1))
|
|
target_pose_face = self.generators["face"](target_pose_anime.mean(dim=1, keepdim=True), batch["face_0"])
|
|
|
|
return dict(fake_anime=target_pose_anime, fake_face=target_pose_face)
|
|
|
|
def cal_gan_and_fm_loss(self, discriminator, generated_img, match_img=None):
|
|
pred_fake = discriminator(generated_img)
|
|
loss_gan = 0
|
|
for sub_pred_fake in pred_fake:
|
|
# last output is actual prediction
|
|
loss_gan += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
|
|
|
|
if match_img is None:
|
|
# do not cal feature match loss
|
|
return loss_gan, 0
|
|
|
|
pred_real = discriminator(match_img)
|
|
loss_fm = 0
|
|
num_scale_discriminator = len(pred_fake)
|
|
for i in range(num_scale_discriminator):
|
|
# last output is the final prediction, so we exclude it
|
|
num_intermediate_outputs = len(pred_fake[i]) - 1
|
|
for j in range(num_intermediate_outputs):
|
|
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
|
loss_fm = self.config.loss.fm.weight * loss_fm
|
|
return loss_gan, loss_fm
|
|
|
|
def criterion_generators(self, batch, generated) -> dict:
|
|
loss = dict()
|
|
loss["face_style"] = self.config.loss.style.weight * self.style_loss(
|
|
generated["fake_face"], batch["face_1"]
|
|
)
|
|
loss["face_recon"] = self.config.loss.recon.weight * self.recon_loss(
|
|
generated["fake_face"], batch["face_1"]
|
|
)
|
|
loss["face_gan"], loss["face_fm"] = self.cal_gan_and_fm_loss(
|
|
self.discriminators["face"], generated["fake_face"], batch["face_1"])
|
|
loss["anime_gan"], loss["anime_fm"] = self.cal_gan_and_fm_loss(
|
|
self.discriminators["anime"], generated["fake_anime"], batch["anime_img"])
|
|
|
|
loss["anime_edge"] = self.config.loss.edge.weight * self.edge_loss(
|
|
generated["fake_anime"], batch["face_1"], gt_is_edge=False,
|
|
)
|
|
if self.config.loss.perceptual.weight > 0:
|
|
loss["anime_perceptual"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
|
generated["fake_anime"], batch["anime_img"]
|
|
)
|
|
if self.config.loss.context.weight > 0:
|
|
loss["anime_context"] = self.config.loss.context.weight * self.context_loss(
|
|
generated["fake_anime"], batch["anime_img"],
|
|
)
|
|
|
|
return loss
|
|
|
|
def criterion_discriminators(self, batch, generated) -> dict:
|
|
loss = dict()
|
|
real = {"anime": "anime_img", "face": "face_1"}
|
|
for phase in self.discriminators.keys():
|
|
pred_real = self.discriminators[phase](batch[real[phase]])
|
|
pred_fake = self.discriminators[phase](generated[f"fake_{phase}"].detach())
|
|
loss[f"gan_{phase}"] = 0
|
|
for i in range(len(pred_fake)):
|
|
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
|
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
|
return loss
|
|
|
|
def intermediate_images(self, batch, generated) -> dict:
|
|
"""
|
|
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
:param batch:
|
|
:param generated: dict of images
|
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
"""
|
|
images = [batch["face_0"], batch["face_1"], batch["anime_img"], generated["fake_anime"].detach(),
|
|
generated["fake_face"].detach()]
|
|
return dict(
|
|
b=[img for img in images]
|
|
)
|
|
|
|
|
|
def run(task, config, _):
|
|
kernel = TAEngineKernel(config)
|
|
run_kernel(task, config, kernel)
|