import itertools from pathlib import Path import torch import torch.nn as nn import torchvision.utils import ignite.distributed as idist from ignite.engine import Events, Engine from ignite.contrib.handlers.param_scheduler import PiecewiseLinear from ignite.metrics import RunningAverage from ignite.contrib.handlers import ProgressBar from ignite.utils import convert_tensor from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler from omegaconf import OmegaConf import data from loss.gan import GANLoss from model.weight_init import generation_init_weights from model.GAN.residual_generator import GANImageBuffer from util.image import make_2d_grid from util.handler import setup_common_handlers from util.build import build_model, build_optimizer def get_trainer(config, logger): generator_a = build_model(config.model.generator, config.distributed.model) generator_b = build_model(config.model.generator, config.distributed.model) discriminator_a = build_model(config.model.discriminator, config.distributed.model) discriminator_b = build_model(config.model.discriminator, config.distributed.model) for m in [generator_b, generator_a, discriminator_b, discriminator_a]: generation_init_weights(m) logger.info(discriminator_a) logger.info(generator_a) optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()), config.optimizers.generator) optimizer_d = build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()), config.optimizers.discriminator) milestones_values = [ (0, config.optimizers.generator.lr), (100, config.optimizers.generator.lr), (200, config.data.train.scheduler.target_lr) ] lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values) milestones_values = [ (0, config.optimizers.discriminator.lr), (100, config.optimizers.discriminator.lr), (200, config.data.train.scheduler.target_lr) ] lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values) gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) def _step(engine, batch): batch = convert_tensor(batch, idist.device()) real_a, real_b = batch["a"], batch["b"] fake_b = generator_a(real_a) # G_A(A) rec_a = generator_b(fake_b) # G_B(G_A(A)) fake_a = generator_b(real_b) # G_B(B) rec_b = generator_a(fake_a) # G_A(G_B(B)) optimizer_g.zero_grad() discriminator_a.requires_grad_(False) discriminator_b.requires_grad_(False) loss_g = dict( cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a), cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b), gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True), gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True) ) if config.loss.id.weight > 0: loss_g["id_a"] = config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B) loss_g["id_b"] = config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A) sum(loss_g.values()).backward() optimizer_g.step() discriminator_a.requires_grad_(True) discriminator_b.requires_grad_(True) optimizer_d.zero_grad() loss_d_a = dict( real=gan_loss(discriminator_a(real_b), True, is_discriminator=True), fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True), ) loss_d_b = dict( real=gan_loss(discriminator_b(real_a), True, is_discriminator=True), fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True), ) (sum(loss_d_a.values()) * 0.5).backward() (sum(loss_d_b.values()) * 0.5).backward() optimizer_d.step() return { "loss": { "g": {ln: loss_g[ln].mean().item() for ln in loss_g}, "d_a": {ln: loss_d_a[ln].mean().item() for ln in loss_d_a}, "d_b": {ln: loss_d_b[ln].mean().item() for ln in loss_d_b}, }, "img": [ real_a.detach(), fake_b.detach(), rec_a.detach(), real_b.detach(), fake_a.detach(), rec_b.detach() ] } trainer = Engine(_step) trainer.logger = logger trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_g) trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_d) RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a") RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b") to_save = dict( generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a, discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g ) setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5, filename_prefix=config.name, to_save=to_save, print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED, metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], save_interval_event=Events.ITERATION_COMPLETED( every=config.checkpoints.interval) | Events.COMPLETED) @trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration)) def terminate(engine): engine.terminate() if idist.get_rank() == 0: # Create a logger tb_logger = TensorboardLogger(log_dir=config.output_dir) tb_writer = tb_logger.writer # Attach the logger to the trainer to log training loss at each iteration def global_step_transform(*args, **kwargs): return trainer.state.iteration def output_transform(output): loss = dict() for tl in output["loss"]: if isinstance(output["loss"][tl], dict): for l in output["loss"][tl]: loss[f"{tl}_{l}"] = output["loss"][tl][l] else: loss[tl] = output["loss"][tl] return loss tb_logger.attach( trainer, log_handler=OutputHandler( tag="loss", metric_names=["loss_g", "loss_d_a", "loss_d_b"], global_step_transform=global_step_transform, output_transform=output_transform ), event_name=Events.ITERATION_COMPLETED(every=50) ) tb_logger.attach( trainer, log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"), event_name=Events.ITERATION_STARTED(every=50) ) @trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval)) def show_images(engine): tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration) @trainer.on(Events.COMPLETED) @idist.one_rank_only() def _(): # We need to close the logger with we are done tb_logger.close() return trainer def get_tester(config, logger): generator_a = build_model(config.model.generator, config.distributed.model) generator_b = build_model(config.model.generator, config.distributed.model) def _step(engine, batch): batch = convert_tensor(batch, idist.device()) real_a, real_b = batch["a"], batch["b"] with torch.no_grad(): fake_b = generator_a(real_a) # G_A(A) rec_a = generator_b(fake_b) # G_B(G_A(A)) fake_a = generator_b(real_b) # G_B(B) rec_b = generator_a(fake_a) # G_A(G_B(B)) return [ real_a.detach(), fake_b.detach(), rec_a.detach(), real_b.detach(), fake_a.detach(), rec_b.detach() ] tester = Engine(_step) tester.logger = logger if idist.get_rank == 0: ProgressBar(ncols=0).attach(tester) to_load = dict(generator_a=generator_a, generator_b=generator_b) setup_common_handlers(tester, use_profiler=False, to_save=to_load, resume_from=config.resume_from) @tester.on(Events.STARTED) @idist.one_rank_only() def mkdir(engine): img_output_dir = Path(config.output_dir) / "test_images" if not img_output_dir.exists(): engine.logger.info(f"mkdir {img_output_dir}") img_output_dir.mkdir() @tester.on(Events.ITERATION_COMPLETED) def save_images(engine): img_tensors = engine.state.output batch_size = img_tensors[0].size(0) for i in range(batch_size): torchvision.utils.save_image([img[i] for img in img_tensors], Path(config.output_dir) / f"test_images/{engine.state.iteration}_{i}.jpg", nrow=len(img_tensors)) return tester def run(task, config, logger): assert torch.backends.cudnn.enabled torch.backends.cudnn.benchmark = True logger.info(f"start task {task}") if task == "train": train_dataset = data.DATASET.build_with(config.data.train.dataset) logger.info(f"train with dataset:\n{train_dataset}") train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) trainer = get_trainer(config, logger) try: trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1) except Exception: import traceback print(traceback.format_exc()) elif task == "test": assert config.resume_from is not None test_dataset = data.DATASET.build_with(config.data.test.dataset) logger.info(f"test with dataset:\n{test_dataset}") test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader) tester = get_tester(config, logger) try: tester.run(test_data_loader, max_epochs=1) except Exception: import traceback print(traceback.format_exc()) else: return NotImplemented(f"invalid task: {task}")