import ignite.distributed as idist import torch import torch.nn as nn import torch.nn.functional as F from omegaconf import OmegaConf from loss.gan import GANLoss def gan_loss(config): gan_loss_cfg = OmegaConf.to_container(config) gan_loss_cfg.pop("weight") return GANLoss(**gan_loss_cfg).to(idist.device()) def pixel_loss(level): return nn.L1Loss() if level == 1 else nn.MSELoss() def mse_loss(x, target_flag): return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) def bce_loss(x, target_flag): return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))