add loss container
This commit is contained in:
parent
6070f08835
commit
436bca88b4
@ -6,6 +6,7 @@ from omegaconf import OmegaConf
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||
from engine.util.build import build_model
|
||||
from engine.util.container import LossContainer
|
||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
|
||||
from loss.gan import GANLoss
|
||||
from model.image_translation.UGATIT import RhoClipper
|
||||
@ -27,9 +28,10 @@ class UGATITEngineKernel(EngineKernel):
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
|
||||
self.mgc_loss = MyLoss()
|
||||
self.cycle_loss = LossContainer(config.loss.cycle.weight,
|
||||
nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss())
|
||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
|
||||
self.id_loss = LossContainer(config.loss.id.weight, nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss())
|
||||
self.rho_clipper = RhoClipper(0, 1)
|
||||
self.train_generator_first = False
|
||||
|
||||
@ -77,12 +79,9 @@ class UGATITEngineKernel(EngineKernel):
|
||||
loss = dict()
|
||||
for phase in "ab":
|
||||
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
|
||||
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
|
||||
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
|
||||
generated["images"][f"{phase}2{phase}"])
|
||||
if self.config.loss.mgc.weight > 0:
|
||||
loss[f"mgc_{phase}"] = self.config.loss.mgc.weight * self.mgc_loss(
|
||||
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
|
||||
loss[f"cycle_{phase}"] = self.cycle_loss(cycle_image, batch[phase])
|
||||
loss[f"id_{phase}"] = self.id_loss(batch[phase], generated["images"][f"{phase}2{phase}"])
|
||||
loss[f"mgc_{phase}"] = self.mgc_loss(batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
|
||||
for dk in "lg":
|
||||
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
|
||||
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
|
||||
|
||||
@ -100,6 +100,13 @@ class EngineKernel(object):
|
||||
pass
|
||||
|
||||
|
||||
def _remove_no_grad_loss(loss_dict):
|
||||
for k in loss_dict:
|
||||
if not isinstance(loss_dict[k], torch.Tensor):
|
||||
loss_dict.pop(k)
|
||||
return loss_dict
|
||||
|
||||
|
||||
def get_trainer(config, kernel: EngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
generators, discriminators = kernel.generators, kernel.discriminators
|
||||
@ -147,10 +154,10 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
|
||||
if engine.state.iteration % iteration_per_image == 0:
|
||||
return {
|
||||
"loss": dict(g=loss_g, d=loss_d),
|
||||
"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d)),
|
||||
"img": kernel.intermediate_images(batch, generated)
|
||||
}
|
||||
return {"loss": dict(g=loss_g, d=loss_d)}
|
||||
return {"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d))}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
|
||||
9
engine/util/container.py
Normal file
9
engine/util/container.py
Normal file
@ -0,0 +1,9 @@
|
||||
class LossContainer:
|
||||
def __init__(self, weight, loss):
|
||||
self.weight = weight
|
||||
self.loss = loss
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.weight > 0:
|
||||
return self.weight * self.loss(*args, **kwargs)
|
||||
return 0.0
|
||||
Loading…
Reference in New Issue
Block a user