raycv/engine/MUNIT.py
2020-09-17 09:34:53 +08:00

155 lines
6.8 KiB
Python

import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
class MUNITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.train_generator_first = False
def build_models(self) -> (dict, dict):
generators = dict(
a=build_model(self.config.model.generator),
b=build_model(self.config.model.generator)
)
discriminators = dict(
a=build_model(self.config.model.discriminator),
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["a"])
self.logger.debug(generators["a"])
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
styles = dict()
contents = dict()
images = dict()
with torch.set_grad_enabled(not inference):
for phase in "ab":
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
for phase in ("a2b", "b2a"):
# images["a2b"] = Gb.decode(content_a, random_style_b)
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
if self.config.loss.recon.cycle.weight > 0:
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
loss[f"recon_image_{phase}"] = self.config.loss.recon.image.weight * self.recon_loss(
batch[phase], generated["images"]["{0}2{0}".format(phase)])
loss[f"recon_content_{phase}"] = self.config.loss.recon.content.weight * self.recon_loss(
generated["contents"][phase], generated["contents"]["a2b" if phase == "a" else "b2a"])
loss[f"recon_style_{phase}"] = self.config.loss.recon.style.weight * self.recon_loss(
generated["styles"][f"random_{phase}"], generated["styles"]["b2a" if phase == "a" else "a2b"])
pred_fake = self.discriminators[phase](generated["images"]["b2a" if phase == "a" else "a2b"])
loss[f"gan_{phase}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
if self.config.loss.recon.cycle.weight > 0:
loss[f"recon_cycle_{phase}"] = self.config.loss.recon.cycle.weight * self.recon_loss(
batch[phase], generated["images"]["a2b2a" if phase == "a" else "b2a2b"])
if self.config.loss.perceptual.weight > 0:
loss[f"perceptual_{phase}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in ("a2b", "b2a"):
pred_real = self.discriminators[phase[-1]](batch[phase[-1]])
pred_fake = self.discriminators[phase[-1]](generated["images"][phase].detach())
loss[f"gan_{phase[-1]}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase[-1]}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
generated = {img: generated["images"][img].detach() for img in generated["images"]}
images = dict()
for phase in "ab":
images[phase] = [batch[phase].detach(), generated["{0}2{0}".format(phase)],
generated["a2b" if phase == "a" else "b2a"]]
if self.config.loss.recon.cycle.weight > 0:
images[phase].append(generated["a2b2a" if phase == "a" else "b2a2b"])
return images
class MUNITTestEngineKernel(TestEngineKernel):
def __init__(self, config):
super().__init__(config)
def build_generators(self) -> dict:
generators = dict(
a=build_model(self.config.model.generator),
b=build_model(self.config.model.generator)
)
return generators
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
with torch.no_grad():
fake, _, _ = self.generators["a2b"](batch[0])
return fake.detach()
def run(task, config, _):
if task == "train":
kernel = MUNITEngineKernel(config)
run_kernel(task, config, kernel)
elif task == "test":
kernel = MUNITTestEngineKernel(config)
run_kernel(task, config, kernel)
else:
raise NotImplemented