add loss container

This commit is contained in:
Ray Wong 2020-10-11 23:09:04 +08:00
parent 6070f08835
commit 436bca88b4
3 changed files with 26 additions and 11 deletions

View File

@ -6,6 +6,7 @@ from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model from engine.util.build import build_model
from engine.util.container import LossContainer
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
from loss.gan import GANLoss from loss.gan import GANLoss
from model.image_translation.UGATIT import RhoClipper from model.image_translation.UGATIT import RhoClipper
@ -27,9 +28,10 @@ class UGATITEngineKernel(EngineKernel):
gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight") gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() self.cycle_loss = LossContainer(config.loss.cycle.weight,
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss() nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss())
self.mgc_loss = MyLoss() self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
self.id_loss = LossContainer(config.loss.id.weight, nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss())
self.rho_clipper = RhoClipper(0, 1) self.rho_clipper = RhoClipper(0, 1)
self.train_generator_first = False self.train_generator_first = False
@ -77,12 +79,9 @@ class UGATITEngineKernel(EngineKernel):
loss = dict() loss = dict()
for phase in "ab": for phase in "ab":
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"] cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase]) loss[f"cycle_{phase}"] = self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase], loss[f"id_{phase}"] = self.id_loss(batch[phase], generated["images"][f"{phase}2{phase}"])
generated["images"][f"{phase}2{phase}"]) loss[f"mgc_{phase}"] = self.mgc_loss(batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
if self.config.loss.mgc.weight > 0:
loss[f"mgc_{phase}"] = self.config.loss.mgc.weight * self.mgc_loss(
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
for dk in "lg": for dk in "lg":
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"] generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image) pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)

View File

@ -100,6 +100,13 @@ class EngineKernel(object):
pass pass
def _remove_no_grad_loss(loss_dict):
for k in loss_dict:
if not isinstance(loss_dict[k], torch.Tensor):
loss_dict.pop(k)
return loss_dict
def get_trainer(config, kernel: EngineKernel): def get_trainer(config, kernel: EngineKernel):
logger = logging.getLogger(config.name) logger = logging.getLogger(config.name)
generators, discriminators = kernel.generators, kernel.discriminators generators, discriminators = kernel.generators, kernel.discriminators
@ -147,10 +154,10 @@ def get_trainer(config, kernel: EngineKernel):
if engine.state.iteration % iteration_per_image == 0: if engine.state.iteration % iteration_per_image == 0:
return { return {
"loss": dict(g=loss_g, d=loss_d), "loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d)),
"img": kernel.intermediate_images(batch, generated) "img": kernel.intermediate_images(batch, generated)
} }
return {"loss": dict(g=loss_g, d=loss_d)} return {"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d))}
trainer = Engine(_step) trainer = Engine(_step)
trainer.logger = logger trainer.logger = logger

9
engine/util/container.py Normal file
View File

@ -0,0 +1,9 @@
class LossContainer:
def __init__(self, weight, loss):
self.weight = weight
self.loss = loss
def __call__(self, *args, **kwargs):
if self.weight > 0:
return self.weight * self.loss(*args, **kwargs)
return 0.0