raycv/util/image.py

45 lines
1.8 KiB
Python

import torchvision.utils
from matplotlib.pyplot import get_cmap
import torch
import warnings
from torch.nn.functional import interpolate
def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
"""
:param images: B x H x W
:param attentions: B x Ha x Wa
:param cmap_name:
:param alpha:
:return:
"""
if attentions.size(0) != images.size(0):
warnings.warn(f"attentions: {attentions.size()} and images: {images.size} do not have same batch_size")
return images
if attentions.size(1) != 1:
warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}")
return images
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
attentions -= min_attentions
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
if images.size() != attentions.size():
attentions = interpolate(attentions, images.size()[-2:])
colored_attentions = torch.zeros_like(images)
cmap = get_cmap(cmap_name)
for i, at in enumerate(attentions):
ca = cmap(at[0].cpu().numpy())[:, :, :3]
colored_attentions[i] = torch.from_numpy(ca).permute(2, 0, 1).view(colored_attentions[i].size())
return images * alpha + colored_attentions * (1 - alpha)
def make_2d_grid(tensors, padding=0, normalize=True, range=None, scale_each=False, pad_value=0):
# merge image in a batch in `y` direction first.
grids = [torchvision.utils.make_grid(img_batch, padding=padding, nrow=1, normalize=normalize, range=range,
scale_each=scale_each, pad_value=pad_value)
for img_batch in tensors]
# merge images in `x` direction.
return torchvision.utils.make_grid(grids, padding=0, nrow=len(grids))