import torchvision.utils from matplotlib.pyplot import get_cmap import torch import warnings from torch.nn.functional import interpolate def attention_colored_map(attentions, size=None, cmap_name="jet"): assert attentions.dim() == 4 and attentions.size(1) == 1 min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1) attentions -= min_attentions attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1) if size is not None and attentions.size()[-2:] != size: assert len(size) == 2, "for interpolate, size must be (x, y), have two dim" attentions = interpolate(attentions, size, mode="bilinear", align_corners=False) cmap = get_cmap(cmap_name) ca = cmap(attentions.squeeze(1).cpu())[:, :, :, :3] return torch.from_numpy(ca).permute(0, 3, 1, 2).contiguous() def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5): """ :param images: B x H x W :param attentions: B x Ha x Wa :param cmap_name: :param alpha: :return: """ if attentions.size(0) != images.size(0): warnings.warn(f"attentions: {attentions.size()} and images: {images.size} do not have same batch_size") return images if attentions.size(1) != 1: warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}") return images colored_attentions = attention_colored_map(attentions, images.size()[-2:], cmap_name).to(images.device) return images * alpha + colored_attentions * (1 - alpha) def make_2d_grid(tensors, padding=0, normalize=True, range=None, scale_each=False, pad_value=0): # merge image in a batch in `y` direction first. grids = [torchvision.utils.make_grid(img_batch, padding=padding, nrow=1, normalize=normalize, range=range, scale_each=scale_each, pad_value=pad_value) for img_batch in tensors] # merge images in `x` direction. return torchvision.utils.make_grid(grids, padding=0, nrow=len(grids))