From f7b7b78669c0d8b8a3842111b7262f0feace44ce Mon Sep 17 00:00:00 2001 From: budui Date: Thu, 22 Oct 2020 23:19:03 +0800 Subject: [PATCH] imporved gan loss --- engine/util/loss.py | 18 +++++++++++++++++- loss/gan.py | 14 +------------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/engine/util/loss.py b/engine/util/loss.py index 70f3c84..5559ce4 100644 --- a/engine/util/loss.py +++ b/engine/util/loss.py @@ -10,7 +10,23 @@ from loss.gan import GANLoss def gan_loss(config): gan_loss_cfg = OmegaConf.to_container(config) gan_loss_cfg.pop("weight") - return GANLoss(**gan_loss_cfg).to(idist.device()) + gl = GANLoss(**gan_loss_cfg).to(idist.device()) + def gan_loss_fn(prediction, target_is_real: bool, is_discriminator=False): + if isinstance(prediction, torch.Tensor): + # origin + return gl(prediction, target_is_real, is_discriminator) + elif isinstance(prediction, list) and isinstance(prediction[0], list): + # for multi scale discriminator, e.g. MultiScaleDiscriminator + loss = 0 + for p in prediction: + loss += gl(p[-1], target_is_real, is_discriminator) + return loss + elif isinstance(prediction, list) and isinstance(prediction[0], torch.Tensor): + # for discriminator set `need_intermediate_feature` true + return gl(prediction[-1], target_is_real, is_discriminator) + else: + raise NotImplementedError("not support discriminator output") + return gan_loss_fn def pixel_loss(level): diff --git a/loss/gan.py b/loss/gan.py index e8f05c0..5e30bc4 100644 --- a/loss/gan.py +++ b/loss/gan.py @@ -1,6 +1,5 @@ import torch.nn as nn import torch.nn.functional as F -import torch class GANLoss(nn.Module): @@ -11,7 +10,7 @@ class GANLoss(nn.Module): self.fake_label_val = fake_label_val self.loss_type = loss_type - def single_forward(self, prediction, target_is_real: bool, is_discriminator=False): + def forward(self, prediction, target_is_real: bool, is_discriminator=False): """ gan loss forward :param prediction: network prediction @@ -38,14 +37,3 @@ class GANLoss(nn.Module): return loss else: raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.') - - def forward(self, prediction, target_is_real: bool, is_discriminator=False): - if isinstance(prediction, torch.Tensor): - # origin - return self.single_forward(prediction, target_is_real, is_discriminator) - elif isinstance(prediction, list): - # for multi scale discriminator, e.g. MultiScaleDiscriminator - loss = 0 - for p in prediction: - loss += self.single_forward(p[-1], target_is_real, is_discriminator) - return loss