53 lines
2.1 KiB
Python
53 lines
2.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class PrototypicalLoss(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@staticmethod
|
|
def acc(query, target, support):
|
|
prototypes = support.mean(-2) # batch_size x N_class x D
|
|
distance = PrototypicalLoss.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class
|
|
indices = distance.argmin(-1) # smallest distance indices
|
|
acc = torch.eq(target, indices).float().mean().item()
|
|
return acc
|
|
|
|
@staticmethod
|
|
def euclidean_dist(x, y):
|
|
# x: B x N x D
|
|
# y: B x M x D
|
|
assert x.size(-1) == y.size(-1) and x.size(0) == y.size(0)
|
|
n = x.size(-2)
|
|
m = y.size(-2)
|
|
d = x.size(-1)
|
|
x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D
|
|
y = y.unsqueeze(1).expand(x.size(0), n, m, d)
|
|
return torch.pow(x - y, 2).sum(-1) # B x N x M
|
|
|
|
def forward(self, query, target, support):
|
|
"""
|
|
calculate prototypical loss
|
|
:param query: Tensor - batch_size x N_class*N_query x D
|
|
:param target: Tensor - batch_size x N_class*N_query, target id set, value must in [0, N_class)
|
|
:param support: Tensor - batch_size x N_class x N_support x D, must be ordered by class id
|
|
:return: loss item and accuracy
|
|
"""
|
|
|
|
prototypes = support.mean(-2) # batch_size x N_class x D
|
|
distance = self.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class
|
|
indices = distance.argmin(-1) # smallest distance indices
|
|
acc = torch.eq(target, indices).float().mean().item()
|
|
|
|
log_p_y = F.log_softmax(-distance, dim=-1)
|
|
n_class = support.size(1)
|
|
n_query = query.size(1) // n_class
|
|
batch_size = support.size(0)
|
|
target_log_indices = torch.arange(n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).resharp(
|
|
n_class * n_query, 1).view(1, n_class * n_query, 1).expand(batch_size, n_class * n_query, 1)
|
|
loss = -log_p_y.gather(2, target_log_indices).mean() # select log-probability of true class then get the mean
|
|
|
|
return loss, acc
|