diff --git a/configs/synthesizers/UGATIT.yml b/configs/synthesizers/UGATIT.yml index cee8eeb..d9e8ad7 100644 --- a/configs/synthesizers/UGATIT.yml +++ b/configs/synthesizers/UGATIT.yml @@ -1,7 +1,7 @@ -name: selfie2anime +name: selfie2anime-origin engine: UGATIT result_dir: ./result -max_iteration: 100000 +max_pairs: 1000000 distributed: model: @@ -10,8 +10,15 @@ distributed: misc: random_seed: 324 -checkpoints: - interval: 1000 +checkpoint: + epoch_interval: 1 # one checkpoint every 1 epoch + n_saved: 5 + +interval: + print_per_iteration: 10 # print once per 10 iteration + tensorboard: + scalar: 10 + image: 1000 model: generator: @@ -26,12 +33,12 @@ model: _type: UGATIT-Discriminator in_channels: 3 base_channels: 64 - num_blocks: 3 + num_blocks: 5 global_discriminator: _type: UGATIT-Discriminator in_channels: 3 base_channels: 64 - num_blocks: 5 + num_blocks: 7 loss: gan: @@ -62,9 +69,12 @@ optimizers: data: train: + scheduler: + start_proportion: 0.5 + target_lr: 0 buffer_size: 50 dataloader: - batch_size: 8 + batch_size: 4 shuffle: True num_workers: 2 pin_memory: True @@ -85,9 +95,6 @@ data: - Normalize: mean: [0.5, 0.5, 0.5] std: [0.5, 0.5, 0.5] - scheduler: - start: 50000 - target_lr: 0 test: dataloader: batch_size: 4 diff --git a/engine/UGATIT.py b/engine/UGATIT.py index 007b0d7..d9bd5de 100644 --- a/engine/UGATIT.py +++ b/engine/UGATIT.py @@ -1,20 +1,18 @@ from itertools import chain -from pathlib import Path +from math import ceil import torch import torch.nn as nn import torch.nn.functional as F -import torchvision.utils import ignite.distributed as idist from ignite.engine import Events, Engine -from ignite.contrib.handlers.param_scheduler import PiecewiseLinear from ignite.metrics import RunningAverage -from ignite.contrib.handlers import ProgressBar from ignite.utils import convert_tensor -from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler +from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler +from ignite.contrib.handlers.param_scheduler import PiecewiseLinear -from omegaconf import OmegaConf +from omegaconf import OmegaConf, read_write import data from loss.gan import GANLoss @@ -22,7 +20,7 @@ from model.weight_init import generation_init_weights from model.GAN.residual_generator import GANImageBuffer from model.GAN.UGATIT import RhoClipper from util.image import make_2d_grid -from util.handler import setup_common_handlers +from util.handler import setup_common_handlers, setup_tensorboard_handler from util.build import build_model, build_optimizer @@ -49,14 +47,14 @@ def get_trainer(config, logger): milestones_values = [ (0, config.optimizers.generator.lr), - (config.data.train.scheduler.start, config.optimizers.generator.lr), + (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr), (config.max_iteration, config.data.train.scheduler.target_lr) ] lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values) milestones_values = [ (0, config.optimizers.discriminator.lr), - (config.data.train.scheduler.start, config.optimizers.discriminator.lr), + (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr), (config.max_iteration, config.data.train.scheduler.target_lr) ] lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values) @@ -66,18 +64,18 @@ def get_trainer(config, logger): gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() - bce_loss = nn.BCEWithLogitsLoss().to(idist.device()) - mse_loss = lambda x, t: F.mse_loss(x, x.new_ones(x.size()) if t else x.new_zeros(x.size())) - bce_loss = lambda x, t: F.binary_cross_entropy_with_logits(x, x.new_ones(x.size()) if t else x.new_zeros(x.size())) - image_buffers = { - k: GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) for k in - discriminators.keys()} + def mse_loss(x, target_flag): + return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) + def bce_loss(x, target_flag): + return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) + + image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()} rho_clipper = RhoClipper(0, 1) - def cal_generator_loss(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l, - discriminator_g): + def criterion_generator(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l, + discriminator_g): discriminator_g.requires_grad_(False) discriminator_l.requires_grad_(False) pred_fake_g, cam_gd_pred = discriminator_g(fake) @@ -92,7 +90,7 @@ def get_trainer(config, logger): f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True), } - def cal_discriminator_loss(name, discriminator, real, fake): + def criterion_discriminator(name, discriminator, real, fake): pred_real, cam_real = discriminator(real) pred_fake, cam_fake = discriminator(fake) # TODO: origin do not divide 2, but I think it better to divide 2. @@ -100,9 +98,8 @@ def get_trainer(config, logger): loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False) return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam} - def _step(engine, batch): - batch = convert_tensor(batch, idist.device()) - real_a, real_b = batch["a"], batch["b"] + def _step(engine, real): + real = convert_tensor(real, idist.device()) fake = dict() cam_generator_pred = dict() @@ -111,18 +108,18 @@ def get_trainer(config, logger): cam_identity_pred = dict() heatmap = dict() - fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real_a) - fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real_b) + fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real["a"]) + fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real["b"]) rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"]) rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"]) - identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real_a) - identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real_b) + identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"]) + identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"]) optimizer_g.zero_grad() loss_g = dict() for n in ["a", "b"]: - loss_g.update(cal_generator_loss(n, batch[n], fake[n], rec[n], identity[n], cam_generator_pred[n], - cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n])) + loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n], + cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n])) sum(loss_g.values()).backward() optimizer_g.step() for generator in generators.values(): @@ -135,13 +132,14 @@ def get_trainer(config, logger): for k in discriminators.keys(): n = k[-1] # "a" or "b" loss_d.update( - cal_discriminator_loss(k, discriminators[k], batch[n], image_buffers[k].query(fake[n].detach()))) + criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach()))) sum(loss_d.values()).backward() optimizer_d.step() for h in heatmap: heatmap[h] = heatmap[h].detach() - generated_img = {f"fake_{k}": fake[k].detach() for k in fake} + generated_img = {f"real_{k}": real[k].detach() for k in real} + generated_img.update({f"fake_{k}": fake[k].detach() for k in fake}) generated_img.update({f"id_{k}": identity[k].detach() for k in identity}) generated_img.update({f"rec_{k}": rec[k].detach() for k in rec}) @@ -169,64 +167,41 @@ def get_trainer(config, logger): to_save.update({f"generator_{k}": generators[k] for k in generators}) to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators}) - setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5, - filename_prefix=config.name, to_save=to_save, - print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED, - metrics_to_print=["loss_g", "loss_d"], - save_interval_event=Events.ITERATION_COMPLETED( - every=config.checkpoints.interval) | Events.COMPLETED) + setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"], + clear_cuda_cache=True, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) - @trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration)) - def terminate(engine): - engine.terminate() + def output_transform(output): + loss = dict() + for tl in output["loss"]: + if isinstance(output["loss"][tl], dict): + for l in output["loss"][tl]: + loss[f"{tl}_{l}"] = output["loss"][tl][l] + else: + loss[tl] = output["loss"][tl] + return loss - if idist.get_rank() == 0: - # Create a logger - tb_logger = TensorboardLogger(log_dir=config.output_dir) - tb_writer = tb_logger.writer - - # Attach the logger to the trainer to log training loss at each iteration - def global_step_transform(*args, **kwargs): - return trainer.state.iteration - - def output_transform(output): - loss = dict() - for tl in output["loss"]: - if isinstance(output["loss"][tl], dict): - for l in output["loss"][tl]: - loss[f"{tl}_{l}"] = output["loss"][tl][l] - else: - loss[tl] = output["loss"][tl] - return loss - - tb_logger.attach( - trainer, - log_handler=OutputHandler( - tag="loss", - metric_names=["loss_g", "loss_d"], - global_step_transform=global_step_transform, - output_transform=output_transform - ), - event_name=Events.ITERATION_COMPLETED(every=50) - ) - tb_logger.attach( + tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform) + if tensorboard_handler is not None: + tensorboard_handler.attach( trainer, log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"), - event_name=Events.ITERATION_STARTED(every=50) + event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar) ) - @trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval)) + @trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image)) def show_images(engine): - tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]["generated"].values()), - engine.state.iteration) - tb_writer.add_image("train/heatmap", make_2d_grid(engine.state.output["img"]["heatmap"].values()), - engine.state.iteration) - - @trainer.on(Events.COMPLETED) - @idist.one_rank_only() - def _(): - # We need to close the logger with we are done - tb_logger.close() + image_a_order = ["real_a", "fake_b", "rec_a", "id_a"] + image_b_order = ["real_b", "fake_a", "rec_b", "id_b"] + tensorboard_handler.writer.add_image( + "train/a", + make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_a_order]), + engine.state.iteration + ) + tensorboard_handler.writer.add_image( + "train/b", + make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_b_order]), + engine.state.iteration + ) return trainer @@ -235,13 +210,16 @@ def run(task, config, logger): assert torch.backends.cudnn.enabled torch.backends.cudnn.benchmark = True logger.info(f"start task {task}") + with read_write(config): + config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size) + if task == "train": train_dataset = data.DATASET.build_with(config.data.train.dataset) logger.info(f"train with dataset:\n{train_dataset}") train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) trainer = get_trainer(config, logger) try: - trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1) + trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader))) except Exception: import traceback print(traceback.format_exc()) diff --git a/model/GAN/UGATIT.py b/model/GAN/UGATIT.py index e02c1cb..5b857fc 100644 --- a/model/GAN/UGATIT.py +++ b/model/GAN/UGATIT.py @@ -90,22 +90,6 @@ class Generator(nn.Module): padding_mode="reflect", bias=False), nn.Tanh()] self.up_decoder = nn.Sequential(*up_decoder) - # self.up_decoder = nn.ModuleDict({ - # "up_1": nn.Upsample(scale_factor=2, mode='nearest'), - # "up_conv_1": nn.Sequential( - # nn.Conv2d(base_channels * 4, base_channels * 4 // 2, kernel_size=3, stride=1, - # padding=1, padding_mode="reflect", bias=False), - # ILN(base_channels * 4 // 2), - # nn.ReLU(True)), - # "up_2": nn.Upsample(scale_factor=2, mode='nearest'), - # "up_conv_2": nn.Sequential( - # nn.Conv2d(base_channels * 2, base_channels * 2 // 2, kernel_size=3, stride=1, - # padding=1, padding_mode="reflect", bias=False), - # ILN(base_channels * 2 // 2), - # nn.ReLU(True)), - # "up_end": nn.Sequential(nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3, - # padding_mode="reflect", bias=False), nn.Tanh()) - # }) def forward(self, x): x = self.down_encoder(x) diff --git a/run.sh b/run.sh index 02ab01c..c93222d 100644 --- a/run.sh +++ b/run.sh @@ -16,5 +16,5 @@ PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch -- CUDA_VISIBLE_DEVICES=$GPUS \ PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \ - main.py "$TASK" "$CONFIG" "$MORE_ARG" --backup_config --setup_output_dir --setup_random_seed + main.py "$TASK" "$CONFIG" $MORE_ARG --backup_config --setup_output_dir --setup_random_seed diff --git a/util/handler.py b/util/handler.py index 15fd41f..447b4be 100644 --- a/util/handler.py +++ b/util/handler.py @@ -5,38 +5,33 @@ import torch import ignite.distributed as idist from ignite.engine import Events, Engine from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan -from ignite.contrib.handlers import BasicTimeProfiler +from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar +from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler -def setup_common_handlers( - trainer: Engine, - output_dir=None, - stop_on_nan=True, - use_profiler=True, - print_interval_event=None, - metrics_to_print=None, - to_save=None, - resume_from=None, - save_interval_event=None, - **checkpoint_kwargs -): +def empty_cuda_cache(_): + torch.cuda.empty_cache() + import gc + + gc.collect() + + +def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True, + to_save=None, metrics_to_print=None, end_event=None): """ Helper method to setup trainer with common handlers. 1. TerminateOnNan 2. BasicTimeProfiler 3. Print 4. Checkpoint - :param trainer: trainer engine. Output of trainer's `update_function` should be a dictionary - or sequence or a single tensor. - :param output_dir: output path to indicate where `to_save` objects are stored. Argument is mutually - :param stop_on_nan: if True, :class:`~ignite.handlers.TerminateOnNan` handler is added to the trainer. + :param trainer: + :param config: + :param stop_on_nan: + :param clear_cuda_cache: :param use_profiler: - :param print_interval_event: - :param metrics_to_print: :param to_save: - :param resume_from: - :param save_interval_event: - :param checkpoint_kwargs: + :param metrics_to_print: + :param end_event: :return: """ @@ -48,28 +43,24 @@ def setup_common_handlers( if stop_on_nan: trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) + if torch.cuda.is_available() and clear_cuda_cache: + trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache) + if use_profiler: # Create an object of the profiler and attach an engine to it profiler = BasicTimeProfiler() profiler.attach(trainer) - @trainer.on(Events.EPOCH_COMPLETED(once=1)) + @trainer.on(Events.EPOCH_COMPLETED(once=1) | Events.COMPLETED) @idist.one_rank_only() def log_intermediate_results(): profiler.print_results(profiler.get_results()) - @trainer.on(Events.COMPLETED) - @idist.one_rank_only() - def _(): - profiler.print_results(profiler.get_results()) - # profiler.write_results(f"{output_dir}/time_profiling.csv") + print_interval_event = Events.ITERATION_COMPLETED(every=config.interval.print_per_iteration) | Events.COMPLETED + + ProgressBar(ncols=0).attach(trainer, "all") if metrics_to_print is not None: - if print_interval_event is None: - raise ValueError( - "If metrics_to_print argument is provided then print_interval_event arguments should be also defined" - ) - @trainer.on(print_interval_event) def print_interval(engine): print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t" @@ -77,19 +68,44 @@ def setup_common_handlers( if m not in engine.state.metrics: continue print_str += f"{m}={engine.state.metrics[m]:.3f} " - engine.logger.info(print_str) + engine.logger.debug(print_str) if to_save is not None: - checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir, require_empty=False), - **checkpoint_kwargs) - if resume_from is not None: + checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False), + n_saved=config.checkpoint.n_saved, filename_prefix=config.name) + if config.resume_from is not None: @trainer.on(Events.STARTED) def resume(engine): - checkpoint_path = Path(resume_from) + checkpoint_path = Path(config.resume_from) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found") ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=ckp) engine.logger.info(f"resume from a checkpoint {checkpoint_path}") - if save_interval_event is not None: - trainer.add_event_handler(save_interval_event, checkpoint_handler) + trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED, + checkpoint_handler) + if end_event is not None: + @trainer.on(end_event) + def terminate(engine): + engine.terminate() + + +def setup_tensorboard_handler(trainer: Engine, config, output_transform): + if config.interval.tensorboard is None: + return None + if idist.get_rank() == 0: + # Create a logger + tb_logger = TensorboardLogger(log_dir=config.output_dir) + tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"), + event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar)) + tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform), + event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar)) + + @trainer.on(Events.COMPLETED) + @idist.one_rank_only() + def _(): + # We need to close the logger with we are done + tb_logger.close() + + return tb_logger + return None diff --git a/util/misc.py b/util/misc.py index eac66f3..8462271 100644 --- a/util/misc.py +++ b/util/misc.py @@ -69,7 +69,7 @@ def setup_logger( if distributed_rank > 0: logger.addHandler(logging.NullHandler()) else: - logger.setLevel(level) + logger.setLevel(logging.DEBUG) ch = logging.StreamHandler() ch.setLevel(level) @@ -78,7 +78,7 @@ def setup_logger( if filepath is not None: fh = logging.FileHandler(filepath) - fh.setLevel(file_level) + fh.setLevel(logging.DEBUG) fh.setFormatter(formatter) logger.addHandler(fh)