Compare commits

...

2 Commits

Author SHA1 Message Date
3a72dcb5f0 change line ending 2020-07-20 11:02:39 +08:00
7d720c181b 1 2020-07-16 16:07:03 +08:00
4 changed files with 253 additions and 251 deletions

5
.gitignore vendored
View File

@ -1,2 +1,3 @@
*.pth *.pth
.idea/ .idea/
submit/

View File

@ -1,101 +1,108 @@
from scipy.io import loadmat from scipy.io import loadmat
import torch import torch
import lmdb import lmdb
import os import os
import pickle import pickle
from io import BytesIO from PIL import Image
from torch.utils.data import Dataset from io import BytesIO
from torchvision.datasets.folder import default_loader from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder from torchvision.datasets.folder import default_loader
from pathlib import Path from torchvision.datasets import ImageFolder
from collections import defaultdict from pathlib import Path
from collections import defaultdict
class CARS(Dataset):
def __init__(self, root, loader=default_loader, transform=None): class CARS(Dataset):
self.root = Path(root) def __init__(self, root, loader=default_loader, transform=None):
self.transform = transform self.root = Path(root)
self.loader = loader self.transform = transform
self.annotations = loadmat(self.root / "devkit/cars_train_annos.mat")["annotations"][0] self.loader = loader
self.annotations = {d[-1].item(): d[-2].item() - 1 for d in self.annotations} self.annotations = loadmat(self.root / "devkit/cars_train_annos.mat")["annotations"][0]
self.classes_list = defaultdict(list) self.annotations = {d[-1].item(): d[-2].item() - 1 for d in self.annotations}
for i in range(len(self.annotations)): self.classes_list = defaultdict(list)
self.classes_list[self.annotations["{:05d}.jpg".format(i + 1)]].append(i) for i in range(len(self.annotations)):
self.classes_list[self.annotations["{:05d}.jpg".format(i + 1)]].append(i)
def __len__(self):
return len(self.annotations) def __len__(self):
return len(self.annotations)
def __getitem__(self, item):
file_name = "{:05d}.jpg".format(item + 1) def __getitem__(self, item):
target = self.annotations[file_name] file_name = "{:05d}.jpg".format(item + 1)
sample = self.loader(self.root / "cars_train" / file_name) target = self.annotations[file_name]
if self.transform is not None: sample = self.loader(self.root / "cars_train" / file_name)
sample = self.transform(sample) if self.transform is not None:
return sample sample = self.transform(sample)
return sample
class ImprovedImageFolder(ImageFolder):
def __init__(self, root, loader=default_loader, transform=None): class ImprovedImageFolder(ImageFolder):
super().__init__(root, transform, loader=loader) def __init__(self, root, loader=default_loader, transform=None):
self.classes_list = defaultdict(list) super().__init__(root, transform, loader=loader)
for i in range(len(self)): self.classes_list = defaultdict(list)
self.classes_list[self.samples[i][-1]].append(i) for i in range(len(self)):
assert len(self.classes_list) == len(self.classes) self.classes_list[self.samples[i][-1]].append(i)
assert len(self.classes_list) == len(self.classes)
def __getitem__(self, item):
return super().__getitem__(item)[0] def __getitem__(self, item):
return super().__getitem__(item)[0]
class LMDBDataset(Dataset):
def __init__(self, lmdb_path, transform=None): class LMDBDataset(Dataset):
self.db = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path), readonly=True, def __init__(self, lmdb_path, transform=None):
lock=False, readahead=False, meminit=False) self.db = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path), readonly=True,
self.transform = transform lock=False, readahead=False, meminit=False)
with self.db.begin(write=False) as txn: self.transform = transform
self.classes_list = pickle.loads(txn.get(b"classes_list")) with self.db.begin(write=False) as txn:
self._len = pickle.loads(txn.get(b"__len__")) self.classes_list = pickle.loads(txn.get(b"classes_list"))
self._len = pickle.loads(txn.get(b"__len__"))
def __len__(self):
return self._len def __len__(self):
return self._len
def __getitem__(self, i):
with self.db.begin(write=False) as txn: def __getitem__(self, i):
sample = torch.load(BytesIO(txn.get("{}".format(i).encode()))) with self.db.begin(write=False) as txn:
if self.transform is not None: sample = Image.open(BytesIO(txn.get("{}".format(i).encode())))
sample = self.transform(sample) if sample.mode != "RGB":
return sample sample = sample.convert("RGB")
if self.transform is not None:
try:
class EpisodicDataset(Dataset): sample = self.transform(sample)
def __init__(self, origin_dataset, num_class, num_set, num_episodes): except RuntimeError as re:
self.origin = origin_dataset print(sample.format, sample.size, sample.mode)
self.num_class = num_class raise re
assert self.num_class < len(self.origin.classes_list) return sample
self.num_set = num_set # K
self.num_episodes = num_episodes
class EpisodicDataset(Dataset):
def __len__(self): def __init__(self, origin_dataset, num_class, num_set, num_episodes):
return self.num_episodes self.origin = origin_dataset
self.num_class = num_class
def __getitem__(self, _): assert self.num_class < len(self.origin.classes_list)
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() self.num_set = num_set # K
support_set_list = [] self.num_episodes = num_episodes
query_set_list = []
target_list = [] def __len__(self):
for i, c in enumerate(random_classes): return self.num_episodes
image_list = self.origin.classes_list[c]
if len(image_list) > self.num_set * 2: def __getitem__(self, _):
idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist() random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
else: support_set_list = []
idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist() query_set_list = []
support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]]) target_list = []
query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]]) for i, c in enumerate(random_classes):
target_list.extend([i] * self.num_set) image_list = self.origin.classes_list[c]
return { if len(image_list) > self.num_set * 2:
"support": torch.stack(support_set_list), idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist()
"query": torch.stack(query_set_list), else:
"target": torch.tensor(target_list) idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist()
} support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]])
query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]])
def __repr__(self): target_list.extend([i] * self.num_set)
return "<EpisodicDataset N={} K={} NUM={}>".format(self.num_class, self.num_set, self.num_episodes) return {
"support": torch.stack(support_set_list),
"query": torch.stack(query_set_list),
"target": torch.tensor(target_list)
}
def __repr__(self):
return "<EpisodicDataset N={} K={} NUM={}>".format(self.num_class, self.num_set, self.num_episodes)

