add attention image fuse
This commit is contained in:
parent
ccc3d7614a
commit
58ed4524bf
@ -12,13 +12,13 @@ misc:
|
|||||||
|
|
||||||
checkpoint:
|
checkpoint:
|
||||||
epoch_interval: 1 # one checkpoint every 1 epoch
|
epoch_interval: 1 # one checkpoint every 1 epoch
|
||||||
n_saved: 5
|
n_saved: 2
|
||||||
|
|
||||||
interval:
|
interval:
|
||||||
print_per_iteration: 10 # print once per 10 iteration
|
print_per_iteration: 10 # print once per 10 iteration
|
||||||
tensorboard:
|
tensorboard:
|
||||||
scalar: 10
|
scalar: 10
|
||||||
image: 1000
|
image: 500
|
||||||
|
|
||||||
model:
|
model:
|
||||||
generator:
|
generator:
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from loss.gan import GANLoss
|
|||||||
from model.weight_init import generation_init_weights
|
from model.weight_init import generation_init_weights
|
||||||
from model.GAN.residual_generator import GANImageBuffer
|
from model.GAN.residual_generator import GANImageBuffer
|
||||||
from model.GAN.UGATIT import RhoClipper
|
from model.GAN.UGATIT import RhoClipper
|
||||||
from util.image import make_2d_grid
|
from util.image import make_2d_grid, fuse_attention_map
|
||||||
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||||
from util.build import build_model, build_optimizer
|
from util.build import build_model, build_optimizer
|
||||||
|
|
||||||
@ -190,16 +190,23 @@ def get_trainer(config, logger):
|
|||||||
|
|
||||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
||||||
def show_images(engine):
|
def show_images(engine):
|
||||||
|
output = engine.state.output
|
||||||
image_a_order = ["real_a", "fake_b", "rec_a", "id_a"]
|
image_a_order = ["real_a", "fake_b", "rec_a", "id_a"]
|
||||||
image_b_order = ["real_b", "fake_a", "rec_b", "id_b"]
|
image_b_order = ["real_b", "fake_a", "rec_b", "id_b"]
|
||||||
|
|
||||||
|
output["img"]["generated"]["real_a"] = fuse_attention_map(
|
||||||
|
output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"])
|
||||||
|
output["img"]["generated"]["real_b"] = fuse_attention_map(
|
||||||
|
output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"])
|
||||||
|
|
||||||
tensorboard_handler.writer.add_image(
|
tensorboard_handler.writer.add_image(
|
||||||
"train/a",
|
"train/a",
|
||||||
make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_a_order]),
|
make_2d_grid([output["img"]["generated"][o] for o in image_a_order]),
|
||||||
engine.state.iteration
|
engine.state.iteration
|
||||||
)
|
)
|
||||||
tensorboard_handler.writer.add_image(
|
tensorboard_handler.writer.add_image(
|
||||||
"train/b",
|
"train/b",
|
||||||
make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_b_order]),
|
make_2d_grid([output["img"]["generated"][o] for o in image_b_order]),
|
||||||
engine.state.iteration
|
engine.state.iteration
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, Output
|
|||||||
def empty_cuda_cache(_):
|
def empty_cuda_cache(_):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
@ -35,6 +34,14 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# if train_sampler is not None:
|
||||||
|
# if not isinstance(train_sampler, DistributedSampler):
|
||||||
|
# raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method")
|
||||||
|
#
|
||||||
|
# @trainer.on(Events.EPOCH_STARTED)
|
||||||
|
# def distrib_set_epoch(engine):
|
||||||
|
# train_sampler.set_epoch(engine.state.epoch - 1)
|
||||||
|
|
||||||
@trainer.on(Events.STARTED)
|
@trainer.on(Events.STARTED)
|
||||||
@idist.one_rank_only()
|
@idist.one_rank_only()
|
||||||
def print_dataloader_size(engine):
|
def print_dataloader_size(engine):
|
||||||
|
|||||||
@ -1,4 +1,38 @@
|
|||||||
import torchvision.utils
|
import torchvision.utils
|
||||||
|
from matplotlib.pyplot import get_cmap
|
||||||
|
import torch
|
||||||
|
import warnings
|
||||||
|
from torch.nn.functional import interpolate
|
||||||
|
|
||||||
|
|
||||||
|
def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param images: B x H x W
|
||||||
|
:param attentions: B x Ha x Wa
|
||||||
|
:param cmap_name:
|
||||||
|
:param alpha:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if attentions.size(0) != images.size(0):
|
||||||
|
warnings.warn(f"attentions: {attentions.size()} and images: {images.size} do not have same batch_size")
|
||||||
|
return images
|
||||||
|
if attentions.size(1) != 1:
|
||||||
|
warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}")
|
||||||
|
return images
|
||||||
|
|
||||||
|
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||||
|
attentions -= min_attentions
|
||||||
|
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||||
|
|
||||||
|
if images.size() != attentions.size():
|
||||||
|
attentions = interpolate(attentions, images.size()[-2:])
|
||||||
|
colored_attentions = torch.zeros_like(images)
|
||||||
|
cmap = get_cmap(cmap_name)
|
||||||
|
for i, at in enumerate(attentions):
|
||||||
|
ca = cmap(at[0].cpu().numpy())[:, :, :3]
|
||||||
|
colored_attentions[i] = torch.from_numpy(ca).permute(2, 0, 1).view(colored_attentions[i].size())
|
||||||
|
return images * alpha + colored_attentions * (1 - alpha)
|
||||||
|
|
||||||
|
|
||||||
def make_2d_grid(tensors, padding=0, normalize=True, range=None, scale_each=False, pad_value=0):
|
def make_2d_grid(tensors, padding=0, normalize=True, range=None, scale_each=False, pad_value=0):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user