from itertools import chain from math import ceil from pathlib import Path import logging import torch import ignite.distributed as idist from ignite.engine import Events, Engine from ignite.metrics import RunningAverage from ignite.utils import convert_tensor from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler from ignite.contrib.handlers.param_scheduler import PiecewiseLinear from model import MODEL from util.image import make_2d_grid from util.handler import setup_common_handlers, setup_tensorboard_handler from util.build import build_optimizer from omegaconf import OmegaConf def build_model(cfg): cfg = OmegaConf.to_container(cfg) bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False) model = MODEL.build_with(cfg) if bn_to_sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) return idist.auto_model(model) def build_lr_schedulers(optimizers, config): # TODO: support more scheduler type g_milestones_values = [ (0, config.optimizers.generator.lr), (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr), (config.max_iteration, config.data.train.scheduler.target_lr) ] d_milestones_values = [ (0, config.optimizers.discriminator.lr), (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr), (config.max_iteration, config.data.train.scheduler.target_lr) ] return dict( g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values), d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values) ) class EngineKernel(object): def __init__(self, config, logger): self.config = config self.logger = logger self.generators, self.discriminators = self.build_models() def build_models(self) -> (dict, dict): raise NotImplemented def to_save(self): to_save = {} to_save.update({f"generator_{k}": self.generators[k] for k in self.generators}) to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators}) return to_save def setup_before_d(self): raise NotImplemented def setup_before_g(self): raise NotImplemented def forward(self, batch, inference=False) -> dict: raise NotImplemented def criterion_generators(self, batch, generated) -> dict: raise NotImplemented def criterion_discriminators(self, batch, generated) -> dict: raise NotImplemented def intermediate_images(self, batch, generated) -> dict: """ returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} :param batch: :param generated: dict of images :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ raise NotImplemented def get_trainer(config, ek: EngineKernel, iter_per_epoch): logger = logging.getLogger(config.name) generators, discriminators = ek.generators, ek.discriminators optimizers = dict( g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator), d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator), ) logger.info("build optimizers", optimizers) lr_schedulers = build_lr_schedulers(optimizers, config) logger.info(f"build lr_schedulers:\n{lr_schedulers}") def _step(engine, batch): batch = convert_tensor(batch, idist.device()) generated = ek.forward(batch) ek.setup_before_g() optimizers["g"].zero_grad() loss_g = ek.criterion_generators(batch, generated) sum(loss_g.values()).backward() optimizers["g"].step() ek.setup_before_d() optimizers["d"].zero_grad() loss_d = ek.criterion_discriminators(batch, generated) sum(loss_d.values()).backward() optimizers["d"].step() return { "loss": dict(g=loss_g, d=loss_d), "img": ek.intermediate_images(batch, generated) } trainer = Engine(_step) trainer.logger = logger for lr_shd in lr_schedulers.values(): trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd) RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") to_save = dict(trainer=trainer) to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers}) to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) to_save.update(ek.to_save()) setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) def output_transform(output): loss = dict() for tl in output["loss"]: if isinstance(output["loss"][tl], dict): for l in output["loss"][tl]: loss[f"{tl}_{l}"] = output["loss"][tl][l] else: loss[tl] = output["loss"][tl] return loss pairs_per_iteration = config.data.train.dataloader.batch_size tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch) if tensorboard_handler is not None: tensorboard_handler.attach( trainer, log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"), event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1)) ) @trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1))) def show_images(engine): output = engine.state.output test_images = {} for k in output["img"]: image_list = output["img"][k] tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list), engine.state.iteration * pairs_per_iteration) test_images[k] = [] for i in range(len(image_list)): test_images[k].append([]) with torch.no_grad(): g = torch.Generator() g.manual_seed(config.misc.random_seed) random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0] for i in range(random_start, random_start + 10): batch = convert_tensor(engine.state.test_dataset[i], idist.device()) for k in batch: batch[k] = batch[k].view(1, *batch[k].size()) generated = ek.forward(batch) images = ek.intermediate_images(batch, generated) for k in test_images: for j in range(len(images[k])): test_images[k][j].append(images[k][j]) for k in test_images: tensorboard_handler.writer.add_image( f"test/{k}", make_2d_grid([torch.cat(ti) for ti in test_images[k]]), engine.state.iteration * pairs_per_iteration ) return trainer