import torch import torch.nn as nn import torch.nn.functional as F class PrototypicalLoss(nn.Module): def __init__(self): super().__init__() @staticmethod def acc(query, target, support): prototypes = support.mean(-2) # batch_size x N_class x D distance = PrototypicalLoss.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class indices = distance.argmin(-1) # smallest distance indices acc = torch.eq(target, indices).float().mean().item() return acc @staticmethod def euclidean_dist(x, y): # x: B x N x D # y: B x M x D assert x.size(-1) == y.size(-1) and x.size(0) == y.size(0) n = x.size(-2) m = y.size(-2) d = x.size(-1) x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D y = y.unsqueeze(1).expand(x.size(0), n, m, d) return torch.pow(x - y, 2).sum(-1) # B x N x M def forward(self, query, target, support): """ calculate prototypical loss :param query: Tensor - batch_size x N_class*N_query x D :param target: Tensor - batch_size x N_class*N_query, target id set, value must in [0, N_class) :param support: Tensor - batch_size x N_class x N_support x D, must be ordered by class id :return: loss item and accuracy """ prototypes = support.mean(-2) # batch_size x N_class x D distance = self.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class indices = distance.argmin(-1) # smallest distance indices acc = torch.eq(target, indices).float().mean().item() log_p_y = F.log_softmax(-distance, dim=-1) n_class = support.size(1) n_query = query.size(1) // n_class batch_size = support.size(0) target_log_indices = torch.arange(n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).resharp( n_class * n_query, 1).view(1, n_class * n_query, 1).expand(batch_size, n_class * n_query, 1) loss = -log_p_y.gather(2, target_log_indices).mean() # select log-probability of true class then get the mean return loss, acc