v1
This commit is contained in:
parent
48aadf0c31
commit
a89f9226e8
44
test.py
44
test.py
@ -4,6 +4,7 @@ import torchvision
|
|||||||
from data import dataset
|
from data import dataset
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from ignite.utils import convert_tensor
|
from ignite.utils import convert_tensor
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
def setup_seed(seed):
|
def setup_seed(seed):
|
||||||
@ -37,12 +38,9 @@ def evaluate(query, target, support):
|
|||||||
:param support: B x N x K x D vector
|
:param support: B x N x K x D vector
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
K = support.size(-2)
|
|
||||||
prototypes = support.mean(-2) # B x N x D
|
prototypes = support.mean(-2) # B x N x D
|
||||||
distance = euclidean_dist(query, prototypes) # B x NK x N
|
distance = euclidean_dist(query, prototypes) # B x NK x N
|
||||||
print(distance.shape)
|
|
||||||
indices = distance.argmin(-1) # B x NK
|
indices = distance.argmin(-1) # B x NK
|
||||||
print(indices, target)
|
|
||||||
return torch.eq(target, indices).float().mean()
|
return torch.eq(target, indices).float().mean()
|
||||||
|
|
||||||
|
|
||||||
@ -54,29 +52,29 @@ class Flatten(nn.Module):
|
|||||||
return x.view(x.size(0), -1)
|
return x.view(x.size(0), -1)
|
||||||
|
|
||||||
|
|
||||||
# def make_extractor():
|
|
||||||
# resnet50 = torchvision.models.resnet50(pretrained=True)
|
|
||||||
# resnet50.to(torch.device("cuda"))
|
|
||||||
# resnet50.fc = torch.nn.Identity()
|
|
||||||
# resnet50.eval()
|
|
||||||
#
|
|
||||||
# def extract(images):
|
|
||||||
# with torch.no_grad():
|
|
||||||
# return resnet50(images)
|
|
||||||
# return extract
|
|
||||||
|
|
||||||
|
|
||||||
def make_extractor():
|
def make_extractor():
|
||||||
model = resnet18()
|
resnet50 = torchvision.models.resnet50(pretrained=True)
|
||||||
model.to(torch.device("cuda"))
|
resnet50.to(torch.device("cuda"))
|
||||||
model.eval()
|
resnet50.fc = torch.nn.Identity()
|
||||||
|
resnet50.eval()
|
||||||
|
|
||||||
def extract(images):
|
def extract(images):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return model(images)
|
return resnet50(images)
|
||||||
return extract
|
return extract
|
||||||
|
|
||||||
|
|
||||||
|
# def make_extractor():
|
||||||
|
# model = resnet18()
|
||||||
|
# model.to(torch.device("cuda"))
|
||||||
|
# model.eval()
|
||||||
|
#
|
||||||
|
# def extract(images):
|
||||||
|
# with torch.no_grad():
|
||||||
|
# return model(images)
|
||||||
|
# return extract
|
||||||
|
|
||||||
|
|
||||||
def resnet18(model_path="ResNet18Official.pth"):
|
def resnet18(model_path="ResNet18Official.pth"):
|
||||||
"""Constructs a ResNet-18 model.
|
"""Constructs a ResNet-18 model.
|
||||||
Args:
|
Args:
|
||||||
@ -115,20 +113,22 @@ def test():
|
|||||||
)
|
)
|
||||||
print(episodic_dataset)
|
print(episodic_dataset)
|
||||||
|
|
||||||
data_loader = DataLoader(episodic_dataset, batch_size=batch_size)
|
data_loader = DataLoader(episodic_dataset, batch_size=batch_size, pin_memory=True)
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
extractor = make_extractor()
|
extractor = make_extractor()
|
||||||
accs = []
|
accs = []
|
||||||
|
st = time.time()
|
||||||
for item in data_loader:
|
for item in data_loader:
|
||||||
item = convert_tensor(item, device)
|
item = convert_tensor(item, device, non_blocking=True)
|
||||||
# item["query"]: B x NK x 3 x W x H
|
# item["query"]: B x NK x 3 x W x H
|
||||||
# item["support"]: B x NK x 3 x W x H
|
# item["support"]: B x NK x 3 x W x H
|
||||||
# item["target"]: B x NK
|
# item["target"]: B x NK
|
||||||
print(item["support"].shape, item["target"].shape)
|
|
||||||
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N*K, -1)
|
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N*K, -1)
|
||||||
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
|
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
|
||||||
|
|
||||||
accs.append(evaluate(query_batch, item["target"], support_batch))
|
accs.append(evaluate(query_batch, item["target"], support_batch))
|
||||||
|
print("time: ", time.time()-st)
|
||||||
|
st = time.time()
|
||||||
print(torch.tensor(accs).mean().item())
|
print(torch.tensor(accs).mean().item())
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user