import itertools from pathlib import Path import torch import torch.nn as nn import torch.optim as optim import torchvision.utils import ignite.distributed as idist from ignite.engine import Events, Engine from ignite.contrib.handlers.param_scheduler import PiecewiseLinear from ignite.metrics import RunningAverage from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar from ignite.utils import convert_tensor from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler from omegaconf import OmegaConf import data from model import MODEL from loss.gan import GANLoss from util.distributed import auto_model from util.image import make_2d_grid from util.handler import Resumer def _build_model(cfg, distributed_args=None): cfg = OmegaConf.to_container(cfg) model_distributed_config = cfg.pop("_distributed", dict()) model = MODEL.build_with(cfg) if model_distributed_config.get("bn_to_syncbn"): model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args return auto_model(model, **distributed_args) def _build_optimizer(params, cfg): assert "_type" in cfg cfg = OmegaConf.to_container(cfg) optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg) return idist.auto_optim(optimizer) def get_trainer(config, logger): generator_a = _build_model(config.model.generator, config.distributed.model) generator_b = _build_model(config.model.generator, config.distributed.model) discriminator_a = _build_model(config.model.discriminator, config.distributed.model) discriminator_b = _build_model(config.model.discriminator, config.distributed.model) logger.debug(discriminator_a) logger.debug(generator_a) optimizer_g = _build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()), config.optimizers.generator) optimizer_d = _build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()), config.optimizers.discriminator) milestones_values = [ (config.data.train.scheduler.start, config.optimizers.generator.lr), (config.max_iteration, config.data.train.scheduler.target_lr), ] lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values) milestones_values = [ (config.data.train.scheduler.start, config.optimizers.discriminator.lr), (config.max_iteration, config.data.train.scheduler.target_lr), ] lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values) gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss() id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss() def _step(engine, batch): batch = convert_tensor(batch, idist.device()) real_a, real_b = batch["a"], batch["b"] optimizer_g.zero_grad() fake_b = generator_a(real_a) # G_A(A) rec_a = generator_b(fake_b) # G_B(G_A(A)) fake_a = generator_b(real_b) # G_B(B) rec_b = generator_a(fake_a) # G_A(G_B(B)) loss_g = dict( id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B) id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A) cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a), cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b), gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True), gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True) ) sum(loss_g.values()).backward() optimizer_g.step() optimizer_d.zero_grad() loss_d_a = dict( real=gan_loss(discriminator_a(real_b), True, is_discriminator=True), fake=gan_loss(discriminator_a(fake_b.detach()), False, is_discriminator=True), ) loss_d_b = dict( real=gan_loss(discriminator_b(real_a), True, is_discriminator=True), fake=gan_loss(discriminator_b(fake_a.detach()), False, is_discriminator=True), ) loss_d = sum(loss_d_a.values()) / 2 + sum(loss_d_b.values()) / 2 loss_d.backward() optimizer_d.step() return { "loss": { "g": {ln: loss_g[ln].mean().item() for ln in loss_g}, "d_a": {ln: loss_d_a[ln].mean().item() for ln in loss_d_a}, "d_b": {ln: loss_d_b[ln].mean().item() for ln in loss_d_b}, }, "img": [ real_a.detach(), fake_b.detach(), rec_a.detach(), real_b.detach(), fake_a.detach(), rec_b.detach() ] } trainer = Engine(_step) trainer.logger = logger trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g) trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a") RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b") @trainer.on(Events.ITERATION_COMPLETED(every=10)) def print_log(engine): engine.logger.info(f"iter:[{engine.state.iteration}/{config.max_iteration}]" f"loss_g={engine.state.metrics['loss_g']:.3f} " f"loss_d_a={engine.state.metrics['loss_d_a']:.3f} " f"loss_d_b={engine.state.metrics['loss_d_b']:.3f} ") to_save = dict( generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a, discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g ) trainer.add_event_handler(Events.STARTED, Resumer(to_save, config.resume_from)) checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir), n_saved=None) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.checkpoints.interval), checkpoint_handler) if idist.get_rank() == 0: # Create a logger tb_logger = TensorboardLogger(log_dir=config.output_dir) tb_writer = tb_logger.writer # Attach the logger to the trainer to log training loss at each iteration def global_step_transform(*args, **kwargs): return trainer.state.iteration tb_logger.attach( trainer, log_handler=OutputHandler( tag="loss", metric_names=["loss_g", "loss_d_a", "loss_d_b"], global_step_transform=global_step_transform, ), event_name=Events.ITERATION_COMPLETED(every=50) ) # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration tb_logger.attach( trainer, log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"), event_name=Events.ITERATION_STARTED(every=50) ) @trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval)) def show_images(engine): tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration) # Create an object of the profiler and attach an engine to it profiler = BasicTimeProfiler() profiler.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED(once=1)) @idist.one_rank_only() def log_intermediate_results(): profiler.print_results(profiler.get_results()) @trainer.on(Events.COMPLETED) @idist.one_rank_only() def _(): profiler.write_results(f"{config.output_dir}/time_profiling.csv") # We need to close the logger with we are done tb_logger.close() return trainer def get_tester(config, logger): generator_a = _build_model(config.model.generator, config.distributed.model) generator_b = _build_model(config.model.generator, config.distributed.model) def _step(engine, batch): batch = convert_tensor(batch, idist.device()) real_a, real_b = batch["a"], batch["b"] with torch.no_grad(): fake_b = generator_a(real_a) # G_A(A) rec_a = generator_b(fake_b) # G_B(G_A(A)) fake_a = generator_b(real_b) # G_B(B) rec_b = generator_a(fake_a) # G_A(G_B(B)) return [ real_a.detach(), fake_b.detach(), rec_a.detach(), real_b.detach(), fake_a.detach(), rec_b.detach() ] tester = Engine(_step) tester.logger = logger if idist.get_rank == 0: ProgressBar(ncols=0).attach(tester) to_load = dict(generator_a=generator_a, generator_b=generator_b) tester.add_event_handler(Events.STARTED, Resumer(to_load, config.resume_from)) @tester.on(Events.STARTED) @idist.one_rank_only() def mkdir(engine): img_output_dir = Path(config.output_dir) / "test_images" if not img_output_dir.exists(): engine.logger.info(f"mkdir {img_output_dir}") img_output_dir.mkdir() @tester.on(Events.ITERATION_COMPLETED) def save_images(engine): img_tensors = engine.state.output batch_size = img_tensors[0].size(0) for i in range(batch_size): torchvision.utils.save_image([img[i] for img in img_tensors], Path(config.output_dir) / f"test_images/{engine.state.iteration}_{i}.jpg", nrow=len(img_tensors)) return tester def run(task, config, logger): logger.info(f"start task {task}") if task == "train": train_dataset = data.DATASET.build_with(config.data.train.dataset) logger.info(f"train with dataset:\n{train_dataset}") train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) trainer = get_trainer(config, logger) trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1) try: trainer.run(train_data_loader, max_epochs=1) except Exception: import traceback print(traceback.format_exc()) elif task == "test": test_dataset = data.DATASET.build_with(config.data.test.dataset) logger.info(f"test with dataset:\n{test_dataset}") test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader) tester = get_tester(config, logger) try: tester.run(test_data_loader, max_epochs=1) except Exception: import traceback print(traceback.format_exc()) else: return NotImplemented(f"invalid task: {task}")