raycv/engine/TAFG.py
2020-09-06 10:34:52 +08:00

138 lines
5.7 KiB
Python

from itertools import chain
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import ignite.distributed as idist
from ignite.engine import Events
from omegaconf import read_write, OmegaConf
from model.weight_init import generation_init_weights
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
class TAFGEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
def _process_batch(self, batch, inference=False):
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
return batch
def build_models(self) -> (dict, dict):
generators = dict(
main=build_model(self.config.model.generator)
)
discriminators = dict(
a=build_model(self.config.model.discriminator),
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["a"])
self.logger.debug(generators["main"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
generator = self.generators["main"]
batch = self._process_batch(batch, inference)
with torch.set_grad_enabled(not inference):
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"),
b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"),
)
return fake
def criterion_generators(self, batch, generated) -> dict:
batch = self._process_batch(batch)
loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
for phase in "ab":
pred_fake = self.discriminators[phase](generated[phase])
loss[f"gan_{phase}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
if self.config.loss.fm.weight > 0 and phase == "b":
pred_real = self.discriminators[phase](batch[phase])
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
# self.generators["main"].module.style_encoders["b"](batch["b"]),
# self.generators["main"].module.style_encoders["b"](generated["b"])
# )
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
# batch = self._process_batch(batch)
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase])
pred_fake = self.discriminators[phase](generated[phase].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
batch = self._process_batch(batch)
edge = batch["edge_a"][:, 0:1, :, :]
return dict(
a=[edge.expand(-1, 3, -1, -1).detach(), batch["a"].detach(), batch["b"].detach(), generated["a"].detach(),
generated["b"].detach()]
)
def change_engine(self, config, trainer):
@trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
def change_config(engine):
with read_write(config):
config.loss.perceptual.weight = 5
def run(task, config, _):
kernel = TAFGEngineKernel(config)
run_kernel(task, config, kernel)