1
This commit is contained in:
parent
598bd9e0f1
commit
7d720c181b
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
*.pth
|
*.pth
|
||||||
.idea/
|
.idea/
|
||||||
|
submit/
|
||||||
@ -11,7 +11,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
|
|
||||||
def dataset_to_lmdb(dataset, lmdb_path):
|
def dataset_to_lmdb(dataset, lmdb_path):
|
||||||
env = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path))
|
env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path))
|
||||||
with env.begin(write=True) as txn:
|
with env.begin(write=True) as txn:
|
||||||
for i in tqdm(range(len(dataset)), ncols=50):
|
for i in tqdm(range(len(dataset)), ncols=50):
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
|
|||||||
9
test.py
9
test.py
@ -73,7 +73,7 @@ def test(lmdb_path, import_path):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for item in data_loader:
|
for item in data_loader:
|
||||||
st = time.time()
|
st = time.time()
|
||||||
print("load", time.time() - load_st)
|
# print("load", time.time() - load_st)
|
||||||
item = convert_tensor(item, device, non_blocking=True)
|
item = convert_tensor(item, device, non_blocking=True)
|
||||||
# item["query"]: B x NK x 3 x W x H
|
# item["query"]: B x NK x 3 x W x H
|
||||||
# item["support"]: B x NK x 3 x W x H
|
# item["support"]: B x NK x 3 x W x H
|
||||||
@ -81,12 +81,11 @@ def test(lmdb_path, import_path):
|
|||||||
batch_size = item["target"].size(0)
|
batch_size = item["target"].size(0)
|
||||||
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1)
|
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1)
|
||||||
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
|
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
|
||||||
print("compute", time.time() - st)
|
# print("compute", time.time() - st)
|
||||||
load_st = time.time()
|
load_st = time.time()
|
||||||
|
|
||||||
accs.append(evaluate(query_batch, item["target"], support_batch))
|
accs.append(evaluate(query_batch, item["target"], support_batch))
|
||||||
print(torch.tensor(accs).mean().item())
|
print(torch.tensor(accs).mean().item())
|
||||||
print("time: ", time.time() - st)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -94,8 +93,8 @@ if __name__ == '__main__':
|
|||||||
defined_path = ["/data/few-shot/lmdb/mini-imagenet/val.lmdb",
|
defined_path = ["/data/few-shot/lmdb/mini-imagenet/val.lmdb",
|
||||||
"/data/few-shot/lmdb/CUB_200_2011/data.lmdb",
|
"/data/few-shot/lmdb/CUB_200_2011/data.lmdb",
|
||||||
"/data/few-shot/lmdb/STANFORD-CARS/train.lmdb",
|
"/data/few-shot/lmdb/STANFORD-CARS/train.lmdb",
|
||||||
# "/data/few-shot/lmdb/Plantae/data.lmdb",
|
"/data/few-shot/lmdb/Plantae/data.lmdb",
|
||||||
# "/data/few-shot/lmdb/Places365/val.lmdb"
|
"/data/few-shot/lmdb/Places365/val.lmdb"
|
||||||
]
|
]
|
||||||
parser = argparse.ArgumentParser(description="test")
|
parser = argparse.ArgumentParser(description="test")
|
||||||
parser.add_argument('-i', "--import_path", required=True)
|
parser.add_argument('-i', "--import_path", required=True)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user