imporved gan loss
This commit is contained in:
parent
376f5caeb7
commit
f7b7b78669
@ -10,7 +10,23 @@ from loss.gan import GANLoss
|
|||||||
def gan_loss(config):
|
def gan_loss(config):
|
||||||
gan_loss_cfg = OmegaConf.to_container(config)
|
gan_loss_cfg = OmegaConf.to_container(config)
|
||||||
gan_loss_cfg.pop("weight")
|
gan_loss_cfg.pop("weight")
|
||||||
return GANLoss(**gan_loss_cfg).to(idist.device())
|
gl = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
|
def gan_loss_fn(prediction, target_is_real: bool, is_discriminator=False):
|
||||||
|
if isinstance(prediction, torch.Tensor):
|
||||||
|
# origin
|
||||||
|
return gl(prediction, target_is_real, is_discriminator)
|
||||||
|
elif isinstance(prediction, list) and isinstance(prediction[0], list):
|
||||||
|
# for multi scale discriminator, e.g. MultiScaleDiscriminator
|
||||||
|
loss = 0
|
||||||
|
for p in prediction:
|
||||||
|
loss += gl(p[-1], target_is_real, is_discriminator)
|
||||||
|
return loss
|
||||||
|
elif isinstance(prediction, list) and isinstance(prediction[0], torch.Tensor):
|
||||||
|
# for discriminator set `need_intermediate_feature` true
|
||||||
|
return gl(prediction[-1], target_is_real, is_discriminator)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("not support discriminator output")
|
||||||
|
return gan_loss_fn
|
||||||
|
|
||||||
|
|
||||||
def pixel_loss(level):
|
def pixel_loss(level):
|
||||||
|
|||||||
14
loss/gan.py
14
loss/gan.py
@ -1,6 +1,5 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class GANLoss(nn.Module):
|
class GANLoss(nn.Module):
|
||||||
@ -11,7 +10,7 @@ class GANLoss(nn.Module):
|
|||||||
self.fake_label_val = fake_label_val
|
self.fake_label_val = fake_label_val
|
||||||
self.loss_type = loss_type
|
self.loss_type = loss_type
|
||||||
|
|
||||||
def single_forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||||
"""
|
"""
|
||||||
gan loss forward
|
gan loss forward
|
||||||
:param prediction: network prediction
|
:param prediction: network prediction
|
||||||
@ -38,14 +37,3 @@ class GANLoss(nn.Module):
|
|||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
|
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
|
||||||
|
|
||||||
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
|
||||||
if isinstance(prediction, torch.Tensor):
|
|
||||||
# origin
|
|
||||||
return self.single_forward(prediction, target_is_real, is_discriminator)
|
|
||||||
elif isinstance(prediction, list):
|
|
||||||
# for multi scale discriminator, e.g. MultiScaleDiscriminator
|
|
||||||
loss = 0
|
|
||||||
for p in prediction:
|
|
||||||
loss += self.single_forward(p[-1], target_is_real, is_discriminator)
|
|
||||||
return loss
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user