52 lines
2.1 KiB
Python
52 lines
2.1 KiB
Python
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch
|
|
|
|
|
|
class GANLoss(nn.Module):
|
|
def __init__(self, loss_type, real_label_val=1.0, fake_label_val=0.0):
|
|
super().__init__()
|
|
assert loss_type in ["vanilla", "lsgan", "hinge", "wgan"]
|
|
self.real_label_val = real_label_val
|
|
self.fake_label_val = fake_label_val
|
|
self.loss_type = loss_type
|
|
|
|
def single_forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
|
"""
|
|
gan loss forward
|
|
:param prediction: network prediction
|
|
:param target_is_real: whether the target is real or fake
|
|
:param is_discriminator: whether the loss for is_discriminator or not. default False
|
|
:return: Tensor, GAN loss value
|
|
"""
|
|
target_val = self.real_label_val if target_is_real else self.fake_label_val
|
|
target = prediction.new_ones(prediction.size()) * target_val
|
|
|
|
if self.loss_type == "vanilla":
|
|
return F.binary_cross_entropy_with_logits(prediction, target)
|
|
elif self.loss_type == "lsgan":
|
|
return F.mse_loss(prediction, target)
|
|
elif self.loss_type == "hinge":
|
|
if is_discriminator:
|
|
prediction = -prediction if target_is_real else prediction
|
|
loss = F.relu(1 + prediction).mean()
|
|
else:
|
|
loss = -prediction.mean()
|
|
return loss
|
|
elif self.loss_type == "wgan":
|
|
loss = -prediction.mean() if target_is_real else prediction.mean()
|
|
return loss
|
|
else:
|
|
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
|
|
|
|
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
|
if isinstance(prediction, torch.Tensor):
|
|
# origin
|
|
return self.single_forward(prediction, target_is_real, is_discriminator)
|
|
elif isinstance(prediction, list):
|
|
# for multi scale discriminator, e.g. MultiScaleDiscriminator
|
|
loss = 0
|
|
for p in prediction:
|
|
loss += self.single_forward(p[-1], target_is_real, is_discriminator)
|
|
return loss
|