View File

@ -1,43 +1,38 @@
import os import os
import pickle import pickle
from io import BytesIO from io import BytesIO
import argparse import argparse
import torch import lmdb
import lmdb from data.dataset import CARS, ImprovedImageFolder
from data.dataset import CARS, ImprovedImageFolder from tqdm import tqdm
import torchvision
from tqdm import tqdm
def content_loader(path):
with open(path, "rb") as f:
def dataset_to_lmdb(dataset, lmdb_path): return f.read()
env = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=50): def dataset_to_lmdb(dataset, lmdb_path):
buffer = BytesIO() env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path))
torch.save(dataset[i], buffer) with env.begin(write=True) as txn:
txn.put("{}".format(i).encode(), buffer.getvalue()) for i in tqdm(range(len(dataset)), ncols=50):
txn.put(b"classes_list", pickle.dumps(dataset.classes_list)) txn.put("{}".format(i).encode(), bytearray(dataset[i]))
txn.put(b"__len__", pickle.dumps(len(dataset))) txn.put(b"classes_list", pickle.dumps(dataset.classes_list))
txn.put(b"__len__", pickle.dumps(len(dataset)))
def transform(save_path, dataset_path):
print(save_path, dataset_path) def transform(save_path, dataset_path):
dt = torchvision.transforms.Compose([ print(save_path, dataset_path)
torchvision.transforms.Resize((256, 256)), # origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", loader=content_loader)
torchvision.transforms.CenterCrop(224), origin_dataset = ImprovedImageFolder(dataset_path, loader=content_loader)
torchvision.transforms.ToTensor(), dataset_to_lmdb(origin_dataset, save_path)
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", transform=dt) if __name__ == '__main__':
origin_dataset = ImprovedImageFolder(dataset_path, transform=dt) parser = argparse.ArgumentParser(description="transform dataset to lmdb database")
dataset_to_lmdb(origin_dataset, save_path) parser.add_argument('--save', required=True)
parser.add_argument('--dataset', required=True)
args = parser.parse_args()
if __name__ == '__main__': transform(args.save, args.dataset)
parser = argparse.ArgumentParser(description="transform dataset to lmdb database")
parser.add_argument('--save', required=True)
parser.add_argument('--dataset', required=True)
args = parser.parse_args()
transform(args.save, args.dataset)

209
test.py
View File

@ -1,105 +1,104 @@
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms import torchvision
from data import dataset from data import dataset
import argparse import argparse
from ignite.utils import convert_tensor from ignite.utils import convert_tensor
import time import time
from importlib import import_module from importlib import import_module
from tqdm import tqdm from tqdm import tqdm
def setup_seed(seed): def setup_seed(seed):
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
def euclidean_dist(x, y): def euclidean_dist(x, y):
""" """
Compute euclidean distance between two tensors Compute euclidean distance between two tensors
""" """
# x: B x N x D # x: B x N x D
# y: B x M x D # y: B x M x D
n = x.size(-2) n = x.size(-2)
m = y.size(-2) m = y.size(-2)
d = x.size(-1) d = x.size(-1)
if d != y.size(-1): if d != y.size(-1):
raise Exception raise Exception
x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D
y = y.unsqueeze(1).expand(x.size(0), n, m, d) y = y.unsqueeze(1).expand(x.size(0), n, m, d)
return torch.pow(x - y, 2).sum(-1) return torch.pow(x - y, 2).sum(-1)
def evaluate(query, target, support): def evaluate(query, target, support):
""" """
:param query: B x NK x D vector :param query: B x NK x D vector
:param target: B x NK vector :param target: B x NK vector
:param support: B x N x K x D vector :param support: B x N x K x D vector
:return: :return:
""" """
prototypes = support.mean(-2) # B x N x D prototypes = support.mean(-2) # B x N x D
distance = euclidean_dist(query, prototypes) # B x NK x N distance = euclidean_dist(query, prototypes) # B x NK x N
indices = distance.argmin(-1) # B x NK indices = distance.argmin(-1) # B x NK
return torch.eq(target, indices).float().mean() return torch.eq(target, indices).float().mean()
def test(lmdb_path, import_path): def test(lmdb_path, import_path):
origin_dataset = dataset.LMDBDataset(lmdb_path) dt = torchvision.transforms.Compose([
N = 5 torchvision.transforms.Resize((256, 256)),
K = 5 torchvision.transforms.CenterCrop(224),
episodic_dataset = dataset.EpisodicDataset( torchvision.transforms.ToTensor(),
origin_dataset, # 抽取数据集 torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
N, # N ])
K, # K origin_dataset = dataset.LMDBDataset(lmdb_path, transform=dt)
100 # 任务数目 N = 5
) K = 5
print(episodic_dataset) episodic_dataset = dataset.EpisodicDataset(
origin_dataset, # 抽取数据集
data_loader = DataLoader(episodic_dataset, batch_size=16, pin_memory=False) N, # N
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") K, # K
100 # 任务数目
submit = import_module(f"submit.{import_path}") )
print(episodic_dataset)
extractor = submit.make_model()
extractor.to(device) data_loader = DataLoader(episodic_dataset, batch_size=20, pin_memory=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
accs = []
submit = import_module(f"submit.{import_path}")
load_st = time.time()
with torch.no_grad(): extractor = submit.make_model()
for item in data_loader: extractor.to(device)
st = time.time()
print("load", time.time() - load_st) accs = []
item = convert_tensor(item, device, non_blocking=True)
# item["query"]: B x NK x 3 x W x H with torch.no_grad():
# item["support"]: B x NK x 3 x W x H for item in tqdm(data_loader):
# item["target"]: B x NK item = convert_tensor(item, device, non_blocking=True)
batch_size = item["target"].size(0) # item["query"]: B x NK x 3 x W x H
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1) # item["support"]: B x NK x 3 x W x H
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) # item["target"]: B x NK
print("compute", time.time() - st) batch_size = item["target"].size(0)
load_st = time.time() query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1)
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
accs.append(evaluate(query_batch, item["target"], support_batch)) accs.append(evaluate(query_batch, item["target"], support_batch))
print(torch.tensor(accs).mean().item()) print(torch.tensor(accs).mean().item())
print("time: ", time.time() - st)
if __name__ == '__main__':
if __name__ == '__main__': setup_seed(100)
setup_seed(100) defined_path = [
defined_path = ["/data/few-shot/lmdb/mini-imagenet/val.lmdb", "/data/few-shot/lmdb/dogs/data.lmdb",
"/data/few-shot/lmdb/CUB_200_2011/data.lmdb", "/data/few-shot/lmdb/flowers/data.lmdb",
"/data/few-shot/lmdb/STANFORD-CARS/train.lmdb", "/data/few-shot/lmdb/256-object/data.lmdb",
# "/data/few-shot/lmdb/Plantae/data.lmdb", "/data/few-shot/lmdb/dtd/data.lmdb",
# "/data/few-shot/lmdb/Places365/val.lmdb" ]
] parser = argparse.ArgumentParser(description="test")
parser = argparse.ArgumentParser(description="test") parser.add_argument('-i', "--import_path", required=True)
parser.add_argument('-i', "--import_path", required=True) args = parser.parse_args()
args = parser.parse_args() for path in defined_path:
for path in defined_path: print(path)
print(path) test(path, args.import_path)
test(path, args.import_path)