import ignite.distributed as idist import torch import torch.nn as nn import torch.nn.functional as F from omegaconf import OmegaConf from loss.I2I.perceptual_loss import PerceptualLoss from loss.gan import GANLoss def gan_loss(config): gan_loss_cfg = OmegaConf.to_container(config) gan_loss_cfg.pop("weight") return GANLoss(**gan_loss_cfg).to(idist.device()) def perceptual_loss(config): perceptual_loss_cfg = OmegaConf.to_container(config) perceptual_loss_cfg.pop("weight") return PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) def pixel_loss(level): return nn.L1Loss() if level == 1 else nn.MSELoss() def mse_loss(x, target_flag): return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) def bce_loss(x, target_flag): return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) def feature_match_loss(level, weight_policy): compare_loss = pixel_loss(level) assert weight_policy in ["same", "exponential_decline"] def fm_loss(generated_features, target_features): num_scale = len(generated_features) loss = torch.zeros(1, device=idist.device()) for s_i in range(num_scale): for i in range(len(generated_features[s_i]) - 1): weight = 1 if weight_policy == "same" else 2 ** i loss += weight * compare_loss(generated_features[s_i][i], target_features[s_i][i].detach()) / num_scale return loss return fm_loss