raycv/engine/U-GAT-IT.py
2020-09-05 10:33:35 +08:00

154 lines
6.7 KiB
Python

from itertools import chain
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
import ignite.distributed as idist
from model.weight_init import generation_init_weights
from loss.gan import GANLoss
from model.GAN.UGATIT import RhoClipper
from model.GAN.residual_generator import GANImageBuffer
from util.image import attention_colored_map
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
class UGATITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.rho_clipper = RhoClipper(0, 1)
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
def build_models(self) -> (dict, dict):
generators = dict(
a2b=build_model(self.config.model.generator),
b2a=build_model(self.config.model.generator)
)
discriminators = dict(
la=build_model(self.config.model.local_discriminator),
lb=build_model(self.config.model.local_discriminator),
ga=build_model(self.config.model.global_discriminator),
gb=build_model(self.config.model.global_discriminator),
)
self.logger.debug(discriminators["ga"])
self.logger.debug(generators["a2b"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_before_d(self):
for generator in self.generators.values():
generator.apply(self.rho_clipper)
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
images = dict()
heatmap = dict()
cam_pred = dict()
with torch.set_grad_enabled(not inference):
images["a2b"], cam_pred["a2b"], heatmap["a2b"] = self.generators["a2b"](batch["a"])
images["b2a"], cam_pred["b2a"], heatmap["b2a"] = self.generators["b2a"](batch["b"])
images["a2b2a"], _, heatmap["a2b2a"] = self.generators["b2a"](images["a2b"])
images["b2a2b"], _, heatmap["b2a2b"] = self.generators["a2b"](images["b2a"])
images["a2a"], cam_pred["a2a"], heatmap["a2a"] = self.generators["b2a"](batch["a"])
images["b2b"], cam_pred["b2b"], heatmap["b2b"] = self.generators["a2b"](batch["b"])
return dict(images=images, heatmap=heatmap, cam_pred=cam_pred)
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
generated["images"][f"{phase}2{phase}"])
for dk in "lg":
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True)
loss[f"gan_cam_{phase}_{dk}"] = self.config.loss.gan.weight * mse_loss(cam_pred, True)
for t, f in [("a2b", "b2b"), ("b2a", "a2a")]:
loss[f"cam_{t[-1]}"] = self.config.loss.cam.weight * (
bce_loss(generated["cam_pred"][t], True) + bce_loss(generated["cam_pred"][f], False))
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
for level in "gl":
generated_image = self.image_buffers[level + phase].query(
generated["images"]["a2b" if phase == "b" else "b2a"])
pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image)
pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase])
loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss(
pred_fake, False, is_discriminator=True)
loss[f"cam_{phase}_{level}"] = mse_loss(cam_fake_pred, False) + mse_loss(cam_real_pred, True)
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
attention_a = attention_colored_map(generated["heatmap"]["a2b"].detach(), batch["a"].size()[-2:])
attention_b = attention_colored_map(generated["heatmap"]["b2a"].detach(), batch["b"].size()[-2:])
generated = {img: generated["images"][img].detach() for img in generated["images"]}
return {
"a": [batch["a"], attention_a, generated["a2b"], generated["a2a"], generated["a2b2a"]],
"b": [batch["b"], attention_b, generated["b2a"], generated["b2b"], generated["b2a2b"]],
}
class UGATITTestEngineKernel(TestEngineKernel):
def __init__(self, config):
super().__init__(config)
def build_generators(self) -> dict:
generators = dict(
a2b=build_model(self.config.model.generator),
)
return generators
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
with torch.no_grad():
fake, _, _ = self.generators["a2b"](batch["a"])
return {"a": fake.detach()}
def run(task, config, _):
if task == "train":
kernel = UGATITEngineKernel(config)
if task == "test":
kernel = UGATITTestEngineKernel(config)
run_kernel(task, config, kernel)