diff --git a/configs/synthesizers/cyclegan.yml b/configs/synthesizers/cyclegan.yml index 13b30fa..9647f63 100644 --- a/configs/synthesizers/cyclegan.yml +++ b/configs/synthesizers/cyclegan.yml @@ -61,6 +61,7 @@ optimizers: data: train: + buffer_size: 50 dataloader: batch_size: 16 shuffle: True diff --git a/engine/cyclegan.py b/engine/cyclegan.py index 7391d19..cacb6f9 100644 --- a/engine/cyclegan.py +++ b/engine/cyclegan.py @@ -3,59 +3,41 @@ from pathlib import Path import torch import torch.nn as nn -import torch.optim as optim import torchvision.utils import ignite.distributed as idist from ignite.engine import Events, Engine from ignite.contrib.handlers.param_scheduler import PiecewiseLinear from ignite.metrics import RunningAverage -from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan -from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar +from ignite.contrib.handlers import ProgressBar from ignite.utils import convert_tensor from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler from omegaconf import OmegaConf import data -from model import MODEL from loss.gan import GANLoss -from util.distributed import auto_model +from model.weight_init import generation_init_weights +from model.residual_generator import GANImageBuffer from util.image import make_2d_grid -from util.handler import Resumer - - -def _build_model(cfg, distributed_args=None): - cfg = OmegaConf.to_container(cfg) - model_distributed_config = cfg.pop("_distributed", dict()) - model = MODEL.build_with(cfg) - - if model_distributed_config.get("bn_to_syncbn"): - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - - distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args - return auto_model(model, **distributed_args) - - -def _build_optimizer(params, cfg): - assert "_type" in cfg - cfg = OmegaConf.to_container(cfg) - optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg) - return idist.auto_optim(optimizer) +from util.handler import setup_common_handlers +from util.build import build_model, build_optimizer def get_trainer(config, logger): - generator_a = _build_model(config.model.generator, config.distributed.model) - generator_b = _build_model(config.model.generator, config.distributed.model) - discriminator_a = _build_model(config.model.discriminator, config.distributed.model) - discriminator_b = _build_model(config.model.discriminator, config.distributed.model) + generator_a = build_model(config.model.generator, config.distributed.model) + generator_b = build_model(config.model.generator, config.distributed.model) + discriminator_a = build_model(config.model.discriminator, config.distributed.model) + discriminator_b = build_model(config.model.discriminator, config.distributed.model) + for m in [generator_b, generator_a, discriminator_b, discriminator_a]: + generation_init_weights(m) logger.debug(discriminator_a) logger.debug(generator_a) - optimizer_g = _build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()), - config.optimizers.generator) - optimizer_d = _build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()), - config.optimizers.discriminator) + optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()), + config.optimizers.generator) + optimizer_d = build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()), + config.optimizers.discriminator) milestones_values = [ (config.data.train.scheduler.start, config.optimizers.generator.lr), @@ -75,16 +57,21 @@ def get_trainer(config, logger): cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss() id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss() + image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) + image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) + def _step(engine, batch): batch = convert_tensor(batch, idist.device()) real_a, real_b = batch["a"], batch["b"] - optimizer_g.zero_grad() fake_b = generator_a(real_a) # G_A(A) rec_a = generator_b(fake_b) # G_B(G_A(A)) fake_a = generator_b(real_b) # G_B(B) rec_b = generator_a(fake_a) # G_A(G_B(B)) + optimizer_g.zero_grad() + discriminator_a.requires_grad_(False) + discriminator_b.requires_grad_(False) loss_g = dict( id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B) id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A) @@ -96,17 +83,19 @@ def get_trainer(config, logger): sum(loss_g.values()).backward() optimizer_g.step() + discriminator_a.requires_grad_(True) + discriminator_b.requires_grad_(True) optimizer_d.zero_grad() loss_d_a = dict( real=gan_loss(discriminator_a(real_b), True, is_discriminator=True), - fake=gan_loss(discriminator_a(fake_b.detach()), False, is_discriminator=True), + fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True), ) + (sum(loss_d_a.values()) * 0.5).backward() loss_d_b = dict( real=gan_loss(discriminator_b(real_a), True, is_discriminator=True), - fake=gan_loss(discriminator_b(fake_a.detach()), False, is_discriminator=True), + fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True), ) - loss_d = sum(loss_d_a.values()) / 2 + sum(loss_d_b.values()) / 2 - loss_d.backward() + (sum(loss_d_b.values()) * 0.5).backward() optimizer_d.step() return { @@ -129,27 +118,25 @@ def get_trainer(config, logger): trainer.logger = logger trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g) trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d) - trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) + RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a") RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b") - @trainer.on(Events.ITERATION_COMPLETED(every=10)) - def print_log(engine): - engine.logger.info(f"iter:[{engine.state.iteration}/{config.max_iteration}]" - f"loss_g={engine.state.metrics['loss_g']:.3f} " - f"loss_d_a={engine.state.metrics['loss_d_a']:.3f} " - f"loss_d_b={engine.state.metrics['loss_d_b']:.3f} ") - to_save = dict( generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a, discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g ) - trainer.add_event_handler(Events.STARTED, Resumer(to_save, config.resume_from)) - checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir), n_saved=None) - trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.checkpoints.interval), checkpoint_handler) + setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.ITERATION_COMPLETED(every=10), + metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], to_save=to_save, + resume_from=config.resume_from, n_saved=5, filename_prefix=config.name, + save_interval_event=Events.ITERATION_COMPLETED(every=config.checkpoints.interval)) + + @trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration)) + def terminate(engine): + engine.terminate() if idist.get_rank() == 0: # Create a logger @@ -169,7 +156,6 @@ def get_trainer(config, logger): ), event_name=Events.ITERATION_COMPLETED(every=50) ) - # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration tb_logger.attach( trainer, log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"), @@ -180,28 +166,18 @@ def get_trainer(config, logger): def show_images(engine): tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration) - # Create an object of the profiler and attach an engine to it - profiler = BasicTimeProfiler() - profiler.attach(trainer) - - @trainer.on(Events.EPOCH_COMPLETED(once=1)) - @idist.one_rank_only() - def log_intermediate_results(): - profiler.print_results(profiler.get_results()) - - @trainer.on(Events.COMPLETED) - @idist.one_rank_only() - def _(): - profiler.write_results(f"{config.output_dir}/time_profiling.csv") - # We need to close the logger with we are done - tb_logger.close() + @trainer.on(Events.COMPLETED) + @idist.one_rank_only() + def _(): + # We need to close the logger with we are done + tb_logger.close() return trainer def get_tester(config, logger): - generator_a = _build_model(config.model.generator, config.distributed.model) - generator_b = _build_model(config.model.generator, config.distributed.model) + generator_a = build_model(config.model.generator, config.distributed.model) + generator_b = build_model(config.model.generator, config.distributed.model) def _step(engine, batch): batch = convert_tensor(batch, idist.device()) @@ -225,7 +201,7 @@ def get_tester(config, logger): if idist.get_rank == 0: ProgressBar(ncols=0).attach(tester) to_load = dict(generator_a=generator_a, generator_b=generator_b) - tester.add_event_handler(Events.STARTED, Resumer(to_load, config.resume_from)) + setup_common_handlers(tester, use_profiler=False, to_save=to_load, resume_from=config.resume_from) @tester.on(Events.STARTED) @idist.one_rank_only() @@ -248,15 +224,16 @@ def get_tester(config, logger): def run(task, config, logger): + assert torch.backends.cudnn.enabled + torch.backends.cudnn.benchmark = True logger.info(f"start task {task}") if task == "train": train_dataset = data.DATASET.build_with(config.data.train.dataset) logger.info(f"train with dataset:\n{train_dataset}") train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) trainer = get_trainer(config, logger) - trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1) try: - trainer.run(train_data_loader, max_epochs=1) + trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1) except Exception: import traceback print(traceback.format_exc()) diff --git a/model/residual_generator.py b/model/residual_generator.py index 413cdec..c3138be 100644 --- a/model/residual_generator.py +++ b/model/residual_generator.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import functools from .registry import MODEL @@ -14,6 +15,60 @@ def _select_norm_layer(norm_type): raise NotImplemented(f'normalization layer {norm_type} is not found') +class GANImageBuffer(object): + """This class implements an image buffer that stores previously + generated images. + This buffer allows us to update the discriminator using a history of + generated images rather than the ones produced by the latest generator + to reduce model oscillation. + Args: + buffer_size (int): The size of image buffer. If buffer_size = 0, + no buffer will be created. + buffer_ratio (float): The chance / possibility to use the images + previously stored in the buffer. + """ + + def __init__(self, buffer_size, buffer_ratio=0.5): + self.buffer_size = buffer_size + # create an empty buffer + if self.buffer_size > 0: + self.img_num = 0 + self.image_buffer = [] + self.buffer_ratio = buffer_ratio + + def query(self, images): + """Query current image batch using a history of generated images. + Args: + images (Tensor): Current image batch without history information. + """ + if self.buffer_size == 0: # if the buffer size is 0, do nothing + return images + return_images = [] + for image in images: + image = torch.unsqueeze(image.data, 0) + # if the buffer is not full, keep inserting current images + if self.img_num < self.buffer_size: + self.img_num = self.img_num + 1 + self.image_buffer.append(image) + return_images.append(image) + else: + use_buffer = torch.rand(1) < self.buffer_ratio + # by self.buffer_ratio, the buffer will return a previously + # stored image, and insert the current image into the buffer + if use_buffer: + random_id = torch.randint(0, self.buffer_size, (1,)).item() + image_tmp = self.image_buffer[random_id].clone() + self.image_buffer[random_id] = image + return_images.append(image_tmp) + # by (1 - self.buffer_ratio), the buffer will return the + # current image + else: + return_images.append(image) + # collect all the images and return + return_images = torch.cat(return_images, 0) + return return_images + + @MODEL.register_module() class ResidualBlock(nn.Module): def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False): diff --git a/util/handler.py b/util/handler.py index fd2291e..db1d3d1 100644 --- a/util/handler.py +++ b/util/handler.py @@ -1,21 +1,86 @@ from pathlib import Path import torch -from ignite.engine import Engine -from ignite.handlers import Checkpoint + +import ignite.distributed as idist +from ignite.engine import Events +from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan +from ignite.contrib.handlers import BasicTimeProfiler -class Resumer: - def __init__(self, to_load, checkpoint_path): - self.to_load = to_load - if checkpoint_path is not None: - checkpoint_path = Path(checkpoint_path) - if not checkpoint_path.exists(): - raise ValueError(f"Checkpoint '{checkpoint_path}' is not found") - self.checkpoint_path = checkpoint_path +def setup_common_handlers( + trainer, + output_dir=None, + stop_on_nan=True, + use_profiler=True, + print_interval_event=None, + metrics_to_print=None, + to_save=None, + resume_from=None, + save_interval_event=None, + **checkpoint_kwargs +): + """ + Helper method to setup trainer with common handlers. + 1. TerminateOnNan + 2. BasicTimeProfiler + 3. Print + 4. Checkpoint + :param trainer: trainer engine. Output of trainer's `update_function` should be a dictionary + or sequence or a single tensor. + :param output_dir: output path to indicate where `to_save` objects are stored. Argument is mutually + :param stop_on_nan: if True, :class:`~ignite.handlers.TerminateOnNan` handler is added to the trainer. + :param use_profiler: + :param print_interval_event: + :param metrics_to_print: + :param to_save: + :param resume_from: + :param save_interval_event: + :param checkpoint_kwargs: + :return: + """ + if stop_on_nan: + trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) - def __call__(self, engine: Engine): - if self.checkpoint_path is not None: - ckp = torch.load(self.checkpoint_path.as_posix(), map_location="cpu") - Checkpoint.load_objects(to_load=self.to_load, checkpoint=ckp) - engine.logger.info(f"resume from a checkpoint {self.checkpoint_path}") + if use_profiler: + # Create an object of the profiler and attach an engine to it + profiler = BasicTimeProfiler() + profiler.attach(trainer) + + @trainer.on(Events.EPOCH_COMPLETED(once=1)) + @idist.one_rank_only() + def log_intermediate_results(): + profiler.print_results(profiler.get_results()) + + @trainer.on(Events.COMPLETED) + @idist.one_rank_only() + def _(): + profiler.print_results(profiler.get_results()) + # profiler.write_results(f"{output_dir}/time_profiling.csv") + + if metrics_to_print is not None: + if print_interval_event is None: + raise ValueError( + "If metrics_to_print argument is provided then print_interval_event arguments should be also defined" + ) + + @trainer.on(print_interval_event) + def print_interval(engine): + print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t" + for m in metrics_to_print: + print_str += f"{m}={engine.state.metrics[m]:.3f} " + engine.logger.info(print_str) + + if to_save is not None: + if resume_from is not None: + @trainer.on(Events.STARTED) + def resume(engine): + checkpoint_path = Path(resume_from) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found") + ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu") + Checkpoint.load_objects(to_load=to_save, checkpoint=ckp) + engine.logger.info(f"resume from a checkpoint {checkpoint_path}") + if save_interval_event is not None: + checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir), **checkpoint_kwargs) + trainer.add_event_handler(save_interval_event, checkpoint_handler)