From a89f9226e834b9a5535042e9ba3aeab4564cc095 Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Mon, 6 Jul 2020 15:13:46 +0800 Subject: [PATCH] v1 --- test.py | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/test.py b/test.py index c940296..5831a40 100755 --- a/test.py +++ b/test.py @@ -4,6 +4,7 @@ import torchvision from data import dataset import torch.nn as nn from ignite.utils import convert_tensor +import time def setup_seed(seed): @@ -37,12 +38,9 @@ def evaluate(query, target, support): :param support: B x N x K x D vector :return: """ - K = support.size(-2) prototypes = support.mean(-2) # B x N x D distance = euclidean_dist(query, prototypes) # B x NK x N - print(distance.shape) indices = distance.argmin(-1) # B x NK - print(indices, target) return torch.eq(target, indices).float().mean() @@ -54,29 +52,29 @@ class Flatten(nn.Module): return x.view(x.size(0), -1) -# def make_extractor(): -# resnet50 = torchvision.models.resnet50(pretrained=True) -# resnet50.to(torch.device("cuda")) -# resnet50.fc = torch.nn.Identity() -# resnet50.eval() -# -# def extract(images): -# with torch.no_grad(): -# return resnet50(images) -# return extract - - def make_extractor(): - model = resnet18() - model.to(torch.device("cuda")) - model.eval() + resnet50 = torchvision.models.resnet50(pretrained=True) + resnet50.to(torch.device("cuda")) + resnet50.fc = torch.nn.Identity() + resnet50.eval() def extract(images): with torch.no_grad(): - return model(images) + return resnet50(images) return extract +# def make_extractor(): +# model = resnet18() +# model.to(torch.device("cuda")) +# model.eval() +# +# def extract(images): +# with torch.no_grad(): +# return model(images) +# return extract + + def resnet18(model_path="ResNet18Official.pth"): """Constructs a ResNet-18 model. Args: @@ -115,20 +113,22 @@ def test(): ) print(episodic_dataset) - data_loader = DataLoader(episodic_dataset, batch_size=batch_size) + data_loader = DataLoader(episodic_dataset, batch_size=batch_size, pin_memory=True) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") extractor = make_extractor() accs = [] + st = time.time() for item in data_loader: - item = convert_tensor(item, device) + item = convert_tensor(item, device, non_blocking=True) # item["query"]: B x NK x 3 x W x H # item["support"]: B x NK x 3 x W x H # item["target"]: B x NK - print(item["support"].shape, item["target"].shape) query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N*K, -1) support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) accs.append(evaluate(query_batch, item["target"], support_batch)) + print("time: ", time.time()-st) + st = time.time() print(torch.tensor(accs).mean().item())