This commit is contained in:
Ray Wong 2020-07-16 16:07:03 +08:00
parent 598bd9e0f1
commit 7d720c181b
3 changed files with 7 additions and 7 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
*.pth *.pth
.idea/ .idea/
submit/

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
def dataset_to_lmdb(dataset, lmdb_path): def dataset_to_lmdb(dataset, lmdb_path):
env = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path)) env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn: with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=50): for i in tqdm(range(len(dataset)), ncols=50):
buffer = BytesIO() buffer = BytesIO()

View File

@ -73,7 +73,7 @@ def test(lmdb_path, import_path):
with torch.no_grad(): with torch.no_grad():
for item in data_loader: for item in data_loader:
st = time.time() st = time.time()
print("load", time.time() - load_st) # print("load", time.time() - load_st)
item = convert_tensor(item, device, non_blocking=True) item = convert_tensor(item, device, non_blocking=True)
# item["query"]: B x NK x 3 x W x H # item["query"]: B x NK x 3 x W x H
# item["support"]: B x NK x 3 x W x H # item["support"]: B x NK x 3 x W x H
@ -81,12 +81,11 @@ def test(lmdb_path, import_path):
batch_size = item["target"].size(0) batch_size = item["target"].size(0)
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1) query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1)
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
print("compute", time.time() - st) # print("compute", time.time() - st)
load_st = time.time() load_st = time.time()
accs.append(evaluate(query_batch, item["target"], support_batch)) accs.append(evaluate(query_batch, item["target"], support_batch))
print(torch.tensor(accs).mean().item()) print(torch.tensor(accs).mean().item())
print("time: ", time.time() - st)
if __name__ == '__main__': if __name__ == '__main__':
@ -94,8 +93,8 @@ if __name__ == '__main__':
defined_path = ["/data/few-shot/lmdb/mini-imagenet/val.lmdb", defined_path = ["/data/few-shot/lmdb/mini-imagenet/val.lmdb",
"/data/few-shot/lmdb/CUB_200_2011/data.lmdb", "/data/few-shot/lmdb/CUB_200_2011/data.lmdb",
"/data/few-shot/lmdb/STANFORD-CARS/train.lmdb", "/data/few-shot/lmdb/STANFORD-CARS/train.lmdb",
# "/data/few-shot/lmdb/Plantae/data.lmdb", "/data/few-shot/lmdb/Plantae/data.lmdb",
# "/data/few-shot/lmdb/Places365/val.lmdb" "/data/few-shot/lmdb/Places365/val.lmdb"
] ]
parser = argparse.ArgumentParser(description="test") parser = argparse.ArgumentParser(description="test")
parser.add_argument('-i', "--import_path", required=True) parser.add_argument('-i', "--import_path", required=True)