from itertools import chain import logging from pathlib import Path import torch import torchvision import ignite.distributed as idist from ignite.engine import Events, Engine from ignite.metrics import RunningAverage from ignite.utils import convert_tensor from ignite.contrib.handlers.param_scheduler import PiecewiseLinear from omegaconf import read_write, OmegaConf from util.image import make_2d_grid from engine.util.handler import setup_common_handlers, setup_tensorboard_handler from engine.util.build import build_optimizer import data def build_lr_schedulers(optimizers, config): # TODO: support more scheduler type g_milestones_values = [ (0, config.optimizers.generator.lr), (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr), (config.max_iteration, config.data.train.scheduler.target_lr) ] d_milestones_values = [ (0, config.optimizers.discriminator.lr), (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr), (config.max_iteration, config.data.train.scheduler.target_lr) ] return dict( g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values), d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values) ) class TestEngineKernel(object): def __init__(self, config): self.config = config self.logger = logging.getLogger(config.name) self.generators = self.build_generators() def build_generators(self) -> dict: raise NotImplemented def to_load(self): return {f"generator_{k}": self.generators[k] for k in self.generators} def inference(self, batch): raise NotImplemented class EngineKernel(object): def __init__(self, config): self.config = config self.logger = logging.getLogger(config.name) self.generators, self.discriminators = self.build_models() def build_models(self) -> (dict, dict): raise NotImplemented def to_save(self): to_save = {} to_save.update({f"generator_{k}": self.generators[k] for k in self.generators}) to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators}) return to_save def setup_before_d(self): raise NotImplemented def setup_before_g(self): raise NotImplemented def forward(self, batch, inference=False) -> dict: raise NotImplemented def criterion_generators(self, batch, generated) -> dict: raise NotImplemented def criterion_discriminators(self, batch, generated) -> dict: raise NotImplemented def intermediate_images(self, batch, generated) -> dict: """ returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} :param batch: :param generated: dict of images :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ raise NotImplemented def get_trainer(config, kernel: EngineKernel): logger = logging.getLogger(config.name) generators, discriminators = kernel.generators, kernel.discriminators optimizers = dict( g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator), d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator), ) logger.info(f"build optimizers:\n{optimizers}") lr_schedulers = build_lr_schedulers(optimizers, config) logger.info(f"build lr_schedulers:\n{lr_schedulers}") image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1) def _step(engine, batch): batch = convert_tensor(batch, idist.device()) generated = kernel.forward(batch) kernel.setup_before_g() optimizers["g"].zero_grad() loss_g = kernel.criterion_generators(batch, generated) sum(loss_g.values()).backward() optimizers["g"].step() kernel.setup_before_d() optimizers["d"].zero_grad() loss_d = kernel.criterion_discriminators(batch, generated) sum(loss_d.values()).backward() optimizers["d"].step() if engine.state.iteration % image_per_iteration == 0: return { "loss": dict(g=loss_g, d=loss_d), "img": kernel.intermediate_images(batch, generated) } return {"loss": dict(g=loss_g, d=loss_d)} trainer = Engine(_step) trainer.logger = logger for lr_shd in lr_schedulers.values(): trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd) RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") to_save = dict(trainer=trainer) to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers}) to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) to_save.update(kernel.to_save()) setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=config.handler.clear_cuda_cache, set_epoch_for_dist_sampler=config.handler.set_epoch_for_dist_sampler, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item") if tensorboard_handler is not None: basic_image_event = Events.ITERATION_COMPLETED( every=image_per_iteration) pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size() @trainer.on(basic_image_event) def show_images(engine): output = engine.state.output test_images = {} for k in output["img"]: image_list = output["img"][k] tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list, range=(-1, 1)), engine.state.iteration * pairs_per_iteration) test_images[k] = [] for i in range(len(image_list)): test_images[k].append([]) with torch.no_grad(): g = torch.Generator() g.manual_seed(config.misc.random_seed) random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0] for i in range(random_start, random_start + 10): batch = convert_tensor(engine.state.test_dataset[i], idist.device()) for k in batch: batch[k] = batch[k].view(1, *batch[k].size()) generated = kernel.forward(batch) images = kernel.intermediate_images(batch, generated) for k in test_images: for j in range(len(images[k])): test_images[k][j].append(images[k][j]) for k in test_images: tensorboard_handler.writer.add_image( f"test/{k}", make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)), engine.state.iteration * pairs_per_iteration ) return trainer def get_tester(config, kernel: TestEngineKernel): logger = logging.getLogger(config.name) def _step(engine, batch): real_a, path = convert_tensor(batch, idist.device()) fake = kernel.inference({"a": real_a})["a"] return {"path": path, "img": [real_a.detach(), fake.detach()]} tester = Engine(_step) tester.logger = logger setup_common_handlers(tester, config, use_profiler=True, to_save=kernel.to_load()) @tester.on(Events.STARTED) def mkdir(engine): img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images" engine.state.img_output_dir = Path(img_output_dir) if idist.get_rank() == 0 and not engine.state.img_output_dir.exists(): engine.logger.info(f"mkdir {engine.state.img_output_dir}") engine.state.img_output_dir.mkdir() @tester.on(Events.ITERATION_COMPLETED) def save_images(engine): img_tensors = engine.state.output["img"] paths = engine.state.output["path"] batch_size = img_tensors[0].size(0) for i in range(batch_size): image_name = Path(paths[i]).name torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name, nrow=len(img_tensors), padding=0, normalize=True, range=(-1, 1)) return tester def run_kernel(task, config, kernel): logger = logging.getLogger(config.name) with read_write(config): real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size() config.max_iteration = config.max_pairs // real_batch_size + 1 if task == "train": train_dataset = data.DATASET.build_with(config.data.train.dataset) logger.info(f"train with dataset:\n{train_dataset}") dataloader_kwargs = OmegaConf.to_container(config.data.train.dataloader) dataloader_kwargs["batch_size"] = dataloader_kwargs["batch_size"] * idist.get_world_size() train_data_loader = idist.auto_dataloader(train_dataset, **dataloader_kwargs) with read_write(config): config.iterations_per_epoch = len(train_data_loader) trainer = get_trainer(config, kernel) if idist.get_rank() == 0: test_dataset = data.DATASET.build_with(config.data.test.dataset) trainer.state.test_dataset = test_dataset try: trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1) except Exception: import traceback print(traceback.format_exc()) elif task == "test": assert config.resume_from is not None test_dataset = data.DATASET.build_with(config.data.test.video_dataset) logger.info(f"test with dataset:\n{test_dataset}") test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader) tester = get_tester(config, kernel) try: tester.run(test_data_loader, max_epochs=1) except Exception: import traceback print(traceback.format_exc()) else: return NotImplemented(f"invalid task: {task}")