This commit is contained in:
Ray Wong 2020-07-11 20:51:22 +08:00
parent 07c63abb30
commit 598bd9e0f1
4 changed files with 60 additions and 33 deletions

View File

@ -40,15 +40,17 @@ class ImprovedImageFolder(ImageFolder):
self.classes_list = defaultdict(list)
for i in range(len(self)):
self.classes_list[self.samples[i][-1]].append(i)
assert len(self.classes_list) == len(self.classes)
def __getitem__(self, item):
return super().__getitem__(item)[0]
class LMDBDataset(Dataset):
def __init__(self, lmdb_path):
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), readonly=True, lock=False,
readahead=False, meminit=False)
def __init__(self, lmdb_path, transform=None):
self.db = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path), readonly=True,
lock=False, readahead=False, meminit=False)
self.transform = transform
with self.db.begin(write=False) as txn:
self.classes_list = pickle.loads(txn.get(b"classes_list"))
self._len = pickle.loads(txn.get(b"__len__"))
@ -58,7 +60,10 @@ class LMDBDataset(Dataset):
def __getitem__(self, i):
with self.db.begin(write=False) as txn:
return torch.load(BytesIO(txn.get("{}".format(i).encode())))
sample = torch.load(BytesIO(txn.get("{}".format(i).encode())))
if self.transform is not None:
sample = self.transform(sample)
return sample
class EpisodicDataset(Dataset):
@ -73,7 +78,7 @@ class EpisodicDataset(Dataset):
return self.num_episodes
def __getitem__(self, _):
random_classes = torch.randint(high=len(self.origin.classes_list), size=(self.num_class,)).tolist()
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set_list = []
query_set_list = []
target_list = []
@ -83,8 +88,8 @@ class EpisodicDataset(Dataset):
idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist()
support_set_list.extend([self.origin[idx] for idx in idx_list[:self.num_set]])
query_set_list.extend([self.origin[idx] for idx in idx_list[self.num_set:]])
support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]])
query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]])
target_list.extend([i] * self.num_set)
return {
"support": torch.stack(support_set_list),

View File

@ -1,17 +1,19 @@
import torch
import lmdb
import os
import pickle
from io import BytesIO
import argparse
import torch
import lmdb
from data.dataset import CARS, ImprovedImageFolder
import torchvision
from tqdm import tqdm
def dataset_to_lmdb(dataset, lmdb_path):
env = lmdb.open(lmdb_path, map_size=1099511627776 * 2, subdir=os.path.isdir(lmdb_path))
env = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset))):
for i in tqdm(range(len(dataset)), ncols=50):
buffer = BytesIO()
torch.save(dataset[i], buffer)
txn.put("{}".format(i).encode(), buffer.getvalue())
@ -19,17 +21,23 @@ def dataset_to_lmdb(dataset, lmdb_path):
txn.put(b"__len__", pickle.dumps(len(dataset)))
def main():
data_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize([int(224 * 1.15), int(224 * 1.15)]),
def transform(save_path, dataset_path):
print(save_path, dataset_path)
dt = torchvision.transforms.Compose([
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
origin_dataset = ImprovedImageFolder("/data/few-shot/CUB_200_2011/CUB_200_2011/images", transform=data_transform)
dataset_to_lmdb(origin_dataset, "/data/few-shot/lmdb/CUB_200_2011/data.lmdb")
# origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", transform=dt)
origin_dataset = ImprovedImageFolder(dataset_path, transform=dt)
dataset_to_lmdb(origin_dataset, save_path)
if __name__ == '__main__':
main()
parser = argparse.ArgumentParser(description="transform dataset to lmdb database")
parser.add_argument('--save', required=True)
parser.add_argument('--dataset', required=True)
args = parser.parse_args()
transform(args.save, args.dataset)

0
submit/__init__.py Executable file
View File

44
test.py
View File

@ -1,10 +1,12 @@
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from data import dataset
import argparse
from ignite.utils import convert_tensor
import time
from importlib import import_module
from tqdm import tqdm
@ -35,7 +37,7 @@ def euclidean_dist(x, y):
def evaluate(query, target, support):
"""
:param query: B x NK x D vector
:param target: B x NK x 1 vector
:param target: B x NK vector
:param support: B x N x K x D vector
:return:
"""
@ -45,11 +47,10 @@ def evaluate(query, target, support):
return torch.eq(target, indices).float().mean()
def test(lmdb_path):
def test(lmdb_path, import_path):
origin_dataset = dataset.LMDBDataset(lmdb_path)
N = torch.randint(5, 10, (1,)).tolist()[0]
K = torch.randint(1, 10, (1,)).tolist()[0]
N = 5
K = 5
episodic_dataset = dataset.EpisodicDataset(
origin_dataset, # 抽取数据集
N, # N
@ -58,17 +59,21 @@ def test(lmdb_path):
)
print(episodic_dataset)
data_loader = DataLoader(episodic_dataset, batch_size=4, pin_memory=False)
data_loader = DataLoader(episodic_dataset, batch_size=16, pin_memory=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from submit import make_model
extractor = make_model()
submit = import_module(f"submit.{import_path}")
extractor = submit.make_model()
extractor.to(device)
accs = []
st = time.time()
load_st = time.time()
with torch.no_grad():
for item in tqdm(data_loader, nrows=80):
for item in data_loader:
st = time.time()
print("load", time.time() - load_st)
item = convert_tensor(item, device, non_blocking=True)
# item["query"]: B x NK x 3 x W x H
# item["support"]: B x NK x 3 x W x H
@ -76,6 +81,8 @@ def test(lmdb_path):
batch_size = item["target"].size(0)
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1)
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
print("compute", time.time() - st)
load_st = time.time()
accs.append(evaluate(query_batch, item["target"], support_batch))
print(torch.tensor(accs).mean().item())
@ -84,8 +91,15 @@ def test(lmdb_path):
if __name__ == '__main__':
setup_seed(100)
for path in ["/data/few-shot/lmdb/CUB_200_2011/data.lmdb",
"/data/few-shot/lmdb/mini-imagenet/train.lmdb",
"/data/few-shot/lmdb/STANFORD-CARS/train.lmdb"]:
defined_path = ["/data/few-shot/lmdb/mini-imagenet/val.lmdb",
"/data/few-shot/lmdb/CUB_200_2011/data.lmdb",
"/data/few-shot/lmdb/STANFORD-CARS/train.lmdb",
# "/data/few-shot/lmdb/Plantae/data.lmdb",
# "/data/few-shot/lmdb/Places365/val.lmdb"
]
parser = argparse.ArgumentParser(description="test")
parser.add_argument('-i', "--import_path", required=True)
args = parser.parse_args()
for path in defined_path:
print(path)
test(path)
test(path, args.import_path)