add loss container
This commit is contained in:
parent
6070f08835
commit
436bca88b4
@ -6,6 +6,7 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||||
from engine.util.build import build_model
|
from engine.util.build import build_model
|
||||||
|
from engine.util.container import LossContainer
|
||||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
|
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
|
||||||
from loss.gan import GANLoss
|
from loss.gan import GANLoss
|
||||||
from model.image_translation.UGATIT import RhoClipper
|
from model.image_translation.UGATIT import RhoClipper
|
||||||
@ -27,9 +28,10 @@ class UGATITEngineKernel(EngineKernel):
|
|||||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||||
gan_loss_cfg.pop("weight")
|
gan_loss_cfg.pop("weight")
|
||||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
self.cycle_loss = LossContainer(config.loss.cycle.weight,
|
||||||
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
|
nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss())
|
||||||
self.mgc_loss = MyLoss()
|
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
|
||||||
|
self.id_loss = LossContainer(config.loss.id.weight, nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss())
|
||||||
self.rho_clipper = RhoClipper(0, 1)
|
self.rho_clipper = RhoClipper(0, 1)
|
||||||
self.train_generator_first = False
|
self.train_generator_first = False
|
||||||
|
|
||||||
@ -77,12 +79,9 @@ class UGATITEngineKernel(EngineKernel):
|
|||||||
loss = dict()
|
loss = dict()
|
||||||
for phase in "ab":
|
for phase in "ab":
|
||||||
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
|
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
|
||||||
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
|
loss[f"cycle_{phase}"] = self.cycle_loss(cycle_image, batch[phase])
|
||||||
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
|
loss[f"id_{phase}"] = self.id_loss(batch[phase], generated["images"][f"{phase}2{phase}"])
|
||||||
generated["images"][f"{phase}2{phase}"])
|
loss[f"mgc_{phase}"] = self.mgc_loss(batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
|
||||||
if self.config.loss.mgc.weight > 0:
|
|
||||||
loss[f"mgc_{phase}"] = self.config.loss.mgc.weight * self.mgc_loss(
|
|
||||||
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
|
|
||||||
for dk in "lg":
|
for dk in "lg":
|
||||||
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
|
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
|
||||||
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
|
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
|
||||||
|
|||||||
@ -100,6 +100,13 @@ class EngineKernel(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_no_grad_loss(loss_dict):
|
||||||
|
for k in loss_dict:
|
||||||
|
if not isinstance(loss_dict[k], torch.Tensor):
|
||||||
|
loss_dict.pop(k)
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
def get_trainer(config, kernel: EngineKernel):
|
def get_trainer(config, kernel: EngineKernel):
|
||||||
logger = logging.getLogger(config.name)
|
logger = logging.getLogger(config.name)
|
||||||
generators, discriminators = kernel.generators, kernel.discriminators
|
generators, discriminators = kernel.generators, kernel.discriminators
|
||||||
@ -147,10 +154,10 @@ def get_trainer(config, kernel: EngineKernel):
|
|||||||
|
|
||||||
if engine.state.iteration % iteration_per_image == 0:
|
if engine.state.iteration % iteration_per_image == 0:
|
||||||
return {
|
return {
|
||||||
"loss": dict(g=loss_g, d=loss_d),
|
"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d)),
|
||||||
"img": kernel.intermediate_images(batch, generated)
|
"img": kernel.intermediate_images(batch, generated)
|
||||||
}
|
}
|
||||||
return {"loss": dict(g=loss_g, d=loss_d)}
|
return {"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d))}
|
||||||
|
|
||||||
trainer = Engine(_step)
|
trainer = Engine(_step)
|
||||||
trainer.logger = logger
|
trainer.logger = logger
|
||||||
|
|||||||
9
engine/util/container.py
Normal file
9
engine/util/container.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
class LossContainer:
|
||||||
|
def __init__(self, weight, loss):
|
||||||
|
self.weight = weight
|
||||||
|
self.loss = loss
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
if self.weight > 0:
|
||||||
|
return self.weight * self.loss(*args, **kwargs)
|
||||||
|
return 0.0
|
||||||
Loading…
Reference in New Issue
Block a user