This commit is contained in:
Ray Wong 2020-07-29 00:03:16 +08:00
parent 3a72dcb5f0
commit 57ad9a2572
3 changed files with 66 additions and 62 deletions

View File

@ -1,10 +1,9 @@
from scipy.io import loadmat from scipy.io import loadmat
import torch import torch
import torchvision
import lmdb import lmdb
import os import os
import pickle import pickle
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader from torchvision.datasets.folder import default_loader
from torchvision.datasets import ImageFolder from torchvision.datasets import ImageFolder
@ -44,7 +43,7 @@ class ImprovedImageFolder(ImageFolder):
assert len(self.classes_list) == len(self.classes) assert len(self.classes_list) == len(self.classes)
def __getitem__(self, item): def __getitem__(self, item):
return super().__getitem__(item)[0] return super().__getitem__(item)
class LMDBDataset(Dataset): class LMDBDataset(Dataset):
@ -61,16 +60,10 @@ class LMDBDataset(Dataset):
def __getitem__(self, i): def __getitem__(self, i):
with self.db.begin(write=False) as txn: with self.db.begin(write=False) as txn:
sample = Image.open(BytesIO(txn.get("{}".format(i).encode()))) sample, target = pickle.loads(txn.get("{}".format(i).encode()))
if sample.mode != "RGB":
sample = sample.convert("RGB")
if self.transform is not None: if self.transform is not None:
try:
sample = self.transform(sample) sample = self.transform(sample)
except RuntimeError as re: return sample, target
print(sample.format, sample.size, sample.mode)
raise re
return sample
class EpisodicDataset(Dataset): class EpisodicDataset(Dataset):
@ -81,6 +74,20 @@ class EpisodicDataset(Dataset):
self.num_set = num_set # K self.num_set = num_set # K
self.num_episodes = num_episodes self.num_episodes = num_episodes
self.t0 = torchvision.transforms.Compose([
# torchvision.transforms.Resize((224, 224)),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def _fetch_list_data(self, id_list):
result = []
for i in id_list:
img = self.origin[i][0]
result.extend([self.t0(img)])
return result
def __len__(self): def __len__(self):
return self.num_episodes return self.num_episodes
@ -89,15 +96,20 @@ class EpisodicDataset(Dataset):
support_set_list = [] support_set_list = []
query_set_list = [] query_set_list = []
target_list = [] target_list = []
for i, c in enumerate(random_classes): for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c] image_list = self.origin.classes_list[c]
if len(image_list) > self.num_set * 2:
if len(image_list) >= self.num_set * 2:
# have enough images belong to this class
idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist() idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist()
else: else:
idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist() idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist()
support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]])
query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]]) support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_set]))
target_list.extend([i] * self.num_set) query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_set:]))
support_set_list.extend(support)
query_set_list.extend(query)
target_list.extend([tag] * self.num_set)
return { return {
"support": torch.stack(support_set_list), "support": torch.stack(support_set_list),
"query": torch.stack(query_set_list), "query": torch.stack(query_set_list),

View File

@ -1,38 +1,34 @@
import os import os
import pickle import pickle
from io import BytesIO from PIL import Image
import argparse
import lmdb import lmdb
from data.dataset import CARS, ImprovedImageFolder from data.dataset import ImprovedImageFolder
from tqdm import tqdm from tqdm import tqdm
import fire
def content_loader(path): def content_loader(path):
with open(path, "rb") as f: im = Image.open(path)
return f.read() im = im.resize((256, 256))
if im.mode != "RGB":
im = im.convert("RGB")
return im
def dataset_to_lmdb(dataset, lmdb_path): def dataset_to_lmdb(dataset, lmdb_path):
env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path)) env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn: with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=50): for i in tqdm(range(len(dataset)), ncols=50):
txn.put("{}".format(i).encode(), bytearray(dataset[i])) txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"classes_list", pickle.dumps(dataset.classes_list)) txn.put(b"classes_list", pickle.dumps(dataset.classes_list))
txn.put(b"__len__", pickle.dumps(len(dataset))) txn.put(b"__len__", pickle.dumps(len(dataset)))
def transform(save_path, dataset_path): def transform(save_path, dataset_path):
print(save_path, dataset_path) print(save_path, dataset_path)
# origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", loader=content_loader)
origin_dataset = ImprovedImageFolder(dataset_path, loader=content_loader) origin_dataset = ImprovedImageFolder(dataset_path, loader=content_loader)
dataset_to_lmdb(origin_dataset, save_path) dataset_to_lmdb(origin_dataset, save_path)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description="transform dataset to lmdb database") fire.Fire(transform)
parser.add_argument('--save', required=True)
parser.add_argument('--dataset', required=True)
args = parser.parse_args()
transform(args.save, args.dataset)

