1
This commit is contained in:
parent
598bd9e0f1
commit
7d720c181b
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
*.pth
|
||||
.idea/
|
||||
submit/
|
||||
@ -11,7 +11,7 @@ from tqdm import tqdm
|
||||
|
||||
|
||||
def dataset_to_lmdb(dataset, lmdb_path):
|
||||
env = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path))
|
||||
env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path))
|
||||
with env.begin(write=True) as txn:
|
||||
for i in tqdm(range(len(dataset)), ncols=50):
|
||||
buffer = BytesIO()
|
||||
|
||||
9
test.py
9
test.py
@ -73,7 +73,7 @@ def test(lmdb_path, import_path):
|
||||
with torch.no_grad():
|
||||
for item in data_loader:
|
||||
st = time.time()
|
||||
print("load", time.time() - load_st)
|
||||
# print("load", time.time() - load_st)
|
||||
item = convert_tensor(item, device, non_blocking=True)
|
||||
# item["query"]: B x NK x 3 x W x H
|
||||
# item["support"]: B x NK x 3 x W x H
|
||||
@ -81,12 +81,11 @@ def test(lmdb_path, import_path):
|
||||
batch_size = item["target"].size(0)
|
||||
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1)
|
||||
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
|
||||
print("compute", time.time() - st)
|
||||
# print("compute", time.time() - st)
|
||||
load_st = time.time()
|
||||
|
||||
accs.append(evaluate(query_batch, item["target"], support_batch))
|
||||
print(torch.tensor(accs).mean().item())
|
||||
print("time: ", time.time() - st)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -94,8 +93,8 @@ if __name__ == '__main__':
|
||||
defined_path = ["/data/few-shot/lmdb/mini-imagenet/val.lmdb",
|
||||
"/data/few-shot/lmdb/CUB_200_2011/data.lmdb",
|
||||
"/data/few-shot/lmdb/STANFORD-CARS/train.lmdb",
|
||||
# "/data/few-shot/lmdb/Plantae/data.lmdb",
|
||||
# "/data/few-shot/lmdb/Places365/val.lmdb"
|
||||
"/data/few-shot/lmdb/Plantae/data.lmdb",
|
||||
"/data/few-shot/lmdb/Places365/val.lmdb"
|
||||
]
|
||||
parser = argparse.ArgumentParser(description="test")
|
||||
parser.add_argument('-i', "--import_path", required=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user