57 lines
2.3 KiB
Python
57 lines
2.3 KiB
Python
import torchvision.utils
|
|
import torch
|
|
import warnings
|
|
import numpy as np
|
|
import cv2
|
|
|
|
|
|
def attention_colored_map(attentions, size=None):
|
|
assert attentions.dim() == 4 and attentions.size(1) == 1
|
|
device = attentions.device
|
|
|
|
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
|
attentions -= min_attentions
|
|
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
|
|
|
attentions = attentions.detach().cpu().numpy()
|
|
attentions = (attentions * 255).astype(np.uint8)
|
|
need_resize = False
|
|
if size is not None and attentions.shape[-2:] != size:
|
|
assert len(size) == 2, "for interpolate, size must be (x, y), have two dim"
|
|
need_resize = True
|
|
|
|
subs = []
|
|
for sub in attentions:
|
|
sub = cv2.resize(sub[0], size) if need_resize else sub[0] # numpy.array shape=size
|
|
subs.append(cv2.applyColorMap(sub, cv2.COLORMAP_JET)) # append a (size[0], size[1], 3) numpy array
|
|
subs = np.stack(subs) # (batch_size, size[0], size[1], 3)
|
|
return torch.from_numpy(subs).permute(0, 3, 1, 2).contiguous().to(device).float() / 255
|
|
|
|
|
|
def fuse_attention_map(images, attentions, alpha=0.5):
|
|
"""
|
|
|
|
:param images: B x H x W
|
|
:param attentions: B x Ha x Wa
|
|
:param cmap_name:
|
|
:param alpha:
|
|
:return:
|
|
"""
|
|
if attentions.size(0) != images.size(0):
|
|
warnings.warn(f"attentions: {attentions.size()} and images: {images.size} do not have same batch_size")
|
|
return images
|
|
if attentions.size(1) != 1:
|
|
warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}")
|
|
return images
|
|
colored_attentions = attention_colored_map(attentions, images.size()[-2:])
|
|
return images * alpha + colored_attentions * (1 - alpha)
|
|
|
|
|
|
def make_2d_grid(tensors, padding=0, normalize=True, range=None, scale_each=False, pad_value=0):
|
|
# merge image in a batch in `y` direction first.
|
|
grids = [torchvision.utils.make_grid(img_batch, padding=padding, nrow=1, normalize=normalize, range=range,
|
|
scale_each=scale_each, pad_value=pad_value)
|
|
for img_batch in tensors]
|
|
# merge images in `x` direction.
|
|
return torchvision.utils.make_grid(grids, padding=0, nrow=len(grids))
|