vectorify
This commit is contained in:
parent
0c58965641
commit
48aadf0c31
47
test.py
47
test.py
@ -16,18 +16,18 @@ def euclidean_dist(x, y):
|
|||||||
"""
|
"""
|
||||||
Compute euclidean distance between two tensors
|
Compute euclidean distance between two tensors
|
||||||
"""
|
"""
|
||||||
# x: N x D
|
# x: B x N x D
|
||||||
# y: M x D
|
# y: B x M x D
|
||||||
n = x.size(0)
|
n = x.size(-2)
|
||||||
m = y.size(0)
|
m = y.size(-2)
|
||||||
d = x.size(1)
|
d = x.size(-1)
|
||||||
if d != y.size(1):
|
if d != y.size(-1):
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
x = x.unsqueeze(1).expand(n, m, d)
|
x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D
|
||||||
y = y.unsqueeze(0).expand(n, m, d)
|
y = y.unsqueeze(1).expand(x.size(0), n, m, d)
|
||||||
|
|
||||||
return torch.pow(x - y, 2).sum(2)
|
return torch.pow(x - y, 2).sum(-1)
|
||||||
|
|
||||||
|
|
||||||
def evaluate(query, target, support):
|
def evaluate(query, target, support):
|
||||||
@ -37,12 +37,13 @@ def evaluate(query, target, support):
|
|||||||
:param support: B x N x K x D vector
|
:param support: B x N x K x D vector
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
K = support.size(1)
|
K = support.size(-2)
|
||||||
prototypes = support.mean(1)
|
prototypes = support.mean(-2) # B x N x D
|
||||||
distance = euclidean_dist(query, prototypes)
|
distance = euclidean_dist(query, prototypes) # B x NK x N
|
||||||
indices = distance.argmin(1)
|
print(distance.shape)
|
||||||
y_hat = torch.tensor([target[i*K-1] for i in indices])
|
indices = distance.argmin(-1) # B x NK
|
||||||
return torch.eq(target.to("cpu"), y_hat).float().mean()
|
print(indices, target)
|
||||||
|
return torch.eq(target, indices).float().mean()
|
||||||
|
|
||||||
|
|
||||||
class Flatten(nn.Module):
|
class Flatten(nn.Module):
|
||||||
@ -105,6 +106,7 @@ def test():
|
|||||||
|
|
||||||
N = torch.randint(5, 10, (1,)).tolist()[0]
|
N = torch.randint(5, 10, (1,)).tolist()[0]
|
||||||
K = torch.randint(1, 10, (1,)).tolist()[0]
|
K = torch.randint(1, 10, (1,)).tolist()[0]
|
||||||
|
batch_size = 2
|
||||||
episodic_dataset = dataset.EpisodicDataset(
|
episodic_dataset = dataset.EpisodicDataset(
|
||||||
origin_dataset, # 抽取数据集
|
origin_dataset, # 抽取数据集
|
||||||
N, # N
|
N, # N
|
||||||
@ -113,17 +115,20 @@ def test():
|
|||||||
)
|
)
|
||||||
print(episodic_dataset)
|
print(episodic_dataset)
|
||||||
|
|
||||||
data_loader = DataLoader(episodic_dataset, batch_size=2)
|
data_loader = DataLoader(episodic_dataset, batch_size=batch_size)
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
extractor = make_extractor()
|
extractor = make_extractor()
|
||||||
accs = []
|
accs = []
|
||||||
for item in data_loader:
|
for item in data_loader:
|
||||||
|
|
||||||
item = convert_tensor(item, device)
|
item = convert_tensor(item, device)
|
||||||
query_batch = [extractor(images) for images in item["query"]]
|
# item["query"]: B x NK x 3 x W x H
|
||||||
support_batch = [torch.stack(torch.split(extractor(images), K)) for images in item["support"]]
|
# item["support"]: B x NK x 3 x W x H
|
||||||
for i in range(len(query_batch)):
|
# item["target"]: B x NK
|
||||||
accs.append(evaluate(query_batch[i], item["target"][i], support_batch[i]))
|
print(item["support"].shape, item["target"].shape)
|
||||||
|
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N*K, -1)
|
||||||
|
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
|
||||||
|
|
||||||
|
accs.append(evaluate(query_batch, item["target"], support_batch))
|
||||||
print(torch.tensor(accs).mean().item())
|
print(torch.tensor(accs).mean().item())
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user