From 31aafb347041707f59ce52e198712c06749d7a92 Mon Sep 17 00:00:00 2001 From: budui Date: Mon, 24 Aug 2020 06:51:42 +0800 Subject: [PATCH] UGATIT version 0.1 --- configs/synthesizers/UGATIT.yml | 2 +- engine/UGATIT.py | 138 +++++++++++++++----------------- util/handler.py | 31 +++---- util/image.py | 28 ++++--- 4 files changed, 90 insertions(+), 109 deletions(-) diff --git a/configs/synthesizers/UGATIT.yml b/configs/synthesizers/UGATIT.yml index a92a541..05d5311 100644 --- a/configs/synthesizers/UGATIT.yml +++ b/configs/synthesizers/UGATIT.yml @@ -1,4 +1,4 @@ -name: selfie2anime-origin +name: selfie2anime engine: UGATIT result_dir: ./result max_pairs: 1000000 diff --git a/engine/UGATIT.py b/engine/UGATIT.py index 518cbfe..6eb96e7 100644 --- a/engine/UGATIT.py +++ b/engine/UGATIT.py @@ -4,7 +4,6 @@ from math import ceil import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import DataLoader import ignite.distributed as idist from ignite.engine import Events, Engine @@ -20,11 +19,28 @@ from loss.gan import GANLoss from model.weight_init import generation_init_weights from model.GAN.residual_generator import GANImageBuffer from model.GAN.UGATIT import RhoClipper -from util.image import make_2d_grid, fuse_attention_map +from util.image import make_2d_grid, fuse_attention_map, attention_colored_map from util.handler import setup_common_handlers, setup_tensorboard_handler from util.build import build_model, build_optimizer +def build_lr_schedulers(optimizers, config): + g_milestones_values = [ + (0, config.optimizers.generator.lr), + (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr), + (config.max_iteration, config.data.train.scheduler.target_lr) + ] + d_milestones_values = [ + (0, config.optimizers.discriminator.lr), + (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr), + (config.max_iteration, config.data.train.scheduler.target_lr) + ] + return dict( + g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values), + d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values) + ) + + def get_trainer(config, logger): generators = dict( a2b=build_model(config.model.generator, config.distributed.model), @@ -42,23 +58,14 @@ def get_trainer(config, logger): logger.debug(discriminators["ga"]) logger.debug(generators["a2b"]) - optimizer_g = build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator) - optimizer_d = build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), - config.optimizers.discriminator) + optimizers = dict( + g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator), + d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator), + ) + logger.info(f"build optimizers:\n{optimizers}") - milestones_values = [ - (0, config.optimizers.generator.lr), - (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr), - (config.max_iteration, config.data.train.scheduler.target_lr) - ] - lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values) - - milestones_values = [ - (0, config.optimizers.discriminator.lr), - (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr), - (config.max_iteration, config.data.train.scheduler.target_lr) - ] - lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values) + lr_schedulers = build_lr_schedulers(optimizers, config) + logger.info(f"build lr_schedulers:\n{lr_schedulers}") gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") @@ -116,26 +123,26 @@ def get_trainer(config, logger): identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"]) identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"]) - optimizer_g.zero_grad() + optimizers["g"].zero_grad() loss_g = dict() for n in ["a", "b"]: loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n], cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n])) sum(loss_g.values()).backward() - optimizer_g.step() + optimizers["g"].step() for generator in generators.values(): generator.apply(rho_clipper) for discriminator in discriminators.values(): discriminator.requires_grad_(True) - optimizer_d.zero_grad() + optimizers["d"].zero_grad() loss_d = dict() for k in discriminators.keys(): n = k[-1] # "a" or "b" loss_d.update( criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach()))) sum(loss_d.values()).backward() - optimizer_d.step() + optimizers["d"].step() for h in heatmap: heatmap[h] = heatmap[h].detach() @@ -157,19 +164,19 @@ def get_trainer(config, logger): trainer = Engine(_step) trainer.logger = logger - trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g) - trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d) + for lr_shd in lr_schedulers.values(): + trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd) RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") - to_save = dict(optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d, - lr_scheduler_g=lr_scheduler_g) + to_save = dict(trainer=trainer) + to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers}) + to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) to_save.update({f"generator_{k}": generators[k] for k in generators}) to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators}) - - setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"], - clear_cuda_cache=False, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) + setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, + end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) def output_transform(output): loss = dict() @@ -185,46 +192,36 @@ def get_trainer(config, logger): if tensorboard_handler is not None: tensorboard_handler.attach( trainer, - log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"), + log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"), event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar) ) @trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image)) def show_images(engine): output = engine.state.output - image_a_order = ["real_a", "fake_b", "rec_a", "id_a"] - image_b_order = ["real_b", "fake_a", "rec_b", "id_b"] - + image_order = dict( + a=["real_a", "fake_b", "rec_a", "id_a"], + b=["real_b", "fake_a", "rec_b", "id_b"] + ) output["img"]["generated"]["real_a"] = fuse_attention_map( output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"]) output["img"]["generated"]["real_b"] = fuse_attention_map( output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"]) - - tensorboard_handler.writer.add_image( - "train/a", - make_2d_grid([output["img"]["generated"][o] for o in image_a_order]), - engine.state.iteration - ) - tensorboard_handler.writer.add_image( - "train/b", - make_2d_grid([output["img"]["generated"][o] for o in image_b_order]), - engine.state.iteration - ) + for k in "ab": + tensorboard_handler.writer.add_image( + f"train/{k}", + make_2d_grid([output["img"]["generated"][o] for o in image_order[k]]), + engine.state.iteration + ) with torch.no_grad(): g = torch.Generator() g.manual_seed(config.misc.random_seed) indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10] - - empty_grid = torch.zeros(0, config.model.generator.in_channels, config.model.generator.img_size, - config.model.generator.img_size) - fake = dict(a=empty_grid.clone(), b=empty_grid.clone()) - rec = dict(a=empty_grid.clone(), b=empty_grid.clone()) - heatmap = dict(a2b=torch.zeros(0, 1, config.model.generator.img_size, - config.model.generator.img_size), - b2a=torch.zeros(0, 1, config.model.generator.img_size, - config.model.generator.img_size)) - real = dict(a=empty_grid.clone(), b=empty_grid.clone()) + test_images = dict( + a=[[], [], [], []], + b=[[], [], [], []] + ) for i in indices: batch = convert_tensor(engine.state.test_dataset[i], idist.device()) @@ -234,27 +231,18 @@ def get_trainer(config, logger): rec_a = generators["b2a"](fake_b)[0] rec_b = generators["a2b"](fake_a)[0] - fake["a"] = torch.cat([fake["a"], fake_a.cpu()]) - fake["b"] = torch.cat([fake["b"], fake_b.cpu()]) - real["a"] = torch.cat([real["a"], real_a.cpu()]) - real["b"] = torch.cat([real["b"], real_b.cpu()]) - rec["a"] = torch.cat([rec["a"], rec_a.cpu()]) - rec["b"] = torch.cat([rec["b"], rec_b.cpu()]) - - heatmap["a2b"] = torch.cat( - [heatmap["a2b"], torch.nn.functional.interpolate(heatmap_a2b, real_a.size()[-2:]).cpu()]) - heatmap["b2a"] = torch.cat( - [heatmap["b2a"], torch.nn.functional.interpolate(heatmap_b2a, real_a.size()[-2:]).cpu()]) - tensorboard_handler.writer.add_image( - "test/a", - make_2d_grid([heatmap["a2b"].expand_as(real["a"]), real["a"], fake["b"], rec["a"]]), - engine.state.iteration - ) - tensorboard_handler.writer.add_image( - "test/b", - make_2d_grid([heatmap["b2a"].expand_as(real["a"]), real["b"], fake["a"], rec["b"]]), - engine.state.iteration - ) + for idx, im in enumerate( + [attention_colored_map(heatmap_a2b, real_a.size()[-2:]), real_a, fake_b, rec_a]): + test_images["a"][idx].append(im.cpu()) + for idx, im in enumerate( + [attention_colored_map(heatmap_b2a, real_b.size()[-2:]), real_b, fake_a, rec_b]): + test_images["b"][idx].append(im.cpu()) + for n in "ab": + tensorboard_handler.writer.add_image( + f"test/{n}", + make_2d_grid([torch.cat(ti) for ti in test_images[n]]), + engine.state.iteration + ) return trainer diff --git a/util/handler.py b/util/handler.py index 7c5e868..fc3df75 100644 --- a/util/handler.py +++ b/util/handler.py @@ -17,7 +17,7 @@ def empty_cuda_cache(_): def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True, - to_save=None, metrics_to_print=None, end_event=None): + to_save=None, end_event=None, set_epoch_for_dist_sampler=True): """ Helper method to setup trainer with common handlers. 1. TerminateOnNan @@ -30,21 +30,21 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_ :param clear_cuda_cache: :param use_profiler: :param to_save: - :param metrics_to_print: :param end_event: + :param set_epoch_for_dist_sampler: :return: """ - - if isinstance(trainer.state.dataloader.sampler, DistributedSampler): + if set_epoch_for_dist_sampler: @trainer.on(Events.EPOCH_STARTED) def distrib_set_epoch(engine): - trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler") - trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1) + if isinstance(trainer.state.dataloader.sampler, DistributedSampler): + trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler") + trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1) - @trainer.on(Events.STARTED) - @idist.one_rank_only() - def print_dataloader_size(engine): + @trainer.on(Events.STARTED | Events.EPOCH_COMPLETED(once=1)) + def print_info(engine): engine.logger.info(f"data loader length: {len(engine.state.dataloader)}") + engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}") if stop_on_nan: trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) @@ -62,20 +62,8 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_ def log_intermediate_results(): profiler.print_results(profiler.get_results()) - print_interval_event = Events.ITERATION_COMPLETED(every=config.interval.print_per_iteration) | Events.COMPLETED - ProgressBar(ncols=0).attach(trainer, "all") - if metrics_to_print is not None: - @trainer.on(print_interval_event) - def print_interval(engine): - print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t" - for m in metrics_to_print: - if m not in engine.state.metrics: - continue - print_str += f"{m}={engine.state.metrics[m]:.3f} " - engine.logger.debug(print_str) - if to_save is not None: checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False), n_saved=config.checkpoint.n_saved, filename_prefix=config.name) @@ -86,6 +74,7 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_ if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found") ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu") + trainer.logger.info(f"load state_dict for {ckp.keys()}") Checkpoint.load_objects(to_load=to_save, checkpoint=ckp) engine.logger.info(f"resume from a checkpoint {checkpoint_path}") trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED, diff --git a/util/image.py b/util/image.py index 41de01d..524d382 100644 --- a/util/image.py +++ b/util/image.py @@ -5,6 +5,21 @@ import warnings from torch.nn.functional import interpolate +def attention_colored_map(attentions, size=None, cmap_name="jet"): + assert attentions.dim() == 4 and attentions.size(1) == 1 + + min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1) + attentions -= min_attentions + attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1) + + if size is not None and attentions.size()[-2:] != size: + assert len(size) == 2, "for interpolate, size must be (x, y), have two dim" + attentions = interpolate(attentions, size, mode="bilinear", align_corners=False) + cmap = get_cmap(cmap_name) + ca = cmap(attentions.squeeze(1).cpu())[:, :, :, :3] + return torch.from_numpy(ca).permute(0, 3, 1, 2).contiguous() + + def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5): """ @@ -20,18 +35,7 @@ def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5): if attentions.size(1) != 1: warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}") return images - - min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1) - attentions -= min_attentions - attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1) - - if images.size() != attentions.size(): - attentions = interpolate(attentions, images.size()[-2:]) - colored_attentions = torch.zeros_like(images) - cmap = get_cmap(cmap_name) - for i, at in enumerate(attentions): - ca = cmap(at[0].cpu().numpy())[:, :, :3] - colored_attentions[i] = torch.from_numpy(ca).permute(2, 0, 1).view(colored_attentions[i].size()) + colored_attentions = attention_colored_map(attentions, images.size()[-2:], cmap_name).to(images.device) return images * alpha + colored_attentions * (1 - alpha)