Compare commits
2 Commits
436bca88b4
...
2de00d0245
| Author | SHA1 | Date | |
|---|---|---|---|
| 2de00d0245 | |||
| 74a7cfb2d8 |
@ -13,6 +13,10 @@ from model.image_translation.UGATIT import RhoClipper
|
|||||||
from util.image import attention_colored_map
|
from util.image import attention_colored_map
|
||||||
|
|
||||||
|
|
||||||
|
def pixel_loss(level):
|
||||||
|
return nn.L1Loss() if level == 1 else nn.MSELoss()
|
||||||
|
|
||||||
|
|
||||||
def mse_loss(x, target_flag):
|
def mse_loss(x, target_flag):
|
||||||
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||||
|
|
||||||
@ -28,10 +32,13 @@ class UGATITEngineKernel(EngineKernel):
|
|||||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||||
gan_loss_cfg.pop("weight")
|
gan_loss_cfg.pop("weight")
|
||||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
self.cycle_loss = LossContainer(config.loss.cycle.weight,
|
|
||||||
nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss())
|
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
|
||||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
|
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
|
||||||
self.id_loss = LossContainer(config.loss.id.weight, nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss())
|
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
|
||||||
|
self.bce_loss = LossContainer(self.config.loss.cam.weight, bce_loss)
|
||||||
|
self.mse_loss = LossContainer(self.config.loss.gan.weight, mse_loss)
|
||||||
|
|
||||||
self.rho_clipper = RhoClipper(0, 1)
|
self.rho_clipper = RhoClipper(0, 1)
|
||||||
self.train_generator_first = False
|
self.train_generator_first = False
|
||||||
|
|
||||||
@ -86,10 +93,10 @@ class UGATITEngineKernel(EngineKernel):
|
|||||||
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
|
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
|
||||||
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
|
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
|
||||||
loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True)
|
loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True)
|
||||||
loss[f"gan_cam_{phase}_{dk}"] = self.config.loss.gan.weight * mse_loss(cam_pred, True)
|
loss[f"gan_cam_{phase}_{dk}"] = self.mse_loss(cam_pred, True)
|
||||||
for t, f in [("a2b", "b2b"), ("b2a", "a2a")]:
|
for t, f in [("a2b", "b2b"), ("b2a", "a2a")]:
|
||||||
loss[f"cam_{t[-1]}"] = self.config.loss.cam.weight * (
|
loss[f"cam_{t[-1]}"] = self.bce_loss(generated["cam_pred"][t], True) + \
|
||||||
bce_loss(generated["cam_pred"][t], True) + bce_loss(generated["cam_pred"][f], False))
|
self.bce_loss(generated["cam_pred"][f], False)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def criterion_discriminators(self, batch, generated) -> dict:
|
def criterion_discriminators(self, batch, generated) -> dict:
|
||||||
|
|||||||
@ -1,10 +1,17 @@
|
|||||||
import ignite.distributed as idist
|
import ignite.distributed as idist
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from model import MODEL
|
from model import MODEL
|
||||||
from util.misc import add_spectral_norm
|
|
||||||
|
|
||||||
|
def add_spectral_norm(module):
|
||||||
|
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
||||||
|
return nn.utils.spectral_norm(module)
|
||||||
|
else:
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
def build_model(cfg):
|
def build_model(cfg):
|
||||||
|
|||||||
10
util/misc.py
10
util/misc.py
@ -4,16 +4,6 @@ import pkgutil
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def add_spectral_norm(module):
|
|
||||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
|
||||||
return nn.utils.spectral_norm(module)
|
|
||||||
else:
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def import_submodules(package, recursive=True):
|
def import_submodules(package, recursive=True):
|
||||||
""" Import all submodules of a module, recursively, including subpackages
|
""" Import all submodules of a module, recursively, including subpackages
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user