diff --git a/engine/U-GAT-IT.py b/engine/U-GAT-IT.py index 860905d..85cd672 100644 --- a/engine/U-GAT-IT.py +++ b/engine/U-GAT-IT.py @@ -6,6 +6,7 @@ from omegaconf import OmegaConf from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel from engine.util.build import build_model +from engine.util.container import LossContainer from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss from loss.gan import GANLoss from model.image_translation.UGATIT import RhoClipper @@ -27,9 +28,10 @@ class UGATITEngineKernel(EngineKernel): gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) - self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() - self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss() - self.mgc_loss = MyLoss() + self.cycle_loss = LossContainer(config.loss.cycle.weight, + nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()) + self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss()) + self.id_loss = LossContainer(config.loss.id.weight, nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()) self.rho_clipper = RhoClipper(0, 1) self.train_generator_first = False @@ -77,12 +79,9 @@ class UGATITEngineKernel(EngineKernel): loss = dict() for phase in "ab": cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"] - loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase]) - loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase], - generated["images"][f"{phase}2{phase}"]) - if self.config.loss.mgc.weight > 0: - loss[f"mgc_{phase}"] = self.config.loss.mgc.weight * self.mgc_loss( - batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"]) + loss[f"cycle_{phase}"] = self.cycle_loss(cycle_image, batch[phase]) + loss[f"id_{phase}"] = self.id_loss(batch[phase], generated["images"][f"{phase}2{phase}"]) + loss[f"mgc_{phase}"] = self.mgc_loss(batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"]) for dk in "lg": generated_image = generated["images"]["a2b" if phase == "b" else "b2a"] pred_fake, cam_pred = self.discriminators[dk + phase](generated_image) diff --git a/engine/base/i2i.py b/engine/base/i2i.py index 74822e6..d9af31f 100644 --- a/engine/base/i2i.py +++ b/engine/base/i2i.py @@ -100,6 +100,13 @@ class EngineKernel(object): pass +def _remove_no_grad_loss(loss_dict): + for k in loss_dict: + if not isinstance(loss_dict[k], torch.Tensor): + loss_dict.pop(k) + return loss_dict + + def get_trainer(config, kernel: EngineKernel): logger = logging.getLogger(config.name) generators, discriminators = kernel.generators, kernel.discriminators @@ -147,10 +154,10 @@ def get_trainer(config, kernel: EngineKernel): if engine.state.iteration % iteration_per_image == 0: return { - "loss": dict(g=loss_g, d=loss_d), + "loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d)), "img": kernel.intermediate_images(batch, generated) } - return {"loss": dict(g=loss_g, d=loss_d)} + return {"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d))} trainer = Engine(_step) trainer.logger = logger diff --git a/engine/util/container.py b/engine/util/container.py new file mode 100644 index 0000000..c1690fd --- /dev/null +++ b/engine/util/container.py @@ -0,0 +1,9 @@ +class LossContainer: + def __init__(self, weight, loss): + self.weight = weight + self.loss = loss + + def __call__(self, *args, **kwargs): + if self.weight > 0: + return self.weight * self.loss(*args, **kwargs) + return 0.0