19 lines
412 B
Python
Executable File
19 lines
412 B
Python
Executable File
import torch
|
|
|
|
|
|
def euclidean_dist(x, y):
|
|
"""
|
|
Compute euclidean distance between two tensors
|
|
"""
|
|
# x: B x N x D
|
|
# y: B x M x D
|
|
n = x.size(-2)
|
|
m = y.size(-2)
|
|
d = x.size(-1)
|
|
if d != y.size(-1):
|
|
raise Exception
|
|
|
|
x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D
|
|
y = y.unsqueeze(1).expand(x.size(0), n, m, d)
|
|
|
|
return torch.pow(x - y, 2).sum(-1) |