diff --git a/engine/U-GAT-IT.py b/engine/U-GAT-IT.py index 85cd672..3044c2c 100644 --- a/engine/U-GAT-IT.py +++ b/engine/U-GAT-IT.py @@ -13,6 +13,10 @@ from model.image_translation.UGATIT import RhoClipper from util.image import attention_colored_map +def pixel_loss(level): + return nn.L1Loss() if level == 1 else nn.MSELoss() + + def mse_loss(x, target_flag): return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) @@ -28,10 +32,13 @@ class UGATITEngineKernel(EngineKernel): gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) - self.cycle_loss = LossContainer(config.loss.cycle.weight, - nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()) + + self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level)) self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss()) - self.id_loss = LossContainer(config.loss.id.weight, nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()) + self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level)) + self.bce_loss = LossContainer(self.config.loss.cam.weight, bce_loss) + self.mse_loss = LossContainer(self.config.loss.gan.weight, mse_loss) + self.rho_clipper = RhoClipper(0, 1) self.train_generator_first = False @@ -86,10 +93,10 @@ class UGATITEngineKernel(EngineKernel): generated_image = generated["images"]["a2b" if phase == "b" else "b2a"] pred_fake, cam_pred = self.discriminators[dk + phase](generated_image) loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True) - loss[f"gan_cam_{phase}_{dk}"] = self.config.loss.gan.weight * mse_loss(cam_pred, True) + loss[f"gan_cam_{phase}_{dk}"] = self.mse_loss(cam_pred, True) for t, f in [("a2b", "b2b"), ("b2a", "a2a")]: - loss[f"cam_{t[-1]}"] = self.config.loss.cam.weight * ( - bce_loss(generated["cam_pred"][t], True) + bce_loss(generated["cam_pred"][f], False)) + loss[f"cam_{t[-1]}"] = self.bce_loss(generated["cam_pred"][t], True) + \ + self.bce_loss(generated["cam_pred"][f], False) return loss def criterion_discriminators(self, batch, generated) -> dict: