raycv/engine/TAFG.py
2020-09-25 18:31:12 +08:00

213 lines
9.7 KiB
Python

from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TAFGEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
self.content_recon_loss = nn.L1Loss() if config.loss.content_recon.level == 1 else nn.MSELoss()
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device())
def _process_batch(self, batch, inference=False):
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
return batch
def build_models(self) -> (dict, dict):
generators = dict(
main=build_model(self.config.model.generator)
)
discriminators = dict(
a=build_model(self.config.model.discriminator),
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["a"])
self.logger.debug(generators["main"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
generator = self.generators["main"]
batch = self._process_batch(batch, inference)
styles = dict()
contents = dict()
images = dict()
with torch.set_grad_enabled(not inference):
contents["a"], styles["a"] = generator.encode(batch["a"]["edge"], batch["a"]["img"], "a", "a")
contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b")
for ph in "ab":
images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
styles[f"random_b"] = torch.randn_like(styles["b"]).to(idist.device())
images["a2b"] = generator.decode(contents["a"], styles["random_b"], "b")
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
images["a2b"], "b", "b")
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict:
batch = self._process_batch(batch)
loss = dict()
for ph in "ab":
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
pred_fake = self.discriminators[ph](generated["images"][f"{ph}2{ph}"])
loss[f"gan_{ph}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
if self.engine.state.epoch == self.config.misc.add_new_loss_epoch:
self.generators["main"].style_converters.requires_grad_(False)
self.generators["main"].style_encoders.requires_grad_(False)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
pred_fake = self.discriminators[ph](generated["images"]["a2b"])
loss["gan_a2b"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss["gan_a2b"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
generated["contents"]["a"], generated["contents"]["recon_a"]
)
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
generated["styles"]["random_b"], generated["styles"]["recon_b"]
)
if self.config.loss.perceptual.weight > 0:
loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss(
batch["a"]["img"], generated["images"]["a2b"]
)
if self.config.loss.cycle.weight > 0:
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
batch["a"]["img"], generated["images"][f"cycle_a"]
)
# for ph in "ab":
#
# if self.config.loss.style.weight > 0:
# loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
# batch[ph]["img"], generated["images"][f"a2{ph}"]
# )
if self.config.loss.edge.weight > 0:
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
pred_fake_2 = self.discriminators[phase](generated["images"]["a2b"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) +
self.gan_loss(pred_fake_2[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 3
else:
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
batch = self._process_batch(batch)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
return dict(
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
generated["images"]["a2b"].detach(),
generated["images"]["cycle_a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
generated["images"]["cycle_b"].detach()]
)
else:
return dict(
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
]
)
def change_engine(self, config, trainer):
pass
# @trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
# def change_config(engine):
# with read_write(config):
# config.loss.perceptual.weight = 5
def run(task, config, _):
kernel = TAFGEngineKernel(config)
run_kernel(task, config, kernel)