import torch.nn as nn import torch.nn.functional as F class GANLoss(nn.Module): def __init__(self, loss_type, real_label_val=1.0, fake_label_val=0.0): super().__init__() assert loss_type in ["vanilla", "lsgan", "hinge", "wgan"] self.real_label_val = real_label_val self.fake_label_val = fake_label_val self.loss_type = loss_type def forward(self, prediction, target_is_real: bool, is_discriminator=False): """ gan loss forward :param prediction: network prediction :param target_is_real: whether the target is real or fake :param is_discriminator: whether the loss for is_discriminator or not. default False :return: Tensor, GAN loss value """ target_val = self.real_label_val if target_is_real else self.fake_label_val target = prediction.new_ones(prediction.size()) * target_val if self.loss_type == "vanilla": return F.binary_cross_entropy_with_logits(prediction, target) elif self.loss_type == "lsgan": return F.mse_loss(prediction, target) elif self.loss_type == "hinge": if is_discriminator: prediction = -prediction if target_is_real else prediction loss = F.relu(1 + prediction).mean() else: loss = -prediction.mean() return loss elif self.loss_type == "wgan": loss = -prediction.mean() if target_is_real else prediction.mean() return loss else: raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')