132 lines
5.0 KiB
Python
132 lines
5.0 KiB
Python
from itertools import chain
|
|
|
|
import ignite.distributed as idist
|
|
import torch
|
|
import torch.nn as nn
|
|
from omegaconf import OmegaConf
|
|
|
|
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
|
from engine.util.build import build_model
|
|
from loss.I2I.perceptual_loss import PerceptualLoss
|
|
from loss.gan import GANLoss
|
|
from model.weight_init import generation_init_weights
|
|
|
|
|
|
class TSITEngineKernel(EngineKernel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
|
perceptual_loss_cfg.pop("weight")
|
|
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
|
|
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
|
gan_loss_cfg.pop("weight")
|
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
|
|
|
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
|
|
|
|
def build_models(self) -> (dict, dict):
|
|
generators = dict(
|
|
main=build_model(self.config.model.generator)
|
|
)
|
|
discriminators = dict(
|
|
b=build_model(self.config.model.discriminator)
|
|
)
|
|
self.logger.debug(discriminators["b"])
|
|
self.logger.debug(generators["main"])
|
|
|
|
for m in chain(generators.values(), discriminators.values()):
|
|
generation_init_weights(m)
|
|
|
|
return generators, discriminators
|
|
|
|
def setup_after_g(self):
|
|
for discriminator in self.discriminators.values():
|
|
discriminator.requires_grad_(True)
|
|
|
|
def setup_before_g(self):
|
|
for discriminator in self.discriminators.values():
|
|
discriminator.requires_grad_(False)
|
|
|
|
def forward(self, batch, inference=False) -> dict:
|
|
with torch.set_grad_enabled(not inference):
|
|
fake = dict(
|
|
b=self.generators["main"](content_img=batch["a"], style_img=batch["b"])
|
|
)
|
|
return fake
|
|
|
|
def criterion_generators(self, batch, generated) -> dict:
|
|
loss = dict()
|
|
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
|
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
|
|
for phase in "b":
|
|
pred_fake = self.discriminators[phase](generated[phase])
|
|
loss[f"gan_{phase}"] = 0
|
|
for sub_pred_fake in pred_fake:
|
|
# last output is actual prediction
|
|
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
|
|
|
|
if self.config.loss.fm.weight > 0 and phase == "b":
|
|
pred_real = self.discriminators[phase](batch[phase])
|
|
loss_fm = 0
|
|
num_scale_discriminator = len(pred_fake)
|
|
for i in range(num_scale_discriminator):
|
|
# last output is the final prediction, so we exclude it
|
|
num_intermediate_outputs = len(pred_fake[i]) - 1
|
|
for j in range(num_intermediate_outputs):
|
|
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
|
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
|
return loss
|
|
|
|
def criterion_discriminators(self, batch, generated) -> dict:
|
|
loss = dict()
|
|
for phase in self.discriminators.keys():
|
|
pred_real = self.discriminators[phase](batch[phase])
|
|
pred_fake = self.discriminators[phase](generated[phase].detach())
|
|
loss[f"gan_{phase}"] = 0
|
|
for i in range(len(pred_fake)):
|
|
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
|
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
|
return loss
|
|
|
|
def intermediate_images(self, batch, generated) -> dict:
|
|
"""
|
|
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
:param batch:
|
|
:param generated: dict of images
|
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
"""
|
|
return dict(
|
|
b=[batch["a"].detach(), batch["b"].detach(), generated["b"].detach()]
|
|
)
|
|
|
|
|
|
class TSITTestEngineKernel(TestEngineKernel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
def build_generators(self) -> dict:
|
|
generators = dict(
|
|
main=build_model(self.config.model.generator)
|
|
)
|
|
return generators
|
|
|
|
def to_load(self):
|
|
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
|
|
|
def inference(self, batch):
|
|
with torch.no_grad():
|
|
fake = self.generators["main"](content_img=batch["a"][0], style_img=batch["b"][0])
|
|
return {"a": fake.detach()}
|
|
|
|
|
|
def run(task, config, _):
|
|
if task == "train":
|
|
kernel = TSITEngineKernel(config)
|
|
run_kernel(task, config, kernel)
|
|
elif task == "test":
|
|
kernel = TSITTestEngineKernel(config)
|
|
run_kernel(task, config, kernel)
|
|
else:
|
|
raise NotImplemented
|