58 lines
2.1 KiB
Python
58 lines
2.1 KiB
Python
import ignite.distributed as idist
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from omegaconf import OmegaConf
|
|
|
|
from loss.gan import GANLoss
|
|
|
|
|
|
def gan_loss(config):
|
|
gan_loss_cfg = OmegaConf.to_container(config)
|
|
gan_loss_cfg.pop("weight")
|
|
gl = GANLoss(**gan_loss_cfg).to(idist.device())
|
|
def gan_loss_fn(prediction, target_is_real: bool, is_discriminator=False):
|
|
if isinstance(prediction, torch.Tensor):
|
|
# origin
|
|
return gl(prediction, target_is_real, is_discriminator)
|
|
elif isinstance(prediction, list) and isinstance(prediction[0], list):
|
|
# for multi scale discriminator, e.g. MultiScaleDiscriminator
|
|
loss = 0
|
|
for p in prediction:
|
|
loss += gl(p[-1], target_is_real, is_discriminator)
|
|
return loss
|
|
elif isinstance(prediction, list) and isinstance(prediction[0], torch.Tensor):
|
|
# for discriminator set `need_intermediate_feature` true
|
|
return gl(prediction[-1], target_is_real, is_discriminator)
|
|
else:
|
|
raise NotImplementedError("not support discriminator output")
|
|
return gan_loss_fn
|
|
|
|
|
|
def pixel_loss(level):
|
|
return nn.L1Loss() if level == 1 else nn.MSELoss()
|
|
|
|
|
|
def mse_loss(x, target_flag):
|
|
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
|
|
|
|
|
def bce_loss(x, target_flag):
|
|
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
|
|
|
|
|
def feature_match_loss(level, weight_policy):
|
|
compare_loss = pixel_loss(level)
|
|
assert weight_policy in ["same", "exponential_decline"]
|
|
|
|
def fm_loss(generated_features, target_features):
|
|
num_scale = len(generated_features)
|
|
loss = 0
|
|
for s_i in range(num_scale):
|
|
for i in range(len(generated_features[s_i]) - 1):
|
|
weight = 1 if weight_policy == "same" else 2 ** i
|
|
loss += weight * compare_loss(generated_features[s_i][i], target_features[s_i][i].detach()) / num_scale
|
|
return loss
|
|
|
|
return fm_loss
|