26 lines
661 B
Python
26 lines
661 B
Python
import ignite.distributed as idist
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from omegaconf import OmegaConf
|
|
|
|
from loss.gan import GANLoss
|
|
|
|
|
|
def gan_loss(config):
|
|
gan_loss_cfg = OmegaConf.to_container(config)
|
|
gan_loss_cfg.pop("weight")
|
|
return GANLoss(**gan_loss_cfg).to(idist.device())
|
|
|
|
|
|
def pixel_loss(level):
|
|
return nn.L1Loss() if level == 1 else nn.MSELoss()
|
|
|
|
|
|
def mse_loss(x, target_flag):
|
|
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
|
|
|
|
|
def bce_loss(x, target_flag):
|
|
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|