151 lines
6.7 KiB
Python
151 lines
6.7 KiB
Python
import ignite.distributed as idist
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from omegaconf import OmegaConf
|
|
|
|
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
|
from engine.util.build import build_model
|
|
from engine.util.container import LossContainer
|
|
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
|
|
from loss.gan import GANLoss
|
|
from model.image_translation.UGATIT import RhoClipper
|
|
from util.image import attention_colored_map
|
|
|
|
|
|
def mse_loss(x, target_flag):
|
|
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
|
|
|
|
|
def bce_loss(x, target_flag):
|
|
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
|
|
|
|
|
class UGATITEngineKernel(EngineKernel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
|
gan_loss_cfg.pop("weight")
|
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
|
self.cycle_loss = LossContainer(config.loss.cycle.weight,
|
|
nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss())
|
|
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
|
|
self.id_loss = LossContainer(config.loss.id.weight, nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss())
|
|
self.rho_clipper = RhoClipper(0, 1)
|
|
self.train_generator_first = False
|
|
|
|
def build_models(self) -> (dict, dict):
|
|
generators = dict(
|
|
a2b=build_model(self.config.model.generator),
|
|
b2a=build_model(self.config.model.generator)
|
|
)
|
|
discriminators = dict(
|
|
la=build_model(self.config.model.local_discriminator),
|
|
lb=build_model(self.config.model.local_discriminator),
|
|
ga=build_model(self.config.model.global_discriminator),
|
|
gb=build_model(self.config.model.global_discriminator),
|
|
)
|
|
self.logger.debug(discriminators["ga"])
|
|
self.logger.debug(generators["a2b"])
|
|
|
|
return generators, discriminators
|
|
|
|
def setup_after_g(self):
|
|
for generator in self.generators.values():
|
|
generator.apply(self.rho_clipper)
|
|
for discriminator in self.discriminators.values():
|
|
discriminator.requires_grad_(True)
|
|
|
|
def setup_before_g(self):
|
|
for discriminator in self.discriminators.values():
|
|
discriminator.requires_grad_(False)
|
|
|
|
def forward(self, batch, inference=False) -> dict:
|
|
images = dict()
|
|
heatmap = dict()
|
|
cam_pred = dict()
|
|
|
|
with torch.set_grad_enabled(not inference):
|
|
images["a2b"], cam_pred["a2b"], heatmap["a2b"] = self.generators["a2b"](batch["a"])
|
|
images["b2a"], cam_pred["b2a"], heatmap["b2a"] = self.generators["b2a"](batch["b"])
|
|
images["a2b2a"], _, heatmap["a2b2a"] = self.generators["b2a"](images["a2b"])
|
|
images["b2a2b"], _, heatmap["b2a2b"] = self.generators["a2b"](images["b2a"])
|
|
images["a2a"], cam_pred["a2a"], heatmap["a2a"] = self.generators["b2a"](batch["a"])
|
|
images["b2b"], cam_pred["b2b"], heatmap["b2b"] = self.generators["a2b"](batch["b"])
|
|
return dict(images=images, heatmap=heatmap, cam_pred=cam_pred)
|
|
|
|
def criterion_generators(self, batch, generated) -> dict:
|
|
loss = dict()
|
|
for phase in "ab":
|
|
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
|
|
loss[f"cycle_{phase}"] = self.cycle_loss(cycle_image, batch[phase])
|
|
loss[f"id_{phase}"] = self.id_loss(batch[phase], generated["images"][f"{phase}2{phase}"])
|
|
loss[f"mgc_{phase}"] = self.mgc_loss(batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
|
|
for dk in "lg":
|
|
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
|
|
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
|
|
loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True)
|
|
loss[f"gan_cam_{phase}_{dk}"] = self.config.loss.gan.weight * mse_loss(cam_pred, True)
|
|
for t, f in [("a2b", "b2b"), ("b2a", "a2a")]:
|
|
loss[f"cam_{t[-1]}"] = self.config.loss.cam.weight * (
|
|
bce_loss(generated["cam_pred"][t], True) + bce_loss(generated["cam_pred"][f], False))
|
|
return loss
|
|
|
|
def criterion_discriminators(self, batch, generated) -> dict:
|
|
loss = dict()
|
|
for phase in "ab":
|
|
for level in "gl":
|
|
generated_image = generated["images"]["b2a" if phase == "a" else "a2b"].detach()
|
|
pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image)
|
|
pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase])
|
|
loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss(
|
|
pred_fake, False, is_discriminator=True)
|
|
loss[f"cam_{phase}_{level}"] = mse_loss(cam_fake_pred, False) + mse_loss(cam_real_pred, True)
|
|
return loss
|
|
|
|
def intermediate_images(self, batch, generated) -> dict:
|
|
"""
|
|
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
:param batch:
|
|
:param generated: dict of images
|
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
"""
|
|
attention_a = attention_colored_map(generated["heatmap"]["a2b"].detach(), batch["a"].size()[-2:])
|
|
attention_b = attention_colored_map(generated["heatmap"]["b2a"].detach(), batch["b"].size()[-2:])
|
|
generated = {img: generated["images"][img].detach() for img in generated["images"]}
|
|
return {
|
|
"a": [batch["a"], attention_a, generated["a2b"], generated["a2a"], generated["a2b2a"]],
|
|
"b": [batch["b"], attention_b, generated["b2a"], generated["b2b"], generated["b2a2b"]],
|
|
}
|
|
|
|
|
|
class UGATITTestEngineKernel(TestEngineKernel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
def build_generators(self) -> dict:
|
|
generators = dict(
|
|
a2b=build_model(self.config.model.generator),
|
|
)
|
|
return generators
|
|
|
|
def to_load(self):
|
|
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
|
|
|
def inference(self, batch):
|
|
with torch.no_grad():
|
|
fake, _, _ = self.generators["a2b"](batch[0])
|
|
return fake.detach()
|
|
|
|
|
|
def run(task, config, _):
|
|
if task == "train":
|
|
kernel = UGATITEngineKernel(config)
|
|
run_kernel(task, config, kernel)
|
|
elif task == "test":
|
|
kernel = UGATITTestEngineKernel(config)
|
|
run_kernel(task, config, kernel)
|
|
else:
|
|
raise NotImplemented
|