few-shot/loss/prototypical.py
2020-07-23 22:32:28 +08:00

19 lines
412 B
Python
Executable File

import torch
def euclidean_dist(x, y):
"""
Compute euclidean distance between two tensors
"""
# x: B x N x D
# y: B x M x D
n = x.size(-2)
m = y.size(-2)
d = x.size(-1)
if d != y.size(-1):
raise Exception
x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D
y = y.unsqueeze(1).expand(x.size(0), n, m, d)
return torch.pow(x - y, 2).sum(-1)