From ead93c1b0ec6b53e0179a8f6fbbb6cf64401710d Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Thu, 23 Jul 2020 22:32:28 +0800 Subject: [PATCH] test --- data/dataset.py | 45 ++++++++++++++++++++++++++++++-------------- data/lmdbify.py | 15 ++++++++------- loss/__init__.py | 0 loss/prototypical.py | 19 +++++++++++++++++++ test.py | 33 +++++++++++++++++++------------- 5 files changed, 78 insertions(+), 34 deletions(-) create mode 100755 loss/__init__.py create mode 100755 loss/prototypical.py diff --git a/data/dataset.py b/data/dataset.py index f08608b..b4236df 100755 --- a/data/dataset.py +++ b/data/dataset.py @@ -8,11 +8,13 @@ from io import BytesIO from torch.utils.data import Dataset from torchvision.datasets.folder import default_loader from torchvision.datasets import ImageFolder +from torchvision import transforms +from torchvision.transforms import functional from pathlib import Path from collections import defaultdict -class CARS(Dataset): +class _CARS(Dataset): def __init__(self, root, loader=default_loader, transform=None): self.root = Path(root) self.transform = transform @@ -32,7 +34,7 @@ class CARS(Dataset): sample = self.loader(self.root / "cars_train" / file_name) if self.transform is not None: sample = self.transform(sample) - return sample + return sample, target class ImprovedImageFolder(ImageFolder): @@ -44,7 +46,7 @@ class ImprovedImageFolder(ImageFolder): assert len(self.classes_list) == len(self.classes) def __getitem__(self, item): - return super().__getitem__(item)[0] + return super().__getitem__(item) class LMDBDataset(Dataset): @@ -61,16 +63,10 @@ class LMDBDataset(Dataset): def __getitem__(self, i): with self.db.begin(write=False) as txn: - sample = Image.open(BytesIO(txn.get("{}".format(i).encode()))) - if sample.mode != "RGB": - sample = sample.convert("RGB") + sample, target = pickle.loads(txn.get("{}".format(i).encode())) if self.transform is not None: - try: - sample = self.transform(sample) - except RuntimeError as re: - print(sample.format, sample.size, sample.mode) - raise re - return sample + sample = self.transform(sample) + return sample, target class EpisodicDataset(Dataset): @@ -81,6 +77,24 @@ class EpisodicDataset(Dataset): self.num_set = num_set # K self.num_episodes = num_episodes + self.t0 = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + self.transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + def apply_transform(self, img): + # img1 = self.transform(img) + # img2 = self.transform(img) + # return [self.t0(img), self.t0(functional.hflip(img))] + return [self.t0(img)] + def __len__(self): return self.num_episodes @@ -95,8 +109,11 @@ class EpisodicDataset(Dataset): idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist() else: idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist() - support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]]) - query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]]) + support = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]] + query = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]] + + support_set_list.extend(sum(map(self.apply_transform, support), list())) + query_set_list.extend(sum(map(self.apply_transform, query), list())) target_list.extend([i] * self.num_set) return { "support": torch.stack(support_set_list), diff --git a/data/lmdbify.py b/data/lmdbify.py index 97f1c9e..1d3b77f 100755 --- a/data/lmdbify.py +++ b/data/lmdbify.py @@ -1,30 +1,31 @@ import os import pickle -from io import BytesIO import argparse - +from PIL import Image import lmdb -from data.dataset import CARS, ImprovedImageFolder +from data.dataset import ImprovedImageFolder from tqdm import tqdm def content_loader(path): - with open(path, "rb") as f: - return f.read() + im = Image.open(path) + im = im.resize((256, 256)) + if im.mode != "RGB": + im = im.convert("RGB") + return im def dataset_to_lmdb(dataset, lmdb_path): env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path)) with env.begin(write=True) as txn: for i in tqdm(range(len(dataset)), ncols=50): - txn.put("{}".format(i).encode(), bytearray(dataset[i])) + txn.put("{}".format(i).encode(), pickle.dumps(dataset[i])) txn.put(b"classes_list", pickle.dumps(dataset.classes_list)) txn.put(b"__len__", pickle.dumps(len(dataset))) def transform(save_path, dataset_path): print(save_path, dataset_path) - # origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", loader=content_loader) origin_dataset = ImprovedImageFolder(dataset_path, loader=content_loader) dataset_to_lmdb(origin_dataset, save_path) diff --git a/loss/__init__.py b/loss/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/loss/prototypical.py b/loss/prototypical.py new file mode 100755 index 0000000..223c78a --- /dev/null +++ b/loss/prototypical.py @@ -0,0 +1,19 @@ +import torch + + +def euclidean_dist(x, y): + """ + Compute euclidean distance between two tensors + """ + # x: B x N x D + # y: B x M x D + n = x.size(-2) + m = y.size(-2) + d = x.size(-1) + if d != y.size(-1): + raise Exception + + x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D + y = y.unsqueeze(1).expand(x.size(0), n, m, d) + + return torch.pow(x - y, 2).sum(-1) \ No newline at end of file diff --git a/test.py b/test.py index bd468ad..32df2b2 100755 --- a/test.py +++ b/test.py @@ -49,12 +49,11 @@ def evaluate(query, target, support): def test(lmdb_path, import_path): dt = torchvision.transforms.Compose([ - torchvision.transforms.Resize((256, 256)), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) - origin_dataset = dataset.LMDBDataset(lmdb_path, transform=dt) + origin_dataset = dataset.LMDBDataset(lmdb_path, transform=None) N = 5 K = 5 episodic_dataset = dataset.EpisodicDataset( @@ -65,8 +64,8 @@ def test(lmdb_path, import_path): ) print(episodic_dataset) - data_loader = DataLoader(episodic_dataset, batch_size=20, pin_memory=False) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + data_loader = DataLoader(episodic_dataset, batch_size=8, pin_memory=False) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") submit = import_module(f"submit.{import_path}") @@ -78,12 +77,19 @@ def test(lmdb_path, import_path): with torch.no_grad(): for item in tqdm(data_loader): item = convert_tensor(item, device, non_blocking=True) - # item["query"]: B x NK x 3 x W x H - # item["support"]: B x NK x 3 x W x H + # item["query"]: B x ANK x 3 x W x H + # item["support"]: B x ANK x 3 x W x H # item["target"]: B x NK batch_size = item["target"].size(0) - query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1) - support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) + image_size = item["query"].shape[-3:] + A = int(item["query"].size(1) / (N * K)) + + query_batch = extractor(item["query"].view([-1, *image_size])).view(batch_size, N * K, A, -1) + support_batch = extractor(item["support"].view([-1, *image_size])).view(batch_size, N, K, A, -1) + + query_batch = torch.mean(query_batch, -2) + support_batch = torch.mean(support_batch, -2) + accs.append(evaluate(query_batch, item["target"], support_batch)) print(torch.tensor(accs).mean().item()) @@ -91,11 +97,12 @@ def test(lmdb_path, import_path): if __name__ == '__main__': setup_seed(100) defined_path = [ - "/data/few-shot/lmdb/dogs/data.lmdb", - "/data/few-shot/lmdb/flowers/data.lmdb", - "/data/few-shot/lmdb/256-object/data.lmdb", - "/data/few-shot/lmdb/dtd/data.lmdb", - ] + "/data/few-shot/lmdb256/dogs.lmdb", + "/data/few-shot/lmdb256/flowers.lmdb", + "/data/few-shot/lmdb256/256-object.lmdb", + "/data/few-shot/lmdb256/dtd.lmdb", + "/data/few-shot/lmdb256/mini-imagenet-test.lmdb" + ] parser = argparse.ArgumentParser(description="test") parser.add_argument('-i', "--import_path", required=True) args = parser.parse_args()