Compare commits

...

2 Commits

Author SHA1 Message Date
2de00d0245 use loss container 2020-10-11 23:36:37 +08:00
74a7cfb2d8 move sn to engine 2020-10-11 23:35:29 +08:00
3 changed files with 21 additions and 17 deletions

View File

@ -13,6 +13,10 @@ from model.image_translation.UGATIT import RhoClipper
from util.image import attention_colored_map from util.image import attention_colored_map
def pixel_loss(level):
return nn.L1Loss() if level == 1 else nn.MSELoss()
def mse_loss(x, target_flag): def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
@ -28,10 +32,13 @@ class UGATITEngineKernel(EngineKernel):
gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight") gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = LossContainer(config.loss.cycle.weight,
nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()) self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss()) self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
self.id_loss = LossContainer(config.loss.id.weight, nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()) self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
self.bce_loss = LossContainer(self.config.loss.cam.weight, bce_loss)
self.mse_loss = LossContainer(self.config.loss.gan.weight, mse_loss)
self.rho_clipper = RhoClipper(0, 1) self.rho_clipper = RhoClipper(0, 1)
self.train_generator_first = False self.train_generator_first = False
@ -86,10 +93,10 @@ class UGATITEngineKernel(EngineKernel):
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"] generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image) pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True) loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True)
loss[f"gan_cam_{phase}_{dk}"] = self.config.loss.gan.weight * mse_loss(cam_pred, True) loss[f"gan_cam_{phase}_{dk}"] = self.mse_loss(cam_pred, True)
for t, f in [("a2b", "b2b"), ("b2a", "a2a")]: for t, f in [("a2b", "b2b"), ("b2a", "a2a")]:
loss[f"cam_{t[-1]}"] = self.config.loss.cam.weight * ( loss[f"cam_{t[-1]}"] = self.bce_loss(generated["cam_pred"][t], True) + \
bce_loss(generated["cam_pred"][t], True) + bce_loss(generated["cam_pred"][f], False)) self.bce_loss(generated["cam_pred"][f], False)
return loss return loss
def criterion_discriminators(self, batch, generated) -> dict: def criterion_discriminators(self, batch, generated) -> dict:

View File

@ -1,10 +1,17 @@
import ignite.distributed as idist import ignite.distributed as idist
import torch import torch
import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from omegaconf import OmegaConf from omegaconf import OmegaConf
from model import MODEL from model import MODEL
from util.misc import add_spectral_norm
def add_spectral_norm(module):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def build_model(cfg): def build_model(cfg):

View File

@ -4,16 +4,6 @@ import pkgutil
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import torch.nn as nn
def add_spectral_norm(module):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def import_submodules(package, recursive=True): def import_submodules(package, recursive=True):
""" Import all submodules of a module, recursively, including subpackages """ Import all submodules of a module, recursively, including subpackages