Compare commits

..

No commits in common. "2de00d0245a8d628fbddc06c379238905bedcd10" and "436bca88b40e08683e75b818f18386e0492c9342" have entirely different histories.

3 changed files with 17 additions and 21 deletions

View File

@ -13,10 +13,6 @@ from model.image_translation.UGATIT import RhoClipper
from util.image import attention_colored_map from util.image import attention_colored_map
def pixel_loss(level):
return nn.L1Loss() if level == 1 else nn.MSELoss()
def mse_loss(x, target_flag): def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
@ -32,13 +28,10 @@ class UGATITEngineKernel(EngineKernel):
gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight") gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = LossContainer(config.loss.cycle.weight,
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level)) nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss())
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss()) self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level)) self.id_loss = LossContainer(config.loss.id.weight, nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss())
self.bce_loss = LossContainer(self.config.loss.cam.weight, bce_loss)
self.mse_loss = LossContainer(self.config.loss.gan.weight, mse_loss)
self.rho_clipper = RhoClipper(0, 1) self.rho_clipper = RhoClipper(0, 1)
self.train_generator_first = False self.train_generator_first = False
@ -93,10 +86,10 @@ class UGATITEngineKernel(EngineKernel):
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"] generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image) pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True) loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True)
loss[f"gan_cam_{phase}_{dk}"] = self.mse_loss(cam_pred, True) loss[f"gan_cam_{phase}_{dk}"] = self.config.loss.gan.weight * mse_loss(cam_pred, True)
for t, f in [("a2b", "b2b"), ("b2a", "a2a")]: for t, f in [("a2b", "b2b"), ("b2a", "a2a")]:
loss[f"cam_{t[-1]}"] = self.bce_loss(generated["cam_pred"][t], True) + \ loss[f"cam_{t[-1]}"] = self.config.loss.cam.weight * (
self.bce_loss(generated["cam_pred"][f], False) bce_loss(generated["cam_pred"][t], True) + bce_loss(generated["cam_pred"][f], False))
return loss return loss
def criterion_discriminators(self, batch, generated) -> dict: def criterion_discriminators(self, batch, generated) -> dict:

View File

@ -1,17 +1,10 @@
import ignite.distributed as idist import ignite.distributed as idist
import torch import torch
import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from omegaconf import OmegaConf from omegaconf import OmegaConf
from model import MODEL from model import MODEL
from util.misc import add_spectral_norm
def add_spectral_norm(module):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def build_model(cfg): def build_model(cfg):

View File

@ -4,6 +4,16 @@ import pkgutil
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import torch.nn as nn
def add_spectral_norm(module):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def import_submodules(package, recursive=True): def import_submodules(package, recursive=True):
""" Import all submodules of a module, recursively, including subpackages """ Import all submodules of a module, recursively, including subpackages