134 lines
5.7 KiB
Python
134 lines
5.7 KiB
Python
from itertools import chain
|
|
from math import ceil
|
|
|
|
from omegaconf import read_write, OmegaConf
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import ignite.distributed as idist
|
|
|
|
import data
|
|
from engine.base.i2i import get_trainer, EngineKernel, build_model
|
|
from model.weight_init import generation_init_weights
|
|
|
|
from loss.I2I.perceptual_loss import PerceptualLoss
|
|
from loss.gan import GANLoss
|
|
|
|
|
|
class TAFGEngineKernel(EngineKernel):
|
|
def __init__(self, config, logger):
|
|
super().__init__(config, logger)
|
|
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
|
perceptual_loss_cfg.pop("weight")
|
|
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
|
|
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
|
gan_loss_cfg.pop("weight")
|
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
|
|
|
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
|
|
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
|
|
|
def build_models(self) -> (dict, dict):
|
|
generators = dict(
|
|
main=build_model(self.config.model.generator)
|
|
)
|
|
discriminators = dict(
|
|
a=build_model(self.config.model.discriminator),
|
|
b=build_model(self.config.model.discriminator)
|
|
)
|
|
self.logger.debug(discriminators["a"])
|
|
self.logger.debug(generators["main"])
|
|
|
|
for m in chain(generators.values(), discriminators.values()):
|
|
generation_init_weights(m)
|
|
|
|
return generators, discriminators
|
|
|
|
def setup_before_d(self):
|
|
for discriminator in self.discriminators.values():
|
|
discriminator.requires_grad_(True)
|
|
|
|
def setup_before_g(self):
|
|
for discriminator in self.discriminators.values():
|
|
discriminator.requires_grad_(False)
|
|
|
|
def forward(self, batch, inference=False) -> dict:
|
|
generator = self.generators["main"]
|
|
with torch.set_grad_enabled(not inference):
|
|
fake = dict(
|
|
a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"),
|
|
b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"),
|
|
)
|
|
return fake
|
|
|
|
def criterion_generators(self, batch, generated) -> dict:
|
|
loss = dict()
|
|
loss["perceptual"], _, = self.perceptual_loss(generated["b"], batch["b"]) * self.config.loss.perceptual.weight
|
|
for phase in "ab":
|
|
pred_fake = self.discriminators[phase](generated[phase])
|
|
for i, sub_pred_fake in enumerate(pred_fake):
|
|
# last output is actual prediction
|
|
loss[f"gan_{phase}_sub_{i}"] = self.gan_loss(sub_pred_fake[-1], True)
|
|
|
|
if self.config.loss.fm.weight > 0 and phase == "b":
|
|
pred_real = self.discriminators[phase](batch[phase])
|
|
loss_fm = 0
|
|
num_scale_discriminator = len(pred_fake)
|
|
for i in range(num_scale_discriminator):
|
|
# last output is the final prediction, so we exclude it
|
|
num_intermediate_outputs = len(pred_fake[i]) - 1
|
|
for j in range(num_intermediate_outputs):
|
|
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
|
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
|
loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight
|
|
return loss
|
|
|
|
def criterion_discriminators(self, batch, generated) -> dict:
|
|
loss = dict()
|
|
for phase in self.discriminators.keys():
|
|
pred_real = self.discriminators[phase](batch[phase])
|
|
pred_fake = self.discriminators[phase](generated[phase].detach())
|
|
loss[f"gan_{phase}"] = 0
|
|
for i in range(len(pred_fake)):
|
|
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
|
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
|
return loss
|
|
|
|
def intermediate_images(self, batch, generated) -> dict:
|
|
"""
|
|
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
:param batch:
|
|
:param generated: dict of images
|
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
"""
|
|
return dict(
|
|
a=[batch[f"edge_a"].expand(-1, 3, -1, -1).detach(), batch["a"].detach(), generated["a"].detach()],
|
|
b=[batch["b"].detach(), generated["b"].detach()]
|
|
)
|
|
|
|
|
|
def run(task, config, logger):
|
|
assert torch.backends.cudnn.enabled
|
|
torch.backends.cudnn.benchmark = True
|
|
logger.info(f"start task {task}")
|
|
with read_write(config):
|
|
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
|
|
|
|
if task == "train":
|
|
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
|
logger.info(f"train with dataset:\n{train_dataset}")
|
|
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
|
trainer = get_trainer(config, TAFGEngineKernel(config, logger), len(train_data_loader))
|
|
if idist.get_rank() == 0:
|
|
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
|
trainer.state.test_dataset = test_dataset
|
|
try:
|
|
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
|
except Exception:
|
|
import traceback
|
|
print(traceback.format_exc())
|
|
else:
|
|
return NotImplemented(f"invalid task: {task}")
|