155 lines
6.8 KiB
Python
155 lines
6.8 KiB
Python
import ignite.distributed as idist
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from omegaconf import OmegaConf
|
|
|
|
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
|
from engine.util.build import build_model
|
|
from loss.I2I.perceptual_loss import PerceptualLoss
|
|
from loss.gan import GANLoss
|
|
|
|
|
|
def mse_loss(x, target_flag):
|
|
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
|
|
|
|
|
def bce_loss(x, target_flag):
|
|
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
|
|
|
|
|
class MUNITEngineKernel(EngineKernel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
|
perceptual_loss_cfg.pop("weight")
|
|
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
|
|
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
|
gan_loss_cfg.pop("weight")
|
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
|
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
|
self.train_generator_first = False
|
|
|
|
def build_models(self) -> (dict, dict):
|
|
generators = dict(
|
|
a=build_model(self.config.model.generator),
|
|
b=build_model(self.config.model.generator)
|
|
)
|
|
discriminators = dict(
|
|
a=build_model(self.config.model.discriminator),
|
|
b=build_model(self.config.model.discriminator)
|
|
)
|
|
self.logger.debug(discriminators["a"])
|
|
self.logger.debug(generators["a"])
|
|
|
|
return generators, discriminators
|
|
|
|
def setup_after_g(self):
|
|
for discriminator in self.discriminators.values():
|
|
discriminator.requires_grad_(True)
|
|
|
|
def setup_before_g(self):
|
|
for discriminator in self.discriminators.values():
|
|
discriminator.requires_grad_(False)
|
|
|
|
def forward(self, batch, inference=False) -> dict:
|
|
styles = dict()
|
|
contents = dict()
|
|
images = dict()
|
|
with torch.set_grad_enabled(not inference):
|
|
for phase in "ab":
|
|
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
|
|
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
|
|
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
|
|
|
|
for phase in ("a2b", "b2a"):
|
|
# images["a2b"] = Gb.decode(content_a, random_style_b)
|
|
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
|
|
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
|
|
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
|
|
if self.config.loss.recon.cycle.weight > 0:
|
|
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
|
|
return dict(styles=styles, contents=contents, images=images)
|
|
|
|
def criterion_generators(self, batch, generated) -> dict:
|
|
loss = dict()
|
|
for phase in "ab":
|
|
loss[f"recon_image_{phase}"] = self.config.loss.recon.image.weight * self.recon_loss(
|
|
batch[phase], generated["images"]["{0}2{0}".format(phase)])
|
|
loss[f"recon_content_{phase}"] = self.config.loss.recon.content.weight * self.recon_loss(
|
|
generated["contents"][phase], generated["contents"]["a2b" if phase == "a" else "b2a"])
|
|
loss[f"recon_style_{phase}"] = self.config.loss.recon.style.weight * self.recon_loss(
|
|
generated["styles"][f"random_{phase}"], generated["styles"]["b2a" if phase == "a" else "a2b"])
|
|
pred_fake = self.discriminators[phase](generated["images"]["b2a" if phase == "a" else "a2b"])
|
|
loss[f"gan_{phase}"] = 0
|
|
for sub_pred_fake in pred_fake:
|
|
# last output is actual prediction
|
|
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
|
|
if self.config.loss.recon.cycle.weight > 0:
|
|
loss[f"recon_cycle_{phase}"] = self.config.loss.recon.cycle.weight * self.recon_loss(
|
|
batch[phase], generated["images"]["a2b2a" if phase == "a" else "b2a2b"])
|
|
if self.config.loss.perceptual.weight > 0:
|
|
loss[f"perceptual_{phase}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
|
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
|
|
return loss
|
|
|
|
def criterion_discriminators(self, batch, generated) -> dict:
|
|
loss = dict()
|
|
for phase in ("a2b", "b2a"):
|
|
pred_real = self.discriminators[phase[-1]](batch[phase[-1]])
|
|
pred_fake = self.discriminators[phase[-1]](generated["images"][phase].detach())
|
|
loss[f"gan_{phase[-1]}"] = 0
|
|
for i in range(len(pred_fake)):
|
|
loss[f"gan_{phase[-1]}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
|
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
|
return loss
|
|
|
|
def intermediate_images(self, batch, generated) -> dict:
|
|
"""
|
|
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
:param batch:
|
|
:param generated: dict of images
|
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
"""
|
|
generated = {img: generated["images"][img].detach() for img in generated["images"]}
|
|
images = dict()
|
|
for phase in "ab":
|
|
images[phase] = [batch[phase].detach(), generated["{0}2{0}".format(phase)],
|
|
generated["a2b" if phase == "a" else "b2a"]]
|
|
if self.config.loss.recon.cycle.weight > 0:
|
|
images[phase].append(generated["a2b2a" if phase == "a" else "b2a2b"])
|
|
return images
|
|
|
|
|
|
class MUNITTestEngineKernel(TestEngineKernel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
def build_generators(self) -> dict:
|
|
generators = dict(
|
|
a=build_model(self.config.model.generator),
|
|
b=build_model(self.config.model.generator)
|
|
)
|
|
return generators
|
|
|
|
def to_load(self):
|
|
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
|
|
|
def inference(self, batch):
|
|
with torch.no_grad():
|
|
fake, _, _ = self.generators["a2b"](batch[0])
|
|
return fake.detach()
|
|
|
|
|
|
def run(task, config, _):
|
|
if task == "train":
|
|
kernel = MUNITEngineKernel(config)
|
|
run_kernel(task, config, kernel)
|
|
elif task == "test":
|
|
kernel = MUNITTestEngineKernel(config)
|
|
run_kernel(task, config, kernel)
|
|
else:
|
|
raise NotImplemented
|