diff --git a/data/dataset.py b/data/dataset.py index 1636e1a..01e8284 100755 --- a/data/dataset.py +++ b/data/dataset.py @@ -40,15 +40,17 @@ class ImprovedImageFolder(ImageFolder): self.classes_list = defaultdict(list) for i in range(len(self)): self.classes_list[self.samples[i][-1]].append(i) + assert len(self.classes_list) == len(self.classes) def __getitem__(self, item): return super().__getitem__(item)[0] class LMDBDataset(Dataset): - def __init__(self, lmdb_path): - self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), readonly=True, lock=False, - readahead=False, meminit=False) + def __init__(self, lmdb_path, transform=None): + self.db = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path), readonly=True, + lock=False, readahead=False, meminit=False) + self.transform = transform with self.db.begin(write=False) as txn: self.classes_list = pickle.loads(txn.get(b"classes_list")) self._len = pickle.loads(txn.get(b"__len__")) @@ -58,7 +60,10 @@ class LMDBDataset(Dataset): def __getitem__(self, i): with self.db.begin(write=False) as txn: - return torch.load(BytesIO(txn.get("{}".format(i).encode()))) + sample = torch.load(BytesIO(txn.get("{}".format(i).encode()))) + if self.transform is not None: + sample = self.transform(sample) + return sample class EpisodicDataset(Dataset): @@ -73,7 +78,7 @@ class EpisodicDataset(Dataset): return self.num_episodes def __getitem__(self, _): - random_classes = torch.randint(high=len(self.origin.classes_list), size=(self.num_class,)).tolist() + random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() support_set_list = [] query_set_list = [] target_list = [] @@ -83,8 +88,8 @@ class EpisodicDataset(Dataset): idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist() else: idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist() - support_set_list.extend([self.origin[idx] for idx in idx_list[:self.num_set]]) - query_set_list.extend([self.origin[idx] for idx in idx_list[self.num_set:]]) + support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]]) + query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]]) target_list.extend([i] * self.num_set) return { "support": torch.stack(support_set_list), diff --git a/data/lmdbify.py b/data/lmdbify.py index 7e874a2..4d87313 100755 --- a/data/lmdbify.py +++ b/data/lmdbify.py @@ -1,17 +1,19 @@ -import torch -import lmdb import os import pickle from io import BytesIO +import argparse + +import torch +import lmdb from data.dataset import CARS, ImprovedImageFolder import torchvision from tqdm import tqdm def dataset_to_lmdb(dataset, lmdb_path): - env = lmdb.open(lmdb_path, map_size=1099511627776 * 2, subdir=os.path.isdir(lmdb_path)) + env = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path)) with env.begin(write=True) as txn: - for i in tqdm(range(len(dataset))): + for i in tqdm(range(len(dataset)), ncols=50): buffer = BytesIO() torch.save(dataset[i], buffer) txn.put("{}".format(i).encode(), buffer.getvalue()) @@ -19,17 +21,23 @@ def dataset_to_lmdb(dataset, lmdb_path): txn.put(b"__len__", pickle.dumps(len(dataset))) -def main(): - data_transform = torchvision.transforms.Compose([ - torchvision.transforms.Resize([int(224 * 1.15), int(224 * 1.15)]), +def transform(save_path, dataset_path): + print(save_path, dataset_path) + dt = torchvision.transforms.Compose([ + torchvision.transforms.Resize((256, 256)), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) - origin_dataset = ImprovedImageFolder("/data/few-shot/CUB_200_2011/CUB_200_2011/images", transform=data_transform) - dataset_to_lmdb(origin_dataset, "/data/few-shot/lmdb/CUB_200_2011/data.lmdb") + # origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", transform=dt) + origin_dataset = ImprovedImageFolder(dataset_path, transform=dt) + dataset_to_lmdb(origin_dataset, save_path) if __name__ == '__main__': - main() + parser = argparse.ArgumentParser(description="transform dataset to lmdb database") + parser.add_argument('--save', required=True) + parser.add_argument('--dataset', required=True) + args = parser.parse_args() + transform(args.save, args.dataset) diff --git a/submit/__init__.py b/submit/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/test.py b/test.py index 2100f72..f064268 100755 --- a/test.py +++ b/test.py @@ -1,10 +1,12 @@ import torch from torch.utils.data import DataLoader - +from torchvision import transforms from data import dataset +import argparse from ignite.utils import convert_tensor import time +from importlib import import_module from tqdm import tqdm @@ -35,7 +37,7 @@ def euclidean_dist(x, y): def evaluate(query, target, support): """ :param query: B x NK x D vector - :param target: B x NK x 1 vector + :param target: B x NK vector :param support: B x N x K x D vector :return: """ @@ -45,11 +47,10 @@ def evaluate(query, target, support): return torch.eq(target, indices).float().mean() -def test(lmdb_path): +def test(lmdb_path, import_path): origin_dataset = dataset.LMDBDataset(lmdb_path) - - N = torch.randint(5, 10, (1,)).tolist()[0] - K = torch.randint(1, 10, (1,)).tolist()[0] + N = 5 + K = 5 episodic_dataset = dataset.EpisodicDataset( origin_dataset, # 抽取数据集 N, # N @@ -58,17 +59,21 @@ def test(lmdb_path): ) print(episodic_dataset) - data_loader = DataLoader(episodic_dataset, batch_size=4, pin_memory=False) + data_loader = DataLoader(episodic_dataset, batch_size=16, pin_memory=False) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - from submit import make_model - extractor = make_model() + submit = import_module(f"submit.{import_path}") + + extractor = submit.make_model() extractor.to(device) accs = [] - st = time.time() + + load_st = time.time() with torch.no_grad(): - for item in tqdm(data_loader, nrows=80): + for item in data_loader: + st = time.time() + print("load", time.time() - load_st) item = convert_tensor(item, device, non_blocking=True) # item["query"]: B x NK x 3 x W x H # item["support"]: B x NK x 3 x W x H @@ -76,6 +81,8 @@ def test(lmdb_path): batch_size = item["target"].size(0) query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1) support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) + print("compute", time.time() - st) + load_st = time.time() accs.append(evaluate(query_batch, item["target"], support_batch)) print(torch.tensor(accs).mean().item()) @@ -84,8 +91,15 @@ def test(lmdb_path): if __name__ == '__main__': setup_seed(100) - for path in ["/data/few-shot/lmdb/CUB_200_2011/data.lmdb", - "/data/few-shot/lmdb/mini-imagenet/train.lmdb", - "/data/few-shot/lmdb/STANFORD-CARS/train.lmdb"]: + defined_path = ["/data/few-shot/lmdb/mini-imagenet/val.lmdb", + "/data/few-shot/lmdb/CUB_200_2011/data.lmdb", + "/data/few-shot/lmdb/STANFORD-CARS/train.lmdb", + # "/data/few-shot/lmdb/Plantae/data.lmdb", + # "/data/few-shot/lmdb/Places365/val.lmdb" + ] + parser = argparse.ArgumentParser(description="test") + parser.add_argument('-i', "--import_path", required=True) + args = parser.parse_args() + for path in defined_path: print(path) - test(path) + test(path, args.import_path)