import logging from itertools import chain from pathlib import Path import ignite.distributed as idist import torch import torchvision from ignite.contrib.handlers.param_scheduler import PiecewiseLinear from ignite.engine import Events, Engine from ignite.metrics import RunningAverage from ignite.utils import convert_tensor from math import ceil from omegaconf import read_write, OmegaConf import data from engine.util.build import build_optimizer from engine.util.handler import setup_common_handlers, setup_tensorboard_handler from util.image import make_2d_grid def build_lr_schedulers(optimizers, config): # TODO: support more scheduler type g_milestones_values = [ (0, config.optimizers.generator.lr), (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr), (config.max_iteration, config.data.train.scheduler.target_lr) ] d_milestones_values = [ (0, config.optimizers.discriminator.lr), (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr), (config.max_iteration, config.data.train.scheduler.target_lr) ] return dict( g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values), d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values) ) class TestEngineKernel(object): def __init__(self, config): self.config = config self.logger = logging.getLogger(config.name) self.generators = self.build_generators() def build_generators(self) -> dict: raise NotImplemented def to_load(self): return {f"generator_{k}": self.generators[k] for k in self.generators} def inference(self, batch): raise NotImplemented class EngineKernel(object): def __init__(self, config): self.config = config self.logger = logging.getLogger(config.name) self.generators, self.discriminators = self.build_models() self.train_generator_first = True self.engine = None def bind_engine(self, engine): self.engine = engine def build_models(self) -> (dict, dict): raise NotImplemented def to_save(self): to_save = {} to_save.update({f"generator_{k}": self.generators[k] for k in self.generators}) to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators}) return to_save def setup_after_g(self): raise NotImplemented def setup_before_g(self): raise NotImplemented def forward(self, batch, inference=False) -> dict: raise NotImplemented def criterion_generators(self, batch, generated) -> dict: raise NotImplemented def criterion_discriminators(self, batch, generated) -> dict: raise NotImplemented def intermediate_images(self, batch, generated) -> dict: """ returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} :param batch: :param generated: dict of images :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ raise NotImplemented def change_engine(self, config, engine: Engine): pass def get_trainer(config, kernel: EngineKernel): logger = logging.getLogger(config.name) generators, discriminators = kernel.generators, kernel.discriminators optimizers = dict( g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator), d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator), ) logger.info(f"build optimizers:\n{optimizers}") lr_schedulers = build_lr_schedulers(optimizers, config) logger.info(f"build lr_schedulers:\n{lr_schedulers}") iteration_per_image = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1) def train_generators(batch, generated): kernel.setup_before_g() optimizers["g"].zero_grad() loss_g = kernel.criterion_generators(batch, generated) sum(loss_g.values()).backward() optimizers["g"].step() kernel.setup_after_g() return loss_g def train_discriminators(batch, generated): optimizers["d"].zero_grad() loss_d = kernel.criterion_discriminators(batch, generated) sum(loss_d.values()).backward() optimizers["d"].step() return loss_d def _step(engine, batch): batch = convert_tensor(batch, idist.device()) generated = kernel.forward(batch) if kernel.train_generator_first: # simultaneous, train G with simultaneous D loss_g = train_generators(batch, generated) loss_d = train_discriminators(batch, generated) else: # update discriminators first, not simultaneous. # train G with updated discriminators loss_d = train_discriminators(batch, generated) loss_g = train_generators(batch, generated) if engine.state.iteration % iteration_per_image == 0: return { "loss": dict(g=loss_g, d=loss_d), "img": kernel.intermediate_images(batch, generated) } return {"loss": dict(g=loss_g, d=loss_d)} trainer = Engine(_step) trainer.logger = logger for lr_shd in lr_schedulers.values(): trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd) kernel.change_engine(config, trainer) kernel.bind_engine(trainer) RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d") to_save = dict(trainer=trainer) to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers}) to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) to_save.update(kernel.to_save()) setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=config.handler.clear_cuda_cache, set_epoch_for_dist_sampler=config.handler.set_epoch_for_dist_sampler, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item") if tensorboard_handler is not None: basic_image_event = Events.ITERATION_COMPLETED( every=iteration_per_image) pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size() @trainer.on(basic_image_event) def show_images(engine): output = engine.state.output test_images = {} for k in output["img"]: image_list = output["img"][k] tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list, range=(-1, 1)), engine.state.iteration * pairs_per_iteration) test_images[k] = [] for i in range(len(image_list)): test_images[k].append([]) with torch.no_grad(): g = torch.Generator() g.manual_seed(config.misc.random_seed + engine.state.epoch if config.handler.test.random else config.misc.random_seed) random_start = \ torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0] for i in range(random_start, random_start + config.handler.test.images): batch = convert_tensor(engine.state.test_dataset[i], idist.device()) for k in batch: if isinstance(batch[k], torch.Tensor): batch[k] = batch[k].unsqueeze(0) elif isinstance(batch[k], dict): for kk in batch[k]: if isinstance(batch[k][kk], torch.Tensor): batch[k][kk] = batch[k][kk].unsqueeze(0) generated = kernel.forward(batch) images = kernel.intermediate_images(batch, generated) for k in test_images: for j in range(len(images[k])): test_images[k][j].append(images[k][j]) for k in test_images: tensorboard_handler.writer.add_image( f"test/{k}", make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)), engine.state.iteration * pairs_per_iteration ) return trainer def save_images_helper(output_dir, paths, images_list): batch_size = len(paths) for i in range(batch_size): image_name = Path(paths[i]).name img_list = [img[i] for img in images_list] torchvision.utils.save_image(img_list, Path(output_dir) / image_name, nrow=len(img_list), padding=0, normalize=True, range=(-1, 1)) def get_tester(config, kernel: TestEngineKernel): logger = logging.getLogger(config.name) def _step(engine, batch): batch = convert_tensor(batch, idist.device()) return {"batch": batch, "generated": kernel.inference(batch)} tester = Engine(_step) tester.logger = logger setup_common_handlers(tester, config, use_profiler=True, to_save=kernel.to_load()) @tester.on(Events.STARTED) def mkdir(engine): img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images" engine.state.img_output_dir = Path(img_output_dir) if idist.get_rank() == 0 and not engine.state.img_output_dir.exists(): engine.logger.info(f"mkdir {engine.state.img_output_dir}") engine.state.img_output_dir.mkdir() @tester.on(Events.ITERATION_COMPLETED) def save_images(engine): if engine.state.dataloader.dataset.__class__.__name__ == "SingleFolderDataset": images, paths = engine.state.output["batch"] save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"]]) else: for k in engine.state.output['generated']: images, paths = engine.state.output["batch"][k] save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"][k]]) return tester def run_kernel(task, config, kernel): logger = logging.getLogger(config.name) with read_write(config): real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size() config.max_iteration = ceil(config.max_pairs / real_batch_size) if task == "train": train_dataset = data.DATASET.build_with(config.data.train.dataset) logger.info(f"train with dataset:\n{train_dataset}") dataloader_kwargs = OmegaConf.to_container(config.data.train.dataloader) dataloader_kwargs["batch_size"] = dataloader_kwargs["batch_size"] * idist.get_world_size() train_data_loader = idist.auto_dataloader(train_dataset, **dataloader_kwargs) with read_write(config): config.iterations_per_epoch = len(train_data_loader) trainer = get_trainer(config, kernel) if idist.get_rank() == 0: test_dataset = data.DATASET.build_with(config.data.test.dataset) trainer.state.test_dataset = test_dataset try: trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader))) except Exception: import traceback print(traceback.format_exc()) elif task == "test": assert config.resume_from is not None test_dataset = data.DATASET.build_with(config.data.test[config.data.test.which]) logger.info(f"test with dataset:\n{test_dataset}") test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader) tester = get_tester(config, kernel) try: tester.run(test_data_loader, max_epochs=1) except Exception: import traceback print(traceback.format_exc()) else: return NotImplemented(f"invalid task: {task}")