raycv/loss/fewshot/prototypical.py
2020-08-21 16:14:30 +08:00

53 lines
2.1 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
class PrototypicalLoss(nn.Module):
def __init__(self):
super().__init__()
@staticmethod
def acc(query, target, support):
prototypes = support.mean(-2) # batch_size x N_class x D
distance = PrototypicalLoss.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class
indices = distance.argmin(-1) # smallest distance indices
acc = torch.eq(target, indices).float().mean().item()
return acc
@staticmethod
def euclidean_dist(x, y):
# x: B x N x D
# y: B x M x D
assert x.size(-1) == y.size(-1) and x.size(0) == y.size(0)
n = x.size(-2)
m = y.size(-2)
d = x.size(-1)
x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D
y = y.unsqueeze(1).expand(x.size(0), n, m, d)
return torch.pow(x - y, 2).sum(-1) # B x N x M
def forward(self, query, target, support):
"""
calculate prototypical loss
:param query: Tensor - batch_size x N_class*N_query x D
:param target: Tensor - batch_size x N_class*N_query, target id set, value must in [0, N_class)
:param support: Tensor - batch_size x N_class x N_support x D, must be ordered by class id
:return: loss item and accuracy
"""
prototypes = support.mean(-2) # batch_size x N_class x D
distance = self.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class
indices = distance.argmin(-1) # smallest distance indices
acc = torch.eq(target, indices).float().mean().item()
log_p_y = F.log_softmax(-distance, dim=-1)
n_class = support.size(1)
n_query = query.size(1) // n_class
batch_size = support.size(0)
target_log_indices = torch.arange(n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).resharp(
n_class * n_query, 1).view(1, n_class * n_query, 1).expand(batch_size, n_class * n_query, 1)
loss = -log_p_y.gather(2, target_log_indices).mean() # select log-probability of true class then get the mean
return loss, acc