Compare commits

..

1 Commits

Author SHA1 Message Date
57ad9a2572 test v2 2020-07-29 00:03:16 +08:00
5 changed files with 42 additions and 82 deletions

View File

@ -1,20 +1,17 @@
from scipy.io import loadmat from scipy.io import loadmat
import torch import torch
import torchvision
import lmdb import lmdb
import os import os
import pickle import pickle
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader from torchvision.datasets.folder import default_loader
from torchvision.datasets import ImageFolder from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.transforms import functional
from pathlib import Path from pathlib import Path
from collections import defaultdict from collections import defaultdict
class _CARS(Dataset): class CARS(Dataset):
def __init__(self, root, loader=default_loader, transform=None): def __init__(self, root, loader=default_loader, transform=None):
self.root = Path(root) self.root = Path(root)
self.transform = transform self.transform = transform
@ -34,7 +31,7 @@ class _CARS(Dataset):
sample = self.loader(self.root / "cars_train" / file_name) sample = self.loader(self.root / "cars_train" / file_name)
if self.transform is not None: if self.transform is not None:
sample = self.transform(sample) sample = self.transform(sample)
return sample, target return sample
class ImprovedImageFolder(ImageFolder): class ImprovedImageFolder(ImageFolder):
@ -77,23 +74,19 @@ class EpisodicDataset(Dataset):
self.num_set = num_set # K self.num_set = num_set # K
self.num_episodes = num_episodes self.num_episodes = num_episodes
self.t0 = transforms.Compose([ self.t0 = torchvision.transforms.Compose([
transforms.Resize((224, 224)), # torchvision.transforms.Resize((224, 224)),
transforms.ToTensor(), torchvision.transforms.CenterCrop(224),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) torchvision.transforms.ToTensor(),
]) torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]) ])
def apply_transform(self, img): def _fetch_list_data(self, id_list):
# img1 = self.transform(img) result = []
# img2 = self.transform(img) for i in id_list:
# return [self.t0(img), self.t0(functional.hflip(img))] img = self.origin[i][0]
return [self.t0(img)] result.extend([self.t0(img)])
return result
def __len__(self): def __len__(self):
return self.num_episodes return self.num_episodes
@ -103,18 +96,20 @@ class EpisodicDataset(Dataset):
support_set_list = [] support_set_list = []
query_set_list = [] query_set_list = []
target_list = [] target_list = []
for i, c in enumerate(random_classes): for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c] image_list = self.origin.classes_list[c]
if len(image_list) > self.num_set * 2:
if len(image_list) >= self.num_set * 2:
# have enough images belong to this class
idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist() idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist()
else: else:
idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist() idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist()
support = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]]
query = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]]
support_set_list.extend(sum(map(self.apply_transform, support), list())) support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_set]))
query_set_list.extend(sum(map(self.apply_transform, query), list())) query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_set:]))
target_list.extend([i] * self.num_set) support_set_list.extend(support)
query_set_list.extend(query)
target_list.extend([tag] * self.num_set)
return { return {
"support": torch.stack(support_set_list), "support": torch.stack(support_set_list),
"query": torch.stack(query_set_list), "query": torch.stack(query_set_list),

View File

@ -1,10 +1,10 @@
import os import os
import pickle import pickle
import argparse
from PIL import Image from PIL import Image
import lmdb import lmdb
from data.dataset import ImprovedImageFolder from data.dataset import ImprovedImageFolder
from tqdm import tqdm from tqdm import tqdm
import fire
def content_loader(path): def content_loader(path):
@ -31,9 +31,4 @@ def transform(save_path, dataset_path):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description="transform dataset to lmdb database") fire.Fire(transform)
parser.add_argument('--save', required=True)
parser.add_argument('--dataset', required=True)
args = parser.parse_args()
transform(args.save, args.dataset)

View File

View File

@ -1,19 +0,0 @@
import torch
def euclidean_dist(x, y):
"""
Compute euclidean distance between two tensors
"""
# x: B x N x D
# y: B x M x D
n = x.size(-2)
m = y.size(-2)
d = x.size(-1)
if d != y.size(-1):
raise Exception
x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D
y = y.unsqueeze(1).expand(x.size(0), n, m, d)
return torch.pow(x - y, 2).sum(-1)

45
test.py
View File

@ -48,23 +48,8 @@ def evaluate(query, target, support):
def test(lmdb_path, import_path): def test(lmdb_path, import_path):
dt = torchvision.transforms.Compose([
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
origin_dataset = dataset.LMDBDataset(lmdb_path, transform=None) origin_dataset = dataset.LMDBDataset(lmdb_path, transform=None)
N = 5
K = 5
episodic_dataset = dataset.EpisodicDataset(
origin_dataset, # 抽取数据集
N, # N
K, # K
100 # 任务数目
)
print(episodic_dataset)
data_loader = DataLoader(episodic_dataset, batch_size=8, pin_memory=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
submit = import_module(f"submit.{import_path}") submit = import_module(f"submit.{import_path}")
@ -72,26 +57,30 @@ def test(lmdb_path, import_path):
extractor = submit.make_model() extractor = submit.make_model()
extractor.to(device) extractor.to(device)
accs = [] batch_size = 10
N = 5
K = 5
episodic_dataset = dataset.EpisodicDataset(origin_dataset, N, K, 100)
data_loader = DataLoader(episodic_dataset, batch_size=batch_size, pin_memory=False)
with torch.no_grad(): with torch.no_grad():
accs = []
for item in tqdm(data_loader): for item in tqdm(data_loader):
item = convert_tensor(item, device, non_blocking=True) item = convert_tensor(item, device, non_blocking=True)
# item["query"]: B x ANK x 3 x W x H # item["query"]: B x NKA x 3 x W x H
# item["support"]: B x ANK x 3 x W x H # item["support"]: B x NKA x 3 x W x H
# item["target"]: B x NK # item["target"]: B x NK
batch_size = item["target"].size(0) A = item["query"].size(1) // item["target"].size(1)
image_size = item["query"].shape[-3:] image_size = item["query"].shape[-3:]
A = int(item["query"].size(1) / (N * K))
query_batch = extractor(item["query"].view([-1, *image_size])).view(batch_size, N * K, A, -1) query_batch = extractor(item["query"].view([-1, *image_size])).view(batch_size, N * K, A, -1)
support_batch = extractor(item["support"].view([-1, *image_size])).view(batch_size, N, K, A, -1) support_batch = extractor(item["support"].view([-1, *image_size])).view(batch_size, N, K, A, -1)
query_batch = torch.mean(query_batch, dim=-2)
query_batch = torch.mean(query_batch, -2) support_batch = torch.mean(support_batch, dim=-2)
support_batch = torch.mean(support_batch, -2) assert query_batch.shape[:2] == item["target"].shape[:2]
accs.append(evaluate(query_batch, item["target"], support_batch)) accs.append(evaluate(query_batch, item["target"], support_batch))
print(torch.tensor(accs).mean().item()) r = torch.tensor(accs).mean().item()
print(lmdb_path, r)
return r
if __name__ == '__main__': if __name__ == '__main__':
@ -101,11 +90,11 @@ if __name__ == '__main__':
"/data/few-shot/lmdb256/flowers.lmdb", "/data/few-shot/lmdb256/flowers.lmdb",
"/data/few-shot/lmdb256/256-object.lmdb", "/data/few-shot/lmdb256/256-object.lmdb",
"/data/few-shot/lmdb256/dtd.lmdb", "/data/few-shot/lmdb256/dtd.lmdb",
"/data/few-shot/lmdb256/mini-imagenet-test.lmdb" "/data/few-shot/lmdb256/cars_train.lmdb",
"/data/few-shot/lmdb256/cub.lmdb",
] ]
parser = argparse.ArgumentParser(description="test") parser = argparse.ArgumentParser(description="test")
parser.add_argument('-i', "--import_path", required=True) parser.add_argument('-i', "--import_path", required=True)
args = parser.parse_args() args = parser.parse_args()
for path in defined_path: for path in defined_path:
print(path)
test(path, args.import_path) test(path, args.import_path)