import torch.nn as nn import torch import torch.nn.functional as F class GANLoss(nn.Module): def __init__(self, loss_type, real_label_val=1.0, fake_label_val=0.0): super().__init__() assert loss_type in ["vanilla", "lsgan", "hinge", "wgan"] self.real_label_val = real_label_val self.fake_label_val = fake_label_val self.loss_type = loss_type def single_forward(self, prediction, target_is_real: bool, is_discriminator=False): """ gan loss forward :param prediction: network prediction :param target_is_real: whether the target is real or fake :param is_discriminator: whether the loss for is_discriminator or not. default False :return: Tensor, GAN loss value """ target_val = self.real_label_val if target_is_real else self.fake_label_val target = prediction.new_ones(prediction.size()) * target_val if self.loss_type == "vanilla": return F.binary_cross_entropy_with_logits(prediction, target) elif self.loss_type == "lsgan": return F.mse_loss(prediction, target) elif self.loss_type == "hinge": if is_discriminator: prediction = -prediction if target_is_real else prediction loss = F.relu(1 + prediction).mean() else: loss = -prediction.mean() return loss elif self.loss_type == "wgan": loss = -prediction.mean() if target_is_real else prediction.mean() return loss else: raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.') def forward(self, prediction, target_is_real: bool, is_discriminator=False): if isinstance(prediction, torch.Tensor): # origin return self.single_forward(prediction, target_is_real, is_discriminator) elif isinstance(prediction, list): # for multi scale discriminator, e.g. MultiScaleDiscriminator loss = 0 for p in prediction: loss += self.single_forward(p[-1], target_is_real, is_discriminator) return loss elif isinstance(prediction, tuple): # for single discriminator set `need_intermediate_feature` true return self.single_forward(prediction[-1], target_is_real, is_discriminator) else: raise NotImplementedError(f"not support discriminator output: {prediction}")