import torch def euclidean_dist(x, y): """ Compute euclidean distance between two tensors """ # x: B x N x D # y: B x M x D n = x.size(-2) m = y.size(-2) d = x.size(-1) if d != y.size(-1): raise Exception x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D y = y.unsqueeze(1).expand(x.size(0), n, m, d) return torch.pow(x - y, 2).sum(-1)