102 lines
4.2 KiB
Python
102 lines
4.2 KiB
Python
from itertools import chain
|
|
|
|
import ignite.distributed as idist
|
|
import torch
|
|
import torch.nn as nn
|
|
from omegaconf import OmegaConf
|
|
|
|
from engine.base.i2i import EngineKernel, run_kernel
|
|
from engine.util.build import build_model
|
|
from loss.gan import GANLoss
|
|
from model.GAN.base import GANImageBuffer
|
|
from model.weight_init import generation_init_weights
|
|
|
|
|
|
class TAFGEngineKernel(EngineKernel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
|
gan_loss_cfg.pop("weight")
|
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
|
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
|
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
|
|
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
|
self.discriminators.keys()}
|
|
|
|
def build_models(self) -> (dict, dict):
|
|
generators = dict(
|
|
a2b=build_model(self.config.model.generator),
|
|
b2a=build_model(self.config.model.generator)
|
|
)
|
|
discriminators = dict(
|
|
a=build_model(self.config.model.discriminator),
|
|
b=build_model(self.config.model.discriminator)
|
|
)
|
|
self.logger.debug(discriminators["a"])
|
|
self.logger.debug(generators["a2b"])
|
|
|
|
for m in chain(generators.values(), discriminators.values()):
|
|
generation_init_weights(m)
|
|
|
|
return generators, discriminators
|
|
|
|
def setup_after_g(self):
|
|
for discriminator in self.discriminators.values():
|
|
discriminator.requires_grad_(True)
|
|
|
|
def setup_before_g(self):
|
|
for discriminator in self.discriminators.values():
|
|
discriminator.requires_grad_(False)
|
|
|
|
def forward(self, batch, inference=False) -> dict:
|
|
images = dict()
|
|
with torch.set_grad_enabled(not inference):
|
|
images["a2b"] = self.generators["a2b"](batch["a"])
|
|
images["b2a"] = self.generators["b2a"](batch["b"])
|
|
images["a2b2a"] = self.generators["b2a"](images["a2b"])
|
|
images["b2a2b"] = self.generators["a2b"](images["b2a"])
|
|
if self.config.loss.id.weight > 0:
|
|
images["a2a"] = self.generators["b2a"](batch["a"])
|
|
images["b2b"] = self.generators["a2b"](batch["b"])
|
|
return images
|
|
|
|
def criterion_generators(self, batch, generated) -> dict:
|
|
loss = dict()
|
|
for phase in ["a2b", "b2a"]:
|
|
loss[f"cycle_{phase[0]}"] = self.config.loss.cycle.weight * self.cycle_loss(
|
|
generated[f"{phase}2{phase[0]}"], batch[phase[0]])
|
|
loss[f"gan_{phase}"] = self.config.loss.gan.weight * self.gan_loss(
|
|
self.discriminators[phase[-1]](generated[phase]), True)
|
|
if self.config.loss.id.weight > 0:
|
|
loss[f"id_{phase[0]}"] = self.config.loss.id.weight * self.id_loss(
|
|
generated[f"{phase[0]}2{phase[0]}"], batch[phase[0]])
|
|
return loss
|
|
|
|
def criterion_discriminators(self, batch, generated) -> dict:
|
|
loss = dict()
|
|
for phase in "ab":
|
|
generated_image = self.image_buffers[phase].query(generated["b2a" if phase == "a" else "a2b"].detach())
|
|
loss[f"gan_{phase}"] = (self.gan_loss(self.discriminators[phase](generated_image), False,
|
|
is_discriminator=True) +
|
|
self.gan_loss(self.discriminators[phase](batch[phase]), True,
|
|
is_discriminator=True)) / 2
|
|
return loss
|
|
|
|
def intermediate_images(self, batch, generated) -> dict:
|
|
"""
|
|
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
:param batch:
|
|
:param generated: dict of images
|
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
"""
|
|
return dict(
|
|
a=[batch["a"].detach(), generated["a2b"].detach(), generated["a2b2a"].detach()],
|
|
b=[batch["b"].detach(), generated["b2a"].detach(), generated["b2a2b"].detach()],
|
|
)
|
|
|
|
|
|
def run(task, config, _):
|
|
kernel = TAFGEngineKernel(config)
|
|
run_kernel(task, config, kernel)
|