Compare commits

...

1 Commits

Author SHA1 Message Date
57ad9a2572 test v2 2020-07-29 00:03:16 +08:00
3 changed files with 66 additions and 62 deletions

View File

@ -1,10 +1,9 @@
from scipy.io import loadmat
import torch
import torchvision
import lmdb
import os
import pickle
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets import ImageFolder
@ -44,7 +43,7 @@ class ImprovedImageFolder(ImageFolder):
assert len(self.classes_list) == len(self.classes)
def __getitem__(self, item):
return super().__getitem__(item)[0]
return super().__getitem__(item)
class LMDBDataset(Dataset):
@ -61,16 +60,10 @@ class LMDBDataset(Dataset):
def __getitem__(self, i):
with self.db.begin(write=False) as txn:
sample = Image.open(BytesIO(txn.get("{}".format(i).encode())))
if sample.mode != "RGB":
sample = sample.convert("RGB")
sample, target = pickle.loads(txn.get("{}".format(i).encode()))
if self.transform is not None:
try:
sample = self.transform(sample)
except RuntimeError as re:
print(sample.format, sample.size, sample.mode)
raise re
return sample
sample = self.transform(sample)
return sample, target
class EpisodicDataset(Dataset):
@ -81,6 +74,20 @@ class EpisodicDataset(Dataset):
self.num_set = num_set # K
self.num_episodes = num_episodes
self.t0 = torchvision.transforms.Compose([
# torchvision.transforms.Resize((224, 224)),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def _fetch_list_data(self, id_list):
result = []
for i in id_list:
img = self.origin[i][0]
result.extend([self.t0(img)])
return result
def __len__(self):
return self.num_episodes
@ -89,15 +96,20 @@ class EpisodicDataset(Dataset):
support_set_list = []
query_set_list = []
target_list = []
for i, c in enumerate(random_classes):
for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c]
if len(image_list) > self.num_set * 2:
if len(image_list) >= self.num_set * 2:
# have enough images belong to this class
idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist()
support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]])
query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]])
target_list.extend([i] * self.num_set)
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_set]))
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_set:]))
support_set_list.extend(support)
query_set_list.extend(query)
target_list.extend([tag] * self.num_set)
return {
"support": torch.stack(support_set_list),
"query": torch.stack(query_set_list),

View File

@ -1,38 +1,34 @@
import os
import pickle
from io import BytesIO
import argparse
from PIL import Image
import lmdb
from data.dataset import CARS, ImprovedImageFolder
from data.dataset import ImprovedImageFolder
from tqdm import tqdm
import fire
def content_loader(path):
with open(path, "rb") as f:
return f.read()
im = Image.open(path)
im = im.resize((256, 256))
if im.mode != "RGB":
im = im.convert("RGB")
return im
def dataset_to_lmdb(dataset, lmdb_path):
env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=50):
txn.put("{}".format(i).encode(), bytearray(dataset[i]))
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"classes_list", pickle.dumps(dataset.classes_list))
txn.put(b"__len__", pickle.dumps(len(dataset)))
def transform(save_path, dataset_path):
print(save_path, dataset_path)
# origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", loader=content_loader)
origin_dataset = ImprovedImageFolder(dataset_path, loader=content_loader)
dataset_to_lmdb(origin_dataset, save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="transform dataset to lmdb database")
parser.add_argument('--save', required=True)
parser.add_argument('--dataset', required=True)
args = parser.parse_args()
transform(args.save, args.dataset)
fire.Fire(transform)

58
test.py
View File

@ -48,57 +48,53 @@ def evaluate(query, target, support):
def test(lmdb_path, import_path):
dt = torchvision.transforms.Compose([
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
origin_dataset = dataset.LMDBDataset(lmdb_path, transform=dt)
N = 5
K = 5
episodic_dataset = dataset.EpisodicDataset(
origin_dataset, # 抽取数据集
N, # N
K, # K
100 # 任务数目
)
print(episodic_dataset)
origin_dataset = dataset.LMDBDataset(lmdb_path, transform=None)
data_loader = DataLoader(episodic_dataset, batch_size=20, pin_memory=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
submit = import_module(f"submit.{import_path}")
extractor = submit.make_model()
extractor.to(device)
accs = []
batch_size = 10
N = 5
K = 5
episodic_dataset = dataset.EpisodicDataset(origin_dataset, N, K, 100)
data_loader = DataLoader(episodic_dataset, batch_size=batch_size, pin_memory=False)
with torch.no_grad():
accs = []
for item in tqdm(data_loader):
item = convert_tensor(item, device, non_blocking=True)
# item["query"]: B x NK x 3 x W x H
# item["support"]: B x NK x 3 x W x H
# item["query"]: B x NKA x 3 x W x H
# item["support"]: B x NKA x 3 x W x H
# item["target"]: B x NK
batch_size = item["target"].size(0)
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1)
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
A = item["query"].size(1) // item["target"].size(1)
image_size = item["query"].shape[-3:]
query_batch = extractor(item["query"].view([-1, *image_size])).view(batch_size, N * K, A, -1)
support_batch = extractor(item["support"].view([-1, *image_size])).view(batch_size, N, K, A, -1)
query_batch = torch.mean(query_batch, dim=-2)
support_batch = torch.mean(support_batch, dim=-2)
assert query_batch.shape[:2] == item["target"].shape[:2]
accs.append(evaluate(query_batch, item["target"], support_batch))
print(torch.tensor(accs).mean().item())
r = torch.tensor(accs).mean().item()
print(lmdb_path, r)
return r
if __name__ == '__main__':
setup_seed(100)
defined_path = [
"/data/few-shot/lmdb/dogs/data.lmdb",
"/data/few-shot/lmdb/flowers/data.lmdb",
"/data/few-shot/lmdb/256-object/data.lmdb",
"/data/few-shot/lmdb/dtd/data.lmdb",
]
"/data/few-shot/lmdb256/dogs.lmdb",
"/data/few-shot/lmdb256/flowers.lmdb",
"/data/few-shot/lmdb256/256-object.lmdb",
"/data/few-shot/lmdb256/dtd.lmdb",
"/data/few-shot/lmdb256/cars_train.lmdb",
"/data/few-shot/lmdb256/cub.lmdb",
]
parser = argparse.ArgumentParser(description="test")
parser.add_argument('-i', "--import_path", required=True)
args = parser.parse_args()
for path in defined_path:
print(path)
test(path, args.import_path)