from itertools import chain import torch from engine.base.i2i import EngineKernel, run_kernel from engine.util.build import build_model from engine.util.container import GANImageBuffer, LossContainer from engine.util.loss import gan_loss, feature_match_loss, perceptual_loss from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss from model.weight_init import generation_init_weights class GauGANEngineKernel(EngineKernel): def __init__(self, config): super().__init__(config) self.gan_loss = gan_loss(config.loss.gan) self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite")) self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same")) self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual)) self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in self.discriminators.keys()} def build_models(self) -> (dict, dict): generators = dict( main=build_model(self.config.model.generator) ) discriminators = dict( b=build_model(self.config.model.discriminator) ) self.logger.debug(discriminators["b"]) self.logger.debug(generators["main"]) for m in chain(generators.values(), discriminators.values()): generation_init_weights(m) return generators, discriminators def setup_after_g(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(True) def setup_before_g(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(False) def forward(self, batch, inference=False) -> dict: images = dict() with torch.set_grad_enabled(not inference): images["a2b"] = self.generators["main"](batch["a"]) return images def criterion_generators(self, batch, generated) -> dict: loss = dict() prediction_fake = self.discriminators["b"](generated["a2b"]) loss["gan"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True) loss["mgc"] = self.mgc_loss(generated["a2b"], batch["a"]) loss["perceptual"] = self.perceptual_loss(generated["a2b"], batch["a"]) if self.fm_loss.weight > 0: prediction_real = self.discriminators["b"](batch["b"]) loss["feature_match"] = self.fm_loss(prediction_fake, prediction_real) return loss def criterion_discriminators(self, batch, generated) -> dict: loss = dict() generated_image = self.image_buffers["b"].query(generated["a2b"].detach()) loss["b"] = (self.gan_loss(self.discriminators["b"](generated_image), False, is_discriminator=True) + self.gan_loss(self.discriminators["b"](batch["b"]), True, is_discriminator=True)) / 2 return loss def intermediate_images(self, batch, generated) -> dict: """ returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} :param batch: :param generated: dict of images :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ return dict( a=[batch["a"].detach(), generated["a2b"].detach()], ) def run(task, config, _): kernel = GauGANEngineKernel(config) run_kernel(task, config, kernel)