From 3a72dcb5f0112a10d59f132118122d9f1b728d09 Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Mon, 20 Jul 2020 11:02:39 +0800 Subject: [PATCH] change line ending --- .gitignore | 4 +- data/dataset.py | 209 +++++++++++++++++++++++++----------------------- data/lmdbify.py | 81 +++++++++---------- test.py | 208 +++++++++++++++++++++++------------------------ 4 files changed, 252 insertions(+), 250 deletions(-) diff --git a/.gitignore b/.gitignore index 0303433..568e017 100755 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ -*.pth -.idea/ +*.pth +.idea/ submit/ \ No newline at end of file diff --git a/data/dataset.py b/data/dataset.py index 01e8284..f08608b 100755 --- a/data/dataset.py +++ b/data/dataset.py @@ -1,101 +1,108 @@ -from scipy.io import loadmat -import torch -import lmdb -import os -import pickle -from io import BytesIO -from torch.utils.data import Dataset -from torchvision.datasets.folder import default_loader -from torchvision.datasets import ImageFolder -from pathlib import Path -from collections import defaultdict - - -class CARS(Dataset): - def __init__(self, root, loader=default_loader, transform=None): - self.root = Path(root) - self.transform = transform - self.loader = loader - self.annotations = loadmat(self.root / "devkit/cars_train_annos.mat")["annotations"][0] - self.annotations = {d[-1].item(): d[-2].item() - 1 for d in self.annotations} - self.classes_list = defaultdict(list) - for i in range(len(self.annotations)): - self.classes_list[self.annotations["{:05d}.jpg".format(i + 1)]].append(i) - - def __len__(self): - return len(self.annotations) - - def __getitem__(self, item): - file_name = "{:05d}.jpg".format(item + 1) - target = self.annotations[file_name] - sample = self.loader(self.root / "cars_train" / file_name) - if self.transform is not None: - sample = self.transform(sample) - return sample - - -class ImprovedImageFolder(ImageFolder): - def __init__(self, root, loader=default_loader, transform=None): - super().__init__(root, transform, loader=loader) - self.classes_list = defaultdict(list) - for i in range(len(self)): - self.classes_list[self.samples[i][-1]].append(i) - assert len(self.classes_list) == len(self.classes) - - def __getitem__(self, item): - return super().__getitem__(item)[0] - - -class LMDBDataset(Dataset): - def __init__(self, lmdb_path, transform=None): - self.db = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path), readonly=True, - lock=False, readahead=False, meminit=False) - self.transform = transform - with self.db.begin(write=False) as txn: - self.classes_list = pickle.loads(txn.get(b"classes_list")) - self._len = pickle.loads(txn.get(b"__len__")) - - def __len__(self): - return self._len - - def __getitem__(self, i): - with self.db.begin(write=False) as txn: - sample = torch.load(BytesIO(txn.get("{}".format(i).encode()))) - if self.transform is not None: - sample = self.transform(sample) - return sample - - -class EpisodicDataset(Dataset): - def __init__(self, origin_dataset, num_class, num_set, num_episodes): - self.origin = origin_dataset - self.num_class = num_class - assert self.num_class < len(self.origin.classes_list) - self.num_set = num_set # K - self.num_episodes = num_episodes - - def __len__(self): - return self.num_episodes - - def __getitem__(self, _): - random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() - support_set_list = [] - query_set_list = [] - target_list = [] - for i, c in enumerate(random_classes): - image_list = self.origin.classes_list[c] - if len(image_list) > self.num_set * 2: - idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist() - else: - idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist() - support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]]) - query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]]) - target_list.extend([i] * self.num_set) - return { - "support": torch.stack(support_set_list), - "query": torch.stack(query_set_list), - "target": torch.tensor(target_list) - } - - def __repr__(self): - return "".format(self.num_class, self.num_set, self.num_episodes) +from scipy.io import loadmat +import torch +import lmdb +import os +import pickle +from PIL import Image +from io import BytesIO +from torch.utils.data import Dataset +from torchvision.datasets.folder import default_loader +from torchvision.datasets import ImageFolder +from pathlib import Path +from collections import defaultdict + + +class CARS(Dataset): + def __init__(self, root, loader=default_loader, transform=None): + self.root = Path(root) + self.transform = transform + self.loader = loader + self.annotations = loadmat(self.root / "devkit/cars_train_annos.mat")["annotations"][0] + self.annotations = {d[-1].item(): d[-2].item() - 1 for d in self.annotations} + self.classes_list = defaultdict(list) + for i in range(len(self.annotations)): + self.classes_list[self.annotations["{:05d}.jpg".format(i + 1)]].append(i) + + def __len__(self): + return len(self.annotations) + + def __getitem__(self, item): + file_name = "{:05d}.jpg".format(item + 1) + target = self.annotations[file_name] + sample = self.loader(self.root / "cars_train" / file_name) + if self.transform is not None: + sample = self.transform(sample) + return sample + + +class ImprovedImageFolder(ImageFolder): + def __init__(self, root, loader=default_loader, transform=None): + super().__init__(root, transform, loader=loader) + self.classes_list = defaultdict(list) + for i in range(len(self)): + self.classes_list[self.samples[i][-1]].append(i) + assert len(self.classes_list) == len(self.classes) + + def __getitem__(self, item): + return super().__getitem__(item)[0] + + +class LMDBDataset(Dataset): + def __init__(self, lmdb_path, transform=None): + self.db = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path), readonly=True, + lock=False, readahead=False, meminit=False) + self.transform = transform + with self.db.begin(write=False) as txn: + self.classes_list = pickle.loads(txn.get(b"classes_list")) + self._len = pickle.loads(txn.get(b"__len__")) + + def __len__(self): + return self._len + + def __getitem__(self, i): + with self.db.begin(write=False) as txn: + sample = Image.open(BytesIO(txn.get("{}".format(i).encode()))) + if sample.mode != "RGB": + sample = sample.convert("RGB") + if self.transform is not None: + try: + sample = self.transform(sample) + except RuntimeError as re: + print(sample.format, sample.size, sample.mode) + raise re + return sample + + +class EpisodicDataset(Dataset): + def __init__(self, origin_dataset, num_class, num_set, num_episodes): + self.origin = origin_dataset + self.num_class = num_class + assert self.num_class < len(self.origin.classes_list) + self.num_set = num_set # K + self.num_episodes = num_episodes + + def __len__(self): + return self.num_episodes + + def __getitem__(self, _): + random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() + support_set_list = [] + query_set_list = [] + target_list = [] + for i, c in enumerate(random_classes): + image_list = self.origin.classes_list[c] + if len(image_list) > self.num_set * 2: + idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist() + else: + idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist() + support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]]) + query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]]) + target_list.extend([i] * self.num_set) + return { + "support": torch.stack(support_set_list), + "query": torch.stack(query_set_list), + "target": torch.tensor(target_list) + } + + def __repr__(self): + return "".format(self.num_class, self.num_set, self.num_episodes) diff --git a/data/lmdbify.py b/data/lmdbify.py index 1752795..97f1c9e 100755 --- a/data/lmdbify.py +++ b/data/lmdbify.py @@ -1,43 +1,38 @@ -import os -import pickle -from io import BytesIO -import argparse - -import torch -import lmdb -from data.dataset import CARS, ImprovedImageFolder -import torchvision -from tqdm import tqdm - - -def dataset_to_lmdb(dataset, lmdb_path): - env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path)) - with env.begin(write=True) as txn: - for i in tqdm(range(len(dataset)), ncols=50): - buffer = BytesIO() - torch.save(dataset[i], buffer) - txn.put("{}".format(i).encode(), buffer.getvalue()) - txn.put(b"classes_list", pickle.dumps(dataset.classes_list)) - txn.put(b"__len__", pickle.dumps(len(dataset))) - - -def transform(save_path, dataset_path): - print(save_path, dataset_path) - dt = torchvision.transforms.Compose([ - torchvision.transforms.Resize((256, 256)), - torchvision.transforms.CenterCrop(224), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) - # origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", transform=dt) - origin_dataset = ImprovedImageFolder(dataset_path, transform=dt) - dataset_to_lmdb(origin_dataset, save_path) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description="transform dataset to lmdb database") - parser.add_argument('--save', required=True) - parser.add_argument('--dataset', required=True) - args = parser.parse_args() - transform(args.save, args.dataset) - +import os +import pickle +from io import BytesIO +import argparse + +import lmdb +from data.dataset import CARS, ImprovedImageFolder +from tqdm import tqdm + + +def content_loader(path): + with open(path, "rb") as f: + return f.read() + + +def dataset_to_lmdb(dataset, lmdb_path): + env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path)) + with env.begin(write=True) as txn: + for i in tqdm(range(len(dataset)), ncols=50): + txn.put("{}".format(i).encode(), bytearray(dataset[i])) + txn.put(b"classes_list", pickle.dumps(dataset.classes_list)) + txn.put(b"__len__", pickle.dumps(len(dataset))) + + +def transform(save_path, dataset_path): + print(save_path, dataset_path) + # origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", loader=content_loader) + origin_dataset = ImprovedImageFolder(dataset_path, loader=content_loader) + dataset_to_lmdb(origin_dataset, save_path) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="transform dataset to lmdb database") + parser.add_argument('--save', required=True) + parser.add_argument('--dataset', required=True) + args = parser.parse_args() + transform(args.save, args.dataset) + diff --git a/test.py b/test.py index 30ff7c1..bd468ad 100755 --- a/test.py +++ b/test.py @@ -1,104 +1,104 @@ -import torch -from torch.utils.data import DataLoader -from torchvision import transforms -from data import dataset - -import argparse -from ignite.utils import convert_tensor -import time -from importlib import import_module -from tqdm import tqdm - - -def setup_seed(seed): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - - -def euclidean_dist(x, y): - """ - Compute euclidean distance between two tensors - """ - # x: B x N x D - # y: B x M x D - n = x.size(-2) - m = y.size(-2) - d = x.size(-1) - if d != y.size(-1): - raise Exception - - x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D - y = y.unsqueeze(1).expand(x.size(0), n, m, d) - - return torch.pow(x - y, 2).sum(-1) - - -def evaluate(query, target, support): - """ - :param query: B x NK x D vector - :param target: B x NK vector - :param support: B x N x K x D vector - :return: - """ - prototypes = support.mean(-2) # B x N x D - distance = euclidean_dist(query, prototypes) # B x NK x N - indices = distance.argmin(-1) # B x NK - return torch.eq(target, indices).float().mean() - - -def test(lmdb_path, import_path): - origin_dataset = dataset.LMDBDataset(lmdb_path) - N = 5 - K = 5 - episodic_dataset = dataset.EpisodicDataset( - origin_dataset, # 抽取数据集 - N, # N - K, # K - 100 # 任务数目 - ) - print(episodic_dataset) - - data_loader = DataLoader(episodic_dataset, batch_size=16, pin_memory=False) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - submit = import_module(f"submit.{import_path}") - - extractor = submit.make_model() - extractor.to(device) - - accs = [] - - load_st = time.time() - with torch.no_grad(): - for item in data_loader: - st = time.time() - # print("load", time.time() - load_st) - item = convert_tensor(item, device, non_blocking=True) - # item["query"]: B x NK x 3 x W x H - # item["support"]: B x NK x 3 x W x H - # item["target"]: B x NK - batch_size = item["target"].size(0) - query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1) - support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) - # print("compute", time.time() - st) - load_st = time.time() - - accs.append(evaluate(query_batch, item["target"], support_batch)) - print(torch.tensor(accs).mean().item()) - - -if __name__ == '__main__': - setup_seed(100) - defined_path = ["/data/few-shot/lmdb/mini-imagenet/val.lmdb", - "/data/few-shot/lmdb/CUB_200_2011/data.lmdb", - "/data/few-shot/lmdb/STANFORD-CARS/train.lmdb", - "/data/few-shot/lmdb/Plantae/data.lmdb", - "/data/few-shot/lmdb/Places365/val.lmdb" - ] - parser = argparse.ArgumentParser(description="test") - parser.add_argument('-i', "--import_path", required=True) - args = parser.parse_args() - for path in defined_path: - print(path) - test(path, args.import_path) +import torch +from torch.utils.data import DataLoader +import torchvision +from data import dataset + +import argparse +from ignite.utils import convert_tensor +import time +from importlib import import_module +from tqdm import tqdm + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + + +def euclidean_dist(x, y): + """ + Compute euclidean distance between two tensors + """ + # x: B x N x D + # y: B x M x D + n = x.size(-2) + m = y.size(-2) + d = x.size(-1) + if d != y.size(-1): + raise Exception + + x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D + y = y.unsqueeze(1).expand(x.size(0), n, m, d) + + return torch.pow(x - y, 2).sum(-1) + + +def evaluate(query, target, support): + """ + :param query: B x NK x D vector + :param target: B x NK vector + :param support: B x N x K x D vector + :return: + """ + prototypes = support.mean(-2) # B x N x D + distance = euclidean_dist(query, prototypes) # B x NK x N + indices = distance.argmin(-1) # B x NK + return torch.eq(target, indices).float().mean() + + +def test(lmdb_path, import_path): + dt = torchvision.transforms.Compose([ + torchvision.transforms.Resize((256, 256)), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + origin_dataset = dataset.LMDBDataset(lmdb_path, transform=dt) + N = 5 + K = 5 + episodic_dataset = dataset.EpisodicDataset( + origin_dataset, # 抽取数据集 + N, # N + K, # K + 100 # 任务数目 + ) + print(episodic_dataset) + + data_loader = DataLoader(episodic_dataset, batch_size=20, pin_memory=False) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + submit = import_module(f"submit.{import_path}") + + extractor = submit.make_model() + extractor.to(device) + + accs = [] + + with torch.no_grad(): + for item in tqdm(data_loader): + item = convert_tensor(item, device, non_blocking=True) + # item["query"]: B x NK x 3 x W x H + # item["support"]: B x NK x 3 x W x H + # item["target"]: B x NK + batch_size = item["target"].size(0) + query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1) + support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) + accs.append(evaluate(query_batch, item["target"], support_batch)) + print(torch.tensor(accs).mean().item()) + + +if __name__ == '__main__': + setup_seed(100) + defined_path = [ + "/data/few-shot/lmdb/dogs/data.lmdb", + "/data/few-shot/lmdb/flowers/data.lmdb", + "/data/few-shot/lmdb/256-object/data.lmdb", + "/data/few-shot/lmdb/dtd/data.lmdb", + ] + parser = argparse.ArgumentParser(description="test") + parser.add_argument('-i', "--import_path", required=True) + args = parser.parse_args() + for path in defined_path: + print(path) + test(path, args.import_path)