diff --git a/configs/synthesizers/UGATIT.yml b/configs/synthesizers/UGATIT.yml index d9e8ad7..55214b1 100644 --- a/configs/synthesizers/UGATIT.yml +++ b/configs/synthesizers/UGATIT.yml @@ -12,13 +12,13 @@ misc: checkpoint: epoch_interval: 1 # one checkpoint every 1 epoch - n_saved: 5 + n_saved: 2 interval: print_per_iteration: 10 # print once per 10 iteration tensorboard: scalar: 10 - image: 1000 + image: 500 model: generator: diff --git a/engine/UGATIT.py b/engine/UGATIT.py index d9bd5de..568758d 100644 --- a/engine/UGATIT.py +++ b/engine/UGATIT.py @@ -19,7 +19,7 @@ from loss.gan import GANLoss from model.weight_init import generation_init_weights from model.GAN.residual_generator import GANImageBuffer from model.GAN.UGATIT import RhoClipper -from util.image import make_2d_grid +from util.image import make_2d_grid, fuse_attention_map from util.handler import setup_common_handlers, setup_tensorboard_handler from util.build import build_model, build_optimizer @@ -190,16 +190,23 @@ def get_trainer(config, logger): @trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image)) def show_images(engine): + output = engine.state.output image_a_order = ["real_a", "fake_b", "rec_a", "id_a"] image_b_order = ["real_b", "fake_a", "rec_b", "id_b"] + + output["img"]["generated"]["real_a"] = fuse_attention_map( + output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"]) + output["img"]["generated"]["real_b"] = fuse_attention_map( + output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"]) + tensorboard_handler.writer.add_image( "train/a", - make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_a_order]), + make_2d_grid([output["img"]["generated"][o] for o in image_a_order]), engine.state.iteration ) tensorboard_handler.writer.add_image( "train/b", - make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_b_order]), + make_2d_grid([output["img"]["generated"][o] for o in image_b_order]), engine.state.iteration ) diff --git a/util/handler.py b/util/handler.py index 447b4be..a211522 100644 --- a/util/handler.py +++ b/util/handler.py @@ -12,7 +12,6 @@ from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, Output def empty_cuda_cache(_): torch.cuda.empty_cache() import gc - gc.collect() @@ -35,6 +34,14 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_ :return: """ + # if train_sampler is not None: + # if not isinstance(train_sampler, DistributedSampler): + # raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method") + # + # @trainer.on(Events.EPOCH_STARTED) + # def distrib_set_epoch(engine): + # train_sampler.set_epoch(engine.state.epoch - 1) + @trainer.on(Events.STARTED) @idist.one_rank_only() def print_dataloader_size(engine): diff --git a/util/image.py b/util/image.py index 126fbe6..41de01d 100644 --- a/util/image.py +++ b/util/image.py @@ -1,4 +1,38 @@ import torchvision.utils +from matplotlib.pyplot import get_cmap +import torch +import warnings +from torch.nn.functional import interpolate + + +def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5): + """ + + :param images: B x H x W + :param attentions: B x Ha x Wa + :param cmap_name: + :param alpha: + :return: + """ + if attentions.size(0) != images.size(0): + warnings.warn(f"attentions: {attentions.size()} and images: {images.size} do not have same batch_size") + return images + if attentions.size(1) != 1: + warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}") + return images + + min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1) + attentions -= min_attentions + attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1) + + if images.size() != attentions.size(): + attentions = interpolate(attentions, images.size()[-2:]) + colored_attentions = torch.zeros_like(images) + cmap = get_cmap(cmap_name) + for i, at in enumerate(attentions): + ca = cmap(at[0].cpu().numpy())[:, :, :3] + colored_attentions[i] = torch.from_numpy(ca).permute(2, 0, 1).view(colored_attentions[i].size()) + return images * alpha + colored_attentions * (1 - alpha) def make_2d_grid(tensors, padding=0, normalize=True, range=None, scale_each=False, pad_value=0):