56
test.py
View File

@ -48,57 +48,53 @@ def evaluate(query, target, support):
def test(lmdb_path, import_path): def test(lmdb_path, import_path):
dt = torchvision.transforms.Compose([ origin_dataset = dataset.LMDBDataset(lmdb_path, transform=None)
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
origin_dataset = dataset.LMDBDataset(lmdb_path, transform=dt)
N = 5
K = 5
episodic_dataset = dataset.EpisodicDataset(
origin_dataset, # 抽取数据集
N, # N
K, # K
100 # 任务数目
)
print(episodic_dataset)
data_loader = DataLoader(episodic_dataset, batch_size=20, pin_memory=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
submit = import_module(f"submit.{import_path}") submit = import_module(f"submit.{import_path}")
extractor = submit.make_model() extractor = submit.make_model()
extractor.to(device) extractor.to(device)
accs = [] batch_size = 10
N = 5
K = 5
episodic_dataset = dataset.EpisodicDataset(origin_dataset, N, K, 100)
data_loader = DataLoader(episodic_dataset, batch_size=batch_size, pin_memory=False)
with torch.no_grad(): with torch.no_grad():
accs = []
for item in tqdm(data_loader): for item in tqdm(data_loader):
item = convert_tensor(item, device, non_blocking=True) item = convert_tensor(item, device, non_blocking=True)
# item["query"]: B x NK x 3 x W x H # item["query"]: B x NKA x 3 x W x H
# item["support"]: B x NK x 3 x W x H # item["support"]: B x NKA x 3 x W x H
# item["target"]: B x NK # item["target"]: B x NK
batch_size = item["target"].size(0) A = item["query"].size(1) // item["target"].size(1)
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1) image_size = item["query"].shape[-3:]
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) query_batch = extractor(item["query"].view([-1, *image_size])).view(batch_size, N * K, A, -1)
support_batch = extractor(item["support"].view([-1, *image_size])).view(batch_size, N, K, A, -1)
query_batch = torch.mean(query_batch, dim=-2)
support_batch = torch.mean(support_batch, dim=-2)
assert query_batch.shape[:2] == item["target"].shape[:2]
accs.append(evaluate(query_batch, item["target"], support_batch)) accs.append(evaluate(query_batch, item["target"], support_batch))
print(torch.tensor(accs).mean().item()) r = torch.tensor(accs).mean().item()
print(lmdb_path, r)
return r
if __name__ == '__main__': if __name__ == '__main__':
setup_seed(100) setup_seed(100)
defined_path = [ defined_path = [
"/data/few-shot/lmdb/dogs/data.lmdb", "/data/few-shot/lmdb256/dogs.lmdb",
"/data/few-shot/lmdb/flowers/data.lmdb", "/data/few-shot/lmdb256/flowers.lmdb",
"/data/few-shot/lmdb/256-object/data.lmdb", "/data/few-shot/lmdb256/256-object.lmdb",
"/data/few-shot/lmdb/dtd/data.lmdb", "/data/few-shot/lmdb256/dtd.lmdb",
"/data/few-shot/lmdb256/cars_train.lmdb",
"/data/few-shot/lmdb256/cub.lmdb",
] ]
parser = argparse.ArgumentParser(description="test") parser = argparse.ArgumentParser(description="test")
parser.add_argument('-i', "--import_path", required=True) parser.add_argument('-i', "--import_path", required=True)
args = parser.parse_args() args = parser.parse_args()
for path in defined_path: for path in defined_path:
print(path)
test(path, args.import_path) test(path, args.import_path)