add loss container

This commit is contained in:
Ray Wong 2020-10-11 23:09:04 +08:00
parent 6070f08835
commit 436bca88b4
3 changed files with 26 additions and 11 deletions

View File

@ -6,6 +6,7 @@ from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from engine.util.container import LossContainer
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
from loss.gan import GANLoss
from model.image_translation.UGATIT import RhoClipper
@ -27,9 +28,10 @@ class UGATITEngineKernel(EngineKernel):
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.mgc_loss = MyLoss()
self.cycle_loss = LossContainer(config.loss.cycle.weight,
nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss())
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
self.id_loss = LossContainer(config.loss.id.weight, nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss())
self.rho_clipper = RhoClipper(0, 1)
self.train_generator_first = False
@ -77,12 +79,9 @@ class UGATITEngineKernel(EngineKernel):
loss = dict()
for phase in "ab":
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
generated["images"][f"{phase}2{phase}"])
if self.config.loss.mgc.weight > 0:
loss[f"mgc_{phase}"] = self.config.loss.mgc.weight * self.mgc_loss(
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
loss[f"cycle_{phase}"] = self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.id_loss(batch[phase], generated["images"][f"{phase}2{phase}"])
loss[f"mgc_{phase}"] = self.mgc_loss(batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
for dk in "lg":
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)

View File

@ -100,6 +100,13 @@ class EngineKernel(object):
pass
def _remove_no_grad_loss(loss_dict):
for k in loss_dict:
if not isinstance(loss_dict[k], torch.Tensor):
loss_dict.pop(k)
return loss_dict
def get_trainer(config, kernel: EngineKernel):
logger = logging.getLogger(config.name)
generators, discriminators = kernel.generators, kernel.discriminators
@ -147,10 +154,10 @@ def get_trainer(config, kernel: EngineKernel):
if engine.state.iteration % iteration_per_image == 0:
return {
"loss": dict(g=loss_g, d=loss_d),
"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d)),
"img": kernel.intermediate_images(batch, generated)
}
return {"loss": dict(g=loss_g, d=loss_d)}
return {"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d))}
trainer = Engine(_step)
trainer.logger = logger

9
engine/util/container.py Normal file
View File

@ -0,0 +1,9 @@
class LossContainer:
def __init__(self, weight, loss):
self.weight = weight
self.loss = loss
def __call__(self, *args, **kwargs):
if self.weight > 0:
return self.weight * self.loss(*args, **kwargs)
return 0.0