raycv/engine/CycleGAN.py
2020-10-14 18:55:51 +08:00

97 lines
4.1 KiB
Python

from itertools import chain
import torch
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from engine.util.container import GANImageBuffer, LossContainer
from engine.util.loss import pixel_loss, gan_loss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
from model.weight_init import generation_init_weights
class CycleGANEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
self.gan_loss = gan_loss(config.loss.gan)
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss())
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
def build_models(self) -> (dict, dict):
generators = dict(
a2b=build_model(self.config.model.generator),
b2a=build_model(self.config.model.generator)
)
discriminators = dict(
a=build_model(self.config.model.discriminator),
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["a"])
self.logger.debug(generators["a2b"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
images = dict()
with torch.set_grad_enabled(not inference):
images["a2b"] = self.generators["a2b"](batch["a"])
images["b2a"] = self.generators["b2a"](batch["b"])
images["a2b2a"] = self.generators["b2a"](images["a2b"])
images["b2a2b"] = self.generators["a2b"](images["b2a"])
if self.id_loss.weight > 0:
images["a2a"] = self.generators["b2a"](batch["a"])
images["b2b"] = self.generators["a2b"](batch["b"])
return images
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for ph in "ab":
loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph])
loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph])
loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph])
loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(
self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"]), True)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
generated_image = self.image_buffers[phase].query(generated["b2a" if phase == "a" else "a2b"].detach())
loss[f"gan_{phase}"] = (self.gan_loss(self.discriminators[phase](generated_image), False,
is_discriminator=True) +
self.gan_loss(self.discriminators[phase](batch[phase]), True,
is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
return dict(
a=[batch["a"].detach(), generated["a2b"].detach(), generated["a2b2a"].detach()],
b=[batch["b"].detach(), generated["b2a"].detach(), generated["b2a2b"].detach()],
)
def run(task, config, _):
kernel = CycleGANEngineKernel(config)
run_kernel(task, config, kernel)