From 48aadf0c3103a87ce8e0fb0b5131d2ecd57d1419 Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Mon, 6 Jul 2020 13:52:39 +0800 Subject: [PATCH] vectorify --- test.py | 47 ++++++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/test.py b/test.py index 9a45c8d..c940296 100755 --- a/test.py +++ b/test.py @@ -16,18 +16,18 @@ def euclidean_dist(x, y): """ Compute euclidean distance between two tensors """ - # x: N x D - # y: M x D - n = x.size(0) - m = y.size(0) - d = x.size(1) - if d != y.size(1): + # x: B x N x D + # y: B x M x D + n = x.size(-2) + m = y.size(-2) + d = x.size(-1) + if d != y.size(-1): raise Exception - x = x.unsqueeze(1).expand(n, m, d) - y = y.unsqueeze(0).expand(n, m, d) + x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D + y = y.unsqueeze(1).expand(x.size(0), n, m, d) - return torch.pow(x - y, 2).sum(2) + return torch.pow(x - y, 2).sum(-1) def evaluate(query, target, support): @@ -37,12 +37,13 @@ def evaluate(query, target, support): :param support: B x N x K x D vector :return: """ - K = support.size(1) - prototypes = support.mean(1) - distance = euclidean_dist(query, prototypes) - indices = distance.argmin(1) - y_hat = torch.tensor([target[i*K-1] for i in indices]) - return torch.eq(target.to("cpu"), y_hat).float().mean() + K = support.size(-2) + prototypes = support.mean(-2) # B x N x D + distance = euclidean_dist(query, prototypes) # B x NK x N + print(distance.shape) + indices = distance.argmin(-1) # B x NK + print(indices, target) + return torch.eq(target, indices).float().mean() class Flatten(nn.Module): @@ -105,6 +106,7 @@ def test(): N = torch.randint(5, 10, (1,)).tolist()[0] K = torch.randint(1, 10, (1,)).tolist()[0] + batch_size = 2 episodic_dataset = dataset.EpisodicDataset( origin_dataset, # 抽取数据集 N, # N @@ -113,17 +115,20 @@ def test(): ) print(episodic_dataset) - data_loader = DataLoader(episodic_dataset, batch_size=2) + data_loader = DataLoader(episodic_dataset, batch_size=batch_size) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") extractor = make_extractor() accs = [] for item in data_loader: - item = convert_tensor(item, device) - query_batch = [extractor(images) for images in item["query"]] - support_batch = [torch.stack(torch.split(extractor(images), K)) for images in item["support"]] - for i in range(len(query_batch)): - accs.append(evaluate(query_batch[i], item["target"][i], support_batch[i])) + # item["query"]: B x NK x 3 x W x H + # item["support"]: B x NK x 3 x W x H + # item["target"]: B x NK + print(item["support"].shape, item["target"].shape) + query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N*K, -1) + support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) + + accs.append(evaluate(query_batch, item["target"], support_batch)) print(torch.tensor(accs).mean().item())