diff --git a/data/dataset.py b/data/dataset.py index f08608b..bb5adaa 100755 --- a/data/dataset.py +++ b/data/dataset.py @@ -1,10 +1,9 @@ from scipy.io import loadmat import torch +import torchvision import lmdb import os import pickle -from PIL import Image -from io import BytesIO from torch.utils.data import Dataset from torchvision.datasets.folder import default_loader from torchvision.datasets import ImageFolder @@ -44,7 +43,7 @@ class ImprovedImageFolder(ImageFolder): assert len(self.classes_list) == len(self.classes) def __getitem__(self, item): - return super().__getitem__(item)[0] + return super().__getitem__(item) class LMDBDataset(Dataset): @@ -61,16 +60,10 @@ class LMDBDataset(Dataset): def __getitem__(self, i): with self.db.begin(write=False) as txn: - sample = Image.open(BytesIO(txn.get("{}".format(i).encode()))) - if sample.mode != "RGB": - sample = sample.convert("RGB") + sample, target = pickle.loads(txn.get("{}".format(i).encode())) if self.transform is not None: - try: - sample = self.transform(sample) - except RuntimeError as re: - print(sample.format, sample.size, sample.mode) - raise re - return sample + sample = self.transform(sample) + return sample, target class EpisodicDataset(Dataset): @@ -81,6 +74,20 @@ class EpisodicDataset(Dataset): self.num_set = num_set # K self.num_episodes = num_episodes + self.t0 = torchvision.transforms.Compose([ + # torchvision.transforms.Resize((224, 224)), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + def _fetch_list_data(self, id_list): + result = [] + for i in id_list: + img = self.origin[i][0] + result.extend([self.t0(img)]) + return result + def __len__(self): return self.num_episodes @@ -89,15 +96,20 @@ class EpisodicDataset(Dataset): support_set_list = [] query_set_list = [] target_list = [] - for i, c in enumerate(random_classes): + for tag, c in enumerate(random_classes): image_list = self.origin.classes_list[c] - if len(image_list) > self.num_set * 2: + + if len(image_list) >= self.num_set * 2: + # have enough images belong to this class idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist() else: idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist() - support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]]) - query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]]) - target_list.extend([i] * self.num_set) + + support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_set])) + query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_set:])) + support_set_list.extend(support) + query_set_list.extend(query) + target_list.extend([tag] * self.num_set) return { "support": torch.stack(support_set_list), "query": torch.stack(query_set_list), diff --git a/data/lmdbify.py b/data/lmdbify.py index 97f1c9e..5e0972a 100755 --- a/data/lmdbify.py +++ b/data/lmdbify.py @@ -1,38 +1,34 @@ import os import pickle -from io import BytesIO -import argparse - +from PIL import Image import lmdb -from data.dataset import CARS, ImprovedImageFolder +from data.dataset import ImprovedImageFolder from tqdm import tqdm +import fire def content_loader(path): - with open(path, "rb") as f: - return f.read() + im = Image.open(path) + im = im.resize((256, 256)) + if im.mode != "RGB": + im = im.convert("RGB") + return im def dataset_to_lmdb(dataset, lmdb_path): env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path)) with env.begin(write=True) as txn: for i in tqdm(range(len(dataset)), ncols=50): - txn.put("{}".format(i).encode(), bytearray(dataset[i])) + txn.put("{}".format(i).encode(), pickle.dumps(dataset[i])) txn.put(b"classes_list", pickle.dumps(dataset.classes_list)) txn.put(b"__len__", pickle.dumps(len(dataset))) def transform(save_path, dataset_path): print(save_path, dataset_path) - # origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", loader=content_loader) origin_dataset = ImprovedImageFolder(dataset_path, loader=content_loader) dataset_to_lmdb(origin_dataset, save_path) if __name__ == '__main__': - parser = argparse.ArgumentParser(description="transform dataset to lmdb database") - parser.add_argument('--save', required=True) - parser.add_argument('--dataset', required=True) - args = parser.parse_args() - transform(args.save, args.dataset) - + fire.Fire(transform) diff --git a/test.py b/test.py index bd468ad..6036bfc 100755 --- a/test.py +++ b/test.py @@ -48,57 +48,53 @@ def evaluate(query, target, support): def test(lmdb_path, import_path): - dt = torchvision.transforms.Compose([ - torchvision.transforms.Resize((256, 256)), - torchvision.transforms.CenterCrop(224), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) - origin_dataset = dataset.LMDBDataset(lmdb_path, transform=dt) - N = 5 - K = 5 - episodic_dataset = dataset.EpisodicDataset( - origin_dataset, # 抽取数据集 - N, # N - K, # K - 100 # 任务数目 - ) - print(episodic_dataset) + origin_dataset = dataset.LMDBDataset(lmdb_path, transform=None) - data_loader = DataLoader(episodic_dataset, batch_size=20, pin_memory=False) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") submit = import_module(f"submit.{import_path}") extractor = submit.make_model() extractor.to(device) - accs = [] + batch_size = 10 + N = 5 + K = 5 + episodic_dataset = dataset.EpisodicDataset(origin_dataset, N, K, 100) + data_loader = DataLoader(episodic_dataset, batch_size=batch_size, pin_memory=False) with torch.no_grad(): + accs = [] for item in tqdm(data_loader): item = convert_tensor(item, device, non_blocking=True) - # item["query"]: B x NK x 3 x W x H - # item["support"]: B x NK x 3 x W x H + # item["query"]: B x NKA x 3 x W x H + # item["support"]: B x NKA x 3 x W x H # item["target"]: B x NK - batch_size = item["target"].size(0) - query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1) - support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) + A = item["query"].size(1) // item["target"].size(1) + image_size = item["query"].shape[-3:] + query_batch = extractor(item["query"].view([-1, *image_size])).view(batch_size, N * K, A, -1) + support_batch = extractor(item["support"].view([-1, *image_size])).view(batch_size, N, K, A, -1) + query_batch = torch.mean(query_batch, dim=-2) + support_batch = torch.mean(support_batch, dim=-2) + assert query_batch.shape[:2] == item["target"].shape[:2] accs.append(evaluate(query_batch, item["target"], support_batch)) - print(torch.tensor(accs).mean().item()) + r = torch.tensor(accs).mean().item() + print(lmdb_path, r) + return r if __name__ == '__main__': setup_seed(100) defined_path = [ - "/data/few-shot/lmdb/dogs/data.lmdb", - "/data/few-shot/lmdb/flowers/data.lmdb", - "/data/few-shot/lmdb/256-object/data.lmdb", - "/data/few-shot/lmdb/dtd/data.lmdb", - ] + "/data/few-shot/lmdb256/dogs.lmdb", + "/data/few-shot/lmdb256/flowers.lmdb", + "/data/few-shot/lmdb256/256-object.lmdb", + "/data/few-shot/lmdb256/dtd.lmdb", + "/data/few-shot/lmdb256/cars_train.lmdb", + "/data/few-shot/lmdb256/cub.lmdb", + ] parser = argparse.ArgumentParser(description="test") parser.add_argument('-i', "--import_path", required=True) args = parser.parse_args() for path in defined_path: - print(path) test(path, args.import_path)