284 lines
13 KiB
Python
284 lines
13 KiB
Python
from itertools import chain
|
|
from math import ceil
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader
|
|
|
|
import ignite.distributed as idist
|
|
from ignite.engine import Events, Engine
|
|
from ignite.metrics import RunningAverage
|
|
from ignite.utils import convert_tensor
|
|
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
|
|
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
|
|
|
from omegaconf import OmegaConf, read_write
|
|
|
|
import data
|
|
from loss.gan import GANLoss
|
|
from model.weight_init import generation_init_weights
|
|
from model.GAN.residual_generator import GANImageBuffer
|
|
from model.GAN.UGATIT import RhoClipper
|
|
from util.image import make_2d_grid, fuse_attention_map
|
|
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
|
from util.build import build_model, build_optimizer
|
|
|
|
|
|
def get_trainer(config, logger):
|
|
generators = dict(
|
|
a2b=build_model(config.model.generator, config.distributed.model),
|
|
b2a=build_model(config.model.generator, config.distributed.model),
|
|
)
|
|
discriminators = dict(
|
|
la=build_model(config.model.local_discriminator, config.distributed.model),
|
|
lb=build_model(config.model.local_discriminator, config.distributed.model),
|
|
ga=build_model(config.model.global_discriminator, config.distributed.model),
|
|
gb=build_model(config.model.global_discriminator, config.distributed.model),
|
|
)
|
|
for m in chain(generators.values(), discriminators.values()):
|
|
generation_init_weights(m)
|
|
|
|
logger.debug(discriminators["ga"])
|
|
logger.debug(generators["a2b"])
|
|
|
|
optimizer_g = build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator)
|
|
optimizer_d = build_optimizer(chain(*[m.parameters() for m in discriminators.values()]),
|
|
config.optimizers.discriminator)
|
|
|
|
milestones_values = [
|
|
(0, config.optimizers.generator.lr),
|
|
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
|
]
|
|
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
|
|
|
milestones_values = [
|
|
(0, config.optimizers.discriminator.lr),
|
|
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
|
]
|
|
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
|
|
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
|
gan_loss_cfg.pop("weight")
|
|
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
|
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
|
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
|
|
|
def mse_loss(x, target_flag):
|
|
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
|
|
|
def bce_loss(x, target_flag):
|
|
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
|
|
|
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
|
|
rho_clipper = RhoClipper(0, 1)
|
|
|
|
def criterion_generator(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
|
|
discriminator_g):
|
|
discriminator_g.requires_grad_(False)
|
|
discriminator_l.requires_grad_(False)
|
|
pred_fake_g, cam_gd_pred = discriminator_g(fake)
|
|
pred_fake_l, cam_ld_pred = discriminator_l(fake)
|
|
return {
|
|
f"cycle_{name}": config.loss.cycle.weight * cycle_loss(real, rec),
|
|
f"id_{name}": config.loss.id.weight * id_loss(real, identity),
|
|
f"cam_{name}": config.loss.cam.weight * (bce_loss(cam_g_pred, True) + bce_loss(cam_g_id_pred, False)),
|
|
f"gan_l_{name}": config.loss.gan.weight * gan_loss(pred_fake_l, True),
|
|
f"gan_g_{name}": config.loss.gan.weight * gan_loss(pred_fake_g, True),
|
|
f"gan_cam_g_{name}": config.loss.gan.weight * mse_loss(cam_gd_pred, True),
|
|
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
|
|
}
|
|
|
|
def criterion_discriminator(name, discriminator, real, fake):
|
|
pred_real, cam_real = discriminator(real)
|
|
pred_fake, cam_fake = discriminator(fake)
|
|
# TODO: origin do not divide 2, but I think it better to divide 2.
|
|
loss_gan = gan_loss(pred_real, True, is_discriminator=True) + gan_loss(pred_fake, False, is_discriminator=True)
|
|
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
|
|
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
|
|
|
|
def _step(engine, real):
|
|
real = convert_tensor(real, idist.device())
|
|
|
|
fake = dict()
|
|
cam_generator_pred = dict()
|
|
rec = dict()
|
|
identity = dict()
|
|
cam_identity_pred = dict()
|
|
heatmap = dict()
|
|
|
|
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real["a"])
|
|
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real["b"])
|
|
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
|
|
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
|
|
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"])
|
|
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"])
|
|
|
|
optimizer_g.zero_grad()
|
|
loss_g = dict()
|
|
for n in ["a", "b"]:
|
|
loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
|
|
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
|
|
sum(loss_g.values()).backward()
|
|
optimizer_g.step()
|
|
for generator in generators.values():
|
|
generator.apply(rho_clipper)
|
|
for discriminator in discriminators.values():
|
|
discriminator.requires_grad_(True)
|
|
|
|
optimizer_d.zero_grad()
|
|
loss_d = dict()
|
|
for k in discriminators.keys():
|
|
n = k[-1] # "a" or "b"
|
|
loss_d.update(
|
|
criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach())))
|
|
sum(loss_d.values()).backward()
|
|
optimizer_d.step()
|
|
|
|
for h in heatmap:
|
|
heatmap[h] = heatmap[h].detach()
|
|
generated_img = {f"real_{k}": real[k].detach() for k in real}
|
|
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
|
|
generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
|
|
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
|
|
|
|
return {
|
|
"loss": {
|
|
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
|
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
|
|
},
|
|
"img": {
|
|
"heatmap": heatmap,
|
|
"generated": generated_img
|
|
}
|
|
}
|
|
|
|
trainer = Engine(_step)
|
|
trainer.logger = logger
|
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
|
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
|
|
|
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
|
|
|
to_save = dict(optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d,
|
|
lr_scheduler_g=lr_scheduler_g)
|
|
to_save.update({f"generator_{k}": generators[k] for k in generators})
|
|
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
|
|
|
setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"],
|
|
clear_cuda_cache=False, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
|
|
|
def output_transform(output):
|
|
loss = dict()
|
|
for tl in output["loss"]:
|
|
if isinstance(output["loss"][tl], dict):
|
|
for l in output["loss"][tl]:
|
|
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
|
else:
|
|
loss[tl] = output["loss"][tl]
|
|
return loss
|
|
|
|
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
|
|
if tensorboard_handler is not None:
|
|
tensorboard_handler.attach(
|
|
trainer,
|
|
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
|
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
|
|
)
|
|
|
|
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
|
def show_images(engine):
|
|
output = engine.state.output
|
|
image_a_order = ["real_a", "fake_b", "rec_a", "id_a"]
|
|
image_b_order = ["real_b", "fake_a", "rec_b", "id_b"]
|
|
|
|
output["img"]["generated"]["real_a"] = fuse_attention_map(
|
|
output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"])
|
|
output["img"]["generated"]["real_b"] = fuse_attention_map(
|
|
output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"])
|
|
|
|
tensorboard_handler.writer.add_image(
|
|
"train/a",
|
|
make_2d_grid([output["img"]["generated"][o] for o in image_a_order]),
|
|
engine.state.iteration
|
|
)
|
|
tensorboard_handler.writer.add_image(
|
|
"train/b",
|
|
make_2d_grid([output["img"]["generated"][o] for o in image_b_order]),
|
|
engine.state.iteration
|
|
)
|
|
|
|
with torch.no_grad():
|
|
g = torch.Generator()
|
|
g.manual_seed(config.misc.random_seed)
|
|
indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10]
|
|
|
|
empty_grid = torch.zeros(0, config.model.generator.in_channels, config.model.generator.img_size,
|
|
config.model.generator.img_size)
|
|
fake = dict(a=empty_grid.clone(), b=empty_grid.clone())
|
|
rec = dict(a=empty_grid.clone(), b=empty_grid.clone())
|
|
heatmap = dict(a2b=torch.zeros(0, 1, config.model.generator.img_size,
|
|
config.model.generator.img_size),
|
|
b2a=torch.zeros(0, 1, config.model.generator.img_size,
|
|
config.model.generator.img_size))
|
|
real = dict(a=empty_grid.clone(), b=empty_grid.clone())
|
|
for i in indices:
|
|
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
|
|
|
real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size())
|
|
fake_b, _, heatmap_a2b = generators["a2b"](real_a)
|
|
fake_a, _, heatmap_b2a = generators["b2a"](real_b)
|
|
rec_a = generators["b2a"](fake_b)[0]
|
|
rec_b = generators["a2b"](fake_a)[0]
|
|
|
|
fake["a"] = torch.cat([fake["a"], fake_a.cpu()])
|
|
fake["b"] = torch.cat([fake["b"], fake_b.cpu()])
|
|
real["a"] = torch.cat([real["a"], real_a.cpu()])
|
|
real["b"] = torch.cat([real["b"], real_b.cpu()])
|
|
rec["a"] = torch.cat([rec["a"], rec_a.cpu()])
|
|
rec["b"] = torch.cat([rec["b"], rec_b.cpu()])
|
|
|
|
heatmap["a2b"] = torch.cat(
|
|
[heatmap["a2b"], torch.nn.functional.interpolate(heatmap_a2b, real_a.size()[-2:]).cpu()])
|
|
heatmap["b2a"] = torch.cat(
|
|
[heatmap["b2a"], torch.nn.functional.interpolate(heatmap_b2a, real_a.size()[-2:]).cpu()])
|
|
tensorboard_handler.writer.add_image(
|
|
"test/a",
|
|
make_2d_grid([heatmap["a2b"].expand_as(real["a"]), real["a"], fake["b"], rec["a"]]),
|
|
engine.state.iteration
|
|
)
|
|
tensorboard_handler.writer.add_image(
|
|
"test/b",
|
|
make_2d_grid([heatmap["b2a"].expand_as(real["a"]), real["b"], fake["a"], rec["b"]]),
|
|
engine.state.iteration
|
|
)
|
|
|
|
return trainer
|
|
|
|
|
|
def run(task, config, logger):
|
|
assert torch.backends.cudnn.enabled
|
|
torch.backends.cudnn.benchmark = True
|
|
logger.info(f"start task {task}")
|
|
with read_write(config):
|
|
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
|
|
|
|
if task == "train":
|
|
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
|
logger.info(f"train with dataset:\n{train_dataset}")
|
|
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
|
trainer = get_trainer(config, logger)
|
|
if idist.get_rank() == 0:
|
|
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
|
trainer.state.test_dataset = test_dataset
|
|
try:
|
|
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
|
except Exception:
|
|
import traceback
|
|
print(traceback.format_exc())
|
|
else:
|
|
return NotImplemented(f"invalid task: {task}")
|