vectorify

This commit is contained in:
Ray Wong 2020-07-06 13:52:39 +08:00
parent 0c58965641
commit 48aadf0c31

47
test.py
View File

@ -16,18 +16,18 @@ def euclidean_dist(x, y):
""" """
Compute euclidean distance between two tensors Compute euclidean distance between two tensors
""" """
# x: N x D # x: B x N x D
# y: M x D # y: B x M x D
n = x.size(0) n = x.size(-2)
m = y.size(0) m = y.size(-2)
d = x.size(1) d = x.size(-1)
if d != y.size(1): if d != y.size(-1):
raise Exception raise Exception
x = x.unsqueeze(1).expand(n, m, d) x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D
y = y.unsqueeze(0).expand(n, m, d) y = y.unsqueeze(1).expand(x.size(0), n, m, d)
return torch.pow(x - y, 2).sum(2) return torch.pow(x - y, 2).sum(-1)
def evaluate(query, target, support): def evaluate(query, target, support):
@ -37,12 +37,13 @@ def evaluate(query, target, support):
:param support: B x N x K x D vector :param support: B x N x K x D vector
:return: :return:
""" """
K = support.size(1) K = support.size(-2)
prototypes = support.mean(1) prototypes = support.mean(-2) # B x N x D
distance = euclidean_dist(query, prototypes) distance = euclidean_dist(query, prototypes) # B x NK x N
indices = distance.argmin(1) print(distance.shape)
y_hat = torch.tensor([target[i*K-1] for i in indices]) indices = distance.argmin(-1) # B x NK
return torch.eq(target.to("cpu"), y_hat).float().mean() print(indices, target)
return torch.eq(target, indices).float().mean()
class Flatten(nn.Module): class Flatten(nn.Module):
@ -105,6 +106,7 @@ def test():
N = torch.randint(5, 10, (1,)).tolist()[0] N = torch.randint(5, 10, (1,)).tolist()[0]
K = torch.randint(1, 10, (1,)).tolist()[0] K = torch.randint(1, 10, (1,)).tolist()[0]
batch_size = 2
episodic_dataset = dataset.EpisodicDataset( episodic_dataset = dataset.EpisodicDataset(
origin_dataset, # 抽取数据集 origin_dataset, # 抽取数据集
N, # N N, # N
@ -113,17 +115,20 @@ def test():
) )
print(episodic_dataset) print(episodic_dataset)
data_loader = DataLoader(episodic_dataset, batch_size=2) data_loader = DataLoader(episodic_dataset, batch_size=batch_size)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
extractor = make_extractor() extractor = make_extractor()
accs = [] accs = []
for item in data_loader: for item in data_loader:
item = convert_tensor(item, device) item = convert_tensor(item, device)
query_batch = [extractor(images) for images in item["query"]] # item["query"]: B x NK x 3 x W x H
support_batch = [torch.stack(torch.split(extractor(images), K)) for images in item["support"]] # item["support"]: B x NK x 3 x W x H
for i in range(len(query_batch)): # item["target"]: B x NK
accs.append(evaluate(query_batch[i], item["target"][i], support_batch[i])) print(item["support"].shape, item["target"].shape)
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N*K, -1)
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
accs.append(evaluate(query_batch, item["target"], support_batch))
print(torch.tensor(accs).mean().item()) print(torch.tensor(accs).mean().item())