from itertools import chain from math import ceil import torch import torch.nn as nn import torch.nn.functional as F import ignite.distributed as idist from ignite.engine import Events, Engine from ignite.metrics import RunningAverage from ignite.utils import convert_tensor from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler from ignite.contrib.handlers.param_scheduler import PiecewiseLinear from omegaconf import OmegaConf, read_write import data from loss.gan import GANLoss from model.weight_init import generation_init_weights from model.GAN.residual_generator import GANImageBuffer from model.GAN.UGATIT import RhoClipper from util.image import make_2d_grid, fuse_attention_map, attention_colored_map from util.handler import setup_common_handlers, setup_tensorboard_handler from util.build import build_model, build_optimizer def build_lr_schedulers(optimizers, config): g_milestones_values = [ (0, config.optimizers.generator.lr), (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr), (config.max_iteration, config.data.train.scheduler.target_lr) ] d_milestones_values = [ (0, config.optimizers.discriminator.lr), (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr), (config.max_iteration, config.data.train.scheduler.target_lr) ] return dict( g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values), d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values) ) def get_trainer(config, logger): generators = dict( a2b=build_model(config.model.generator, config.distributed.model), b2a=build_model(config.model.generator, config.distributed.model), ) discriminators = dict( la=build_model(config.model.local_discriminator, config.distributed.model), lb=build_model(config.model.local_discriminator, config.distributed.model), ga=build_model(config.model.global_discriminator, config.distributed.model), gb=build_model(config.model.global_discriminator, config.distributed.model), ) for m in chain(generators.values(), discriminators.values()): generation_init_weights(m) logger.debug(discriminators["ga"]) logger.debug(generators["a2b"]) optimizers = dict( g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator), d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator), ) logger.info(f"build optimizers:\n{optimizers}") lr_schedulers = build_lr_schedulers(optimizers, config) logger.info(f"build lr_schedulers:\n{lr_schedulers}") gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() def mse_loss(x, target_flag): return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) def bce_loss(x, target_flag): return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()} rho_clipper = RhoClipper(0, 1) def criterion_generator(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l, discriminator_g): discriminator_g.requires_grad_(False) discriminator_l.requires_grad_(False) pred_fake_g, cam_gd_pred = discriminator_g(fake) pred_fake_l, cam_ld_pred = discriminator_l(fake) return { f"cycle_{name}": config.loss.cycle.weight * cycle_loss(real, rec), f"id_{name}": config.loss.id.weight * id_loss(real, identity), f"cam_{name}": config.loss.cam.weight * (bce_loss(cam_g_pred, True) + bce_loss(cam_g_id_pred, False)), f"gan_l_{name}": config.loss.gan.weight * gan_loss(pred_fake_l, True), f"gan_g_{name}": config.loss.gan.weight * gan_loss(pred_fake_g, True), f"gan_cam_g_{name}": config.loss.gan.weight * mse_loss(cam_gd_pred, True), f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True), } def criterion_discriminator(name, discriminator, real, fake): pred_real, cam_real = discriminator(real) pred_fake, cam_fake = discriminator(fake) # TODO: origin do not divide 2, but I think it better to divide 2. loss_gan = gan_loss(pred_real, True, is_discriminator=True) + gan_loss(pred_fake, False, is_discriminator=True) loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False) return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam} def _step(engine, real): real = convert_tensor(real, idist.device()) fake = dict() cam_generator_pred = dict() rec = dict() identity = dict() cam_identity_pred = dict() heatmap = dict() fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real["a"]) fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real["b"]) rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"]) rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"]) identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"]) identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"]) optimizers["g"].zero_grad() loss_g = dict() for n in ["a", "b"]: loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n], cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n])) sum(loss_g.values()).backward() optimizers["g"].step() for generator in generators.values(): generator.apply(rho_clipper) for discriminator in discriminators.values(): discriminator.requires_grad_(True) optimizers["d"].zero_grad() loss_d = dict() for k in discriminators.keys(): n = k[-1] # "a" or "b" loss_d.update( criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach()))) sum(loss_d.values()).backward() optimizers["d"].step() for h in heatmap: heatmap[h] = heatmap[h].detach() generated_img = {f"real_{k}": real[k].detach() for k in real} generated_img.update({f"fake_{k}": fake[k].detach() for k in fake}) generated_img.update({f"id_{k}": identity[k].detach() for k in identity}) generated_img.update({f"rec_{k}": rec[k].detach() for k in rec}) return { "loss": { "g": {ln: loss_g[ln].mean().item() for ln in loss_g}, "d": {ln: loss_d[ln].mean().item() for ln in loss_d}, }, "img": { "heatmap": heatmap, "generated": generated_img } } trainer = Engine(_step) trainer.logger = logger for lr_shd in lr_schedulers.values(): trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd) RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") to_save = dict(trainer=trainer) to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers}) to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) to_save.update({f"generator_{k}": generators[k] for k in generators}) to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators}) setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) def output_transform(output): loss = dict() for tl in output["loss"]: if isinstance(output["loss"][tl], dict): for l in output["loss"][tl]: loss[f"{tl}_{l}"] = output["loss"][tl][l] else: loss[tl] = output["loss"][tl] return loss tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform) if tensorboard_handler is not None: tensorboard_handler.attach( trainer, log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"), event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar) ) @trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image)) def show_images(engine): output = engine.state.output image_order = dict( a=["real_a", "fake_b", "rec_a", "id_a"], b=["real_b", "fake_a", "rec_b", "id_b"] ) output["img"]["generated"]["real_a"] = fuse_attention_map( output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"]) output["img"]["generated"]["real_b"] = fuse_attention_map( output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"]) for k in "ab": tensorboard_handler.writer.add_image( f"train/{k}", make_2d_grid([output["img"]["generated"][o] for o in image_order[k]]), engine.state.iteration ) with torch.no_grad(): g = torch.Generator() g.manual_seed(config.misc.random_seed) indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10] test_images = dict( a=[[], [], [], []], b=[[], [], [], []] ) for i in indices: batch = convert_tensor(engine.state.test_dataset[i], idist.device()) real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size()) fake_b, _, heatmap_a2b = generators["a2b"](real_a) fake_a, _, heatmap_b2a = generators["b2a"](real_b) rec_a = generators["b2a"](fake_b)[0] rec_b = generators["a2b"](fake_a)[0] for idx, im in enumerate( [attention_colored_map(heatmap_a2b, real_a.size()[-2:]), real_a, fake_b, rec_a]): test_images["a"][idx].append(im.cpu()) for idx, im in enumerate( [attention_colored_map(heatmap_b2a, real_b.size()[-2:]), real_b, fake_a, rec_b]): test_images["b"][idx].append(im.cpu()) for n in "ab": tensorboard_handler.writer.add_image( f"test/{n}", make_2d_grid([torch.cat(ti) for ti in test_images[n]]), engine.state.iteration ) return trainer def run(task, config, logger): assert torch.backends.cudnn.enabled torch.backends.cudnn.benchmark = True logger.info(f"start task {task}") with read_write(config): config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size) if task == "train": train_dataset = data.DATASET.build_with(config.data.train.dataset) logger.info(f"train with dataset:\n{train_dataset}") train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) trainer = get_trainer(config, logger) if idist.get_rank() == 0: test_dataset = data.DATASET.build_with(config.data.test.dataset) trainer.state.test_dataset = test_dataset try: trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader))) except Exception: import traceback print(traceback.format_exc()) else: return NotImplemented(f"invalid task: {task}")