imporved gan loss
This commit is contained in:
parent
376f5caeb7
commit
f7b7b78669
@ -10,7 +10,23 @@ from loss.gan import GANLoss
|
||||
def gan_loss(config):
|
||||
gan_loss_cfg = OmegaConf.to_container(config)
|
||||
gan_loss_cfg.pop("weight")
|
||||
return GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
gl = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
def gan_loss_fn(prediction, target_is_real: bool, is_discriminator=False):
|
||||
if isinstance(prediction, torch.Tensor):
|
||||
# origin
|
||||
return gl(prediction, target_is_real, is_discriminator)
|
||||
elif isinstance(prediction, list) and isinstance(prediction[0], list):
|
||||
# for multi scale discriminator, e.g. MultiScaleDiscriminator
|
||||
loss = 0
|
||||
for p in prediction:
|
||||
loss += gl(p[-1], target_is_real, is_discriminator)
|
||||
return loss
|
||||
elif isinstance(prediction, list) and isinstance(prediction[0], torch.Tensor):
|
||||
# for discriminator set `need_intermediate_feature` true
|
||||
return gl(prediction[-1], target_is_real, is_discriminator)
|
||||
else:
|
||||
raise NotImplementedError("not support discriminator output")
|
||||
return gan_loss_fn
|
||||
|
||||
|
||||
def pixel_loss(level):
|
||||
|
||||
14
loss/gan.py
14
loss/gan.py
@ -1,6 +1,5 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
|
||||
|
||||
class GANLoss(nn.Module):
|
||||
@ -11,7 +10,7 @@ class GANLoss(nn.Module):
|
||||
self.fake_label_val = fake_label_val
|
||||
self.loss_type = loss_type
|
||||
|
||||
def single_forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
"""
|
||||
gan loss forward
|
||||
:param prediction: network prediction
|
||||
@ -38,14 +37,3 @@ class GANLoss(nn.Module):
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
|
||||
|
||||
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
if isinstance(prediction, torch.Tensor):
|
||||
# origin
|
||||
return self.single_forward(prediction, target_is_real, is_discriminator)
|
||||
elif isinstance(prediction, list):
|
||||
# for multi scale discriminator, e.g. MultiScaleDiscriminator
|
||||
loss = 0
|
||||
for p in prediction:
|
||||
loss += self.single_forward(p[-1], target_is_real, is_discriminator)
|
||||
return loss
|
||||
|
||||
Loading…
Reference in New Issue
Block a user