raycv/engine/talking_anime.py
2020-09-25 18:31:12 +08:00

154 lines
6.4 KiB
Python

from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.I2I.context_loss import ContextLoss
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TAEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
context_loss_cfg = OmegaConf.to_container(config.loss.context)
context_loss_cfg.pop("weight")
self.context_loss = ContextLoss(**context_loss_cfg).to(idist.device())
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device())
def build_models(self) -> (dict, dict):
generators = dict(
anime=build_model(self.config.model.anime_generator),
face=build_model(self.config.model.face_generator)
)
discriminators = dict(
anime=build_model(self.config.model.discriminator),
face=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["face"])
self.logger.debug(generators["face"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
with torch.set_grad_enabled(not inference):
target_pose_anime = self.generators["anime"](
torch.cat([batch["face_1"], torch.flip(batch["anime_img"], dims=[3])], dim=1))
target_pose_face = self.generators["face"](target_pose_anime.mean(dim=1, keepdim=True), batch["face_0"])
return dict(fake_anime=target_pose_anime, fake_face=target_pose_face)
def cal_gan_and_fm_loss(self, discriminator, generated_img, match_img=None):
pred_fake = discriminator(generated_img)
loss_gan = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss_gan += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
if match_img is None:
# do not cal feature match loss
return loss_gan, 0
pred_real = discriminator(match_img)
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss_fm = self.config.loss.fm.weight * loss_fm
return loss_gan, loss_fm
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
loss["face_style"] = self.config.loss.style.weight * self.style_loss(
generated["fake_face"], batch["face_1"]
)
loss["face_recon"] = self.config.loss.recon.weight * self.recon_loss(
generated["fake_face"], batch["face_1"]
)
loss["face_gan"], loss["face_fm"] = self.cal_gan_and_fm_loss(
self.discriminators["face"], generated["fake_face"], batch["face_1"])
loss["anime_gan"], loss["anime_fm"] = self.cal_gan_and_fm_loss(
self.discriminators["anime"], generated["fake_anime"], batch["anime_img"])
loss["anime_edge"] = self.config.loss.edge.weight * self.edge_loss(
generated["fake_anime"], batch["face_1"], gt_is_edge=False,
)
if self.config.loss.perceptual.weight > 0:
loss["anime_perceptual"] = self.config.loss.perceptual.weight * self.perceptual_loss(
generated["fake_anime"], batch["anime_img"]
)
if self.config.loss.context.weight > 0:
loss["anime_context"] = self.config.loss.context.weight * self.context_loss(
generated["fake_anime"], batch["anime_img"],
)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
real = {"anime": "anime_img", "face": "face_1"}
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[real[phase]])
pred_fake = self.discriminators[phase](generated[f"fake_{phase}"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
images = [batch["face_0"], batch["face_1"], batch["anime_img"], generated["fake_anime"].detach(),
generated["fake_face"].detach()]
return dict(
b=[img for img in images]
)
def run(task, config, _):
kernel = TAEngineKernel(config)
run_kernel(task, config, kernel)