import torch from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel from engine.util.build import build_model from engine.util.container import LossContainer from engine.util.loss import bce_loss, mse_loss, pixel_loss, gan_loss from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss from util.image import attention_colored_map class RhoClipper(object): def __init__(self, clip_min, clip_max): self.clip_min = clip_min self.clip_max = clip_max assert clip_min < clip_max def __call__(self, module): if hasattr(module, 'rho'): w = module.rho.data w = w.clamp(self.clip_min, self.clip_max) module.rho.data = w class UGATITEngineKernel(EngineKernel): def __init__(self, config): super().__init__(config) self.gan_loss = gan_loss(config.loss.gan) self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level)) self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss()) self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level)) self.bce_loss = LossContainer(self.config.loss.cam.weight, bce_loss) self.mse_loss = LossContainer(self.config.loss.gan.weight, mse_loss) self.rho_clipper = RhoClipper(0, 1) self.train_generator_first = False def build_models(self) -> (dict, dict): generators = dict( a2b=build_model(self.config.model.generator), b2a=build_model(self.config.model.generator) ) discriminators = dict( la=build_model(self.config.model.local_discriminator), lb=build_model(self.config.model.local_discriminator), ga=build_model(self.config.model.global_discriminator), gb=build_model(self.config.model.global_discriminator), ) self.logger.debug(discriminators["ga"]) self.logger.debug(generators["a2b"]) return generators, discriminators def setup_after_g(self): for generator in self.generators.values(): generator.apply(self.rho_clipper) for discriminator in self.discriminators.values(): discriminator.requires_grad_(True) def setup_before_g(self): for discriminator in self.discriminators.values(): discriminator.requires_grad_(False) def forward(self, batch, inference=False) -> dict: images = dict() heatmap = dict() cam_pred = dict() with torch.set_grad_enabled(not inference): images["a2b"], cam_pred["a2b"], heatmap["a2b"] = self.generators["a2b"](batch["a"]) images["b2a"], cam_pred["b2a"], heatmap["b2a"] = self.generators["b2a"](batch["b"]) images["a2b2a"], _, heatmap["a2b2a"] = self.generators["b2a"](images["a2b"]) images["b2a2b"], _, heatmap["b2a2b"] = self.generators["a2b"](images["b2a"]) images["a2a"], cam_pred["a2a"], heatmap["a2a"] = self.generators["b2a"](batch["a"]) images["b2b"], cam_pred["b2b"], heatmap["b2b"] = self.generators["a2b"](batch["b"]) return dict(images=images, heatmap=heatmap, cam_pred=cam_pred) def criterion_generators(self, batch, generated) -> dict: loss = dict() for phase in "ab": cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"] loss[f"cycle_{phase}"] = self.cycle_loss(cycle_image, batch[phase]) loss[f"id_{phase}"] = self.id_loss(batch[phase], generated["images"][f"{phase}2{phase}"]) loss[f"mgc_{phase}"] = self.mgc_loss(batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"]) for dk in "lg": generated_image = generated["images"]["a2b" if phase == "b" else "b2a"] pred_fake, cam_pred = self.discriminators[dk + phase](generated_image) loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True) loss[f"gan_cam_{phase}_{dk}"] = self.mse_loss(cam_pred, True) for t, f in [("a2b", "b2b"), ("b2a", "a2a")]: loss[f"cam_{t[-1]}"] = self.bce_loss(generated["cam_pred"][t], True) + \ self.bce_loss(generated["cam_pred"][f], False) return loss def criterion_discriminators(self, batch, generated) -> dict: loss = dict() for phase in "ab": for level in "gl": generated_image = generated["images"]["b2a" if phase == "a" else "a2b"].detach() pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image) pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase]) loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss( pred_fake, False, is_discriminator=True) loss[f"cam_{phase}_{level}"] = mse_loss(cam_fake_pred, False) + mse_loss(cam_real_pred, True) return loss def intermediate_images(self, batch, generated) -> dict: """ returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} :param batch: :param generated: dict of images :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ attention_a = attention_colored_map(generated["heatmap"]["a2b"].detach(), batch["a"].size()[-2:]) attention_b = attention_colored_map(generated["heatmap"]["b2a"].detach(), batch["b"].size()[-2:]) generated = {img: generated["images"][img].detach() for img in generated["images"]} return { "a": [batch["a"], attention_a, generated["a2b"], generated["a2a"], generated["a2b2a"]], "b": [batch["b"], attention_b, generated["b2a"], generated["b2b"], generated["b2a2b"]], } class UGATITTestEngineKernel(TestEngineKernel): def __init__(self, config): super().__init__(config) def build_generators(self) -> dict: generators = dict( a2b=build_model(self.config.model.generator), ) return generators def to_load(self): return {f"generator_{k}": self.generators[k] for k in self.generators} def inference(self, batch): with torch.no_grad(): fake, _, _ = self.generators["a2b"](batch[0]) return fake.detach() def run(task, config, _): if task == "train": kernel = UGATITEngineKernel(config) run_kernel(task, config, kernel) elif task == "test": kernel = UGATITTestEngineKernel(config) run_kernel(task, config, kernel) else: raise NotImplemented