This commit is contained in:
Ray Wong 2020-07-23 22:32:28 +08:00
parent 3a72dcb5f0
commit ead93c1b0e
5 changed files with 78 additions and 34 deletions

View File

@ -8,11 +8,13 @@ from io import BytesIO
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader from torchvision.datasets.folder import default_loader
from torchvision.datasets import ImageFolder from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.transforms import functional
from pathlib import Path from pathlib import Path
from collections import defaultdict from collections import defaultdict
class CARS(Dataset): class _CARS(Dataset):
def __init__(self, root, loader=default_loader, transform=None): def __init__(self, root, loader=default_loader, transform=None):
self.root = Path(root) self.root = Path(root)
self.transform = transform self.transform = transform
@ -32,7 +34,7 @@ class CARS(Dataset):
sample = self.loader(self.root / "cars_train" / file_name) sample = self.loader(self.root / "cars_train" / file_name)
if self.transform is not None: if self.transform is not None:
sample = self.transform(sample) sample = self.transform(sample)
return sample return sample, target
class ImprovedImageFolder(ImageFolder): class ImprovedImageFolder(ImageFolder):
@ -44,7 +46,7 @@ class ImprovedImageFolder(ImageFolder):
assert len(self.classes_list) == len(self.classes) assert len(self.classes_list) == len(self.classes)
def __getitem__(self, item): def __getitem__(self, item):
return super().__getitem__(item)[0] return super().__getitem__(item)
class LMDBDataset(Dataset): class LMDBDataset(Dataset):
@ -61,16 +63,10 @@ class LMDBDataset(Dataset):
def __getitem__(self, i): def __getitem__(self, i):
with self.db.begin(write=False) as txn: with self.db.begin(write=False) as txn:
sample = Image.open(BytesIO(txn.get("{}".format(i).encode()))) sample, target = pickle.loads(txn.get("{}".format(i).encode()))
if sample.mode != "RGB":
sample = sample.convert("RGB")
if self.transform is not None: if self.transform is not None:
try:
sample = self.transform(sample) sample = self.transform(sample)
except RuntimeError as re: return sample, target
print(sample.format, sample.size, sample.mode)
raise re
return sample
class EpisodicDataset(Dataset): class EpisodicDataset(Dataset):
@ -81,6 +77,24 @@ class EpisodicDataset(Dataset):
self.num_set = num_set # K self.num_set = num_set # K
self.num_episodes = num_episodes self.num_episodes = num_episodes
self.t0 = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def apply_transform(self, img):
# img1 = self.transform(img)
# img2 = self.transform(img)
# return [self.t0(img), self.t0(functional.hflip(img))]
return [self.t0(img)]
def __len__(self): def __len__(self):
return self.num_episodes return self.num_episodes
@ -95,8 +109,11 @@ class EpisodicDataset(Dataset):
idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist() idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist()
else: else:
idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist() idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist()
support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]]) support = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]]
query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]]) query = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]]
support_set_list.extend(sum(map(self.apply_transform, support), list()))
query_set_list.extend(sum(map(self.apply_transform, query), list()))
target_list.extend([i] * self.num_set) target_list.extend([i] * self.num_set)
return { return {
"support": torch.stack(support_set_list), "support": torch.stack(support_set_list),

View File

@ -1,30 +1,31 @@
import os import os
import pickle import pickle
from io import BytesIO
import argparse import argparse
from PIL import Image
import lmdb import lmdb
from data.dataset import CARS, ImprovedImageFolder from data.dataset import ImprovedImageFolder
from tqdm import tqdm from tqdm import tqdm
def content_loader(path): def content_loader(path):
with open(path, "rb") as f: im = Image.open(path)
return f.read() im = im.resize((256, 256))
if im.mode != "RGB":
im = im.convert("RGB")
return im
def dataset_to_lmdb(dataset, lmdb_path): def dataset_to_lmdb(dataset, lmdb_path):
env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path)) env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn: with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=50): for i in tqdm(range(len(dataset)), ncols=50):
txn.put("{}".format(i).encode(), bytearray(dataset[i])) txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"classes_list", pickle.dumps(dataset.classes_list)) txn.put(b"classes_list", pickle.dumps(dataset.classes_list))
txn.put(b"__len__", pickle.dumps(len(dataset))) txn.put(b"__len__", pickle.dumps(len(dataset)))
def transform(save_path, dataset_path): def transform(save_path, dataset_path):
print(save_path, dataset_path) print(save_path, dataset_path)
# origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", loader=content_loader)
origin_dataset = ImprovedImageFolder(dataset_path, loader=content_loader) origin_dataset = ImprovedImageFolder(dataset_path, loader=content_loader)
dataset_to_lmdb(origin_dataset, save_path) dataset_to_lmdb(origin_dataset, save_path)

0
loss/__init__.py Executable file
View File

19
loss/prototypical.py Executable file
View File

@ -0,0 +1,19 @@
import torch
def euclidean_dist(x, y):
"""
Compute euclidean distance between two tensors
"""
# x: B x N x D
# y: B x M x D
n = x.size(-2)
m = y.size(-2)
d = x.size(-1)
if d != y.size(-1):
raise Exception
x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D
y = y.unsqueeze(1).expand(x.size(0), n, m, d)
return torch.pow(x - y, 2).sum(-1)

31
test.py
View File

@ -49,12 +49,11 @@ def evaluate(query, target, support):
def test(lmdb_path, import_path): def test(lmdb_path, import_path):
dt = torchvision.transforms.Compose([ dt = torchvision.transforms.Compose([
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.CenterCrop(224), torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(), torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]) ])
origin_dataset = dataset.LMDBDataset(lmdb_path, transform=dt) origin_dataset = dataset.LMDBDataset(lmdb_path, transform=None)
N = 5 N = 5
K = 5 K = 5
episodic_dataset = dataset.EpisodicDataset( episodic_dataset = dataset.EpisodicDataset(
@ -65,8 +64,8 @@ def test(lmdb_path, import_path):
) )
print(episodic_dataset) print(episodic_dataset)
data_loader = DataLoader(episodic_dataset, batch_size=20, pin_memory=False) data_loader = DataLoader(episodic_dataset, batch_size=8, pin_memory=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
submit = import_module(f"submit.{import_path}") submit = import_module(f"submit.{import_path}")
@ -78,12 +77,19 @@ def test(lmdb_path, import_path):
with torch.no_grad(): with torch.no_grad():
for item in tqdm(data_loader): for item in tqdm(data_loader):
item = convert_tensor(item, device, non_blocking=True) item = convert_tensor(item, device, non_blocking=True)
# item["query"]: B x NK x 3 x W x H # item["query"]: B x ANK x 3 x W x H
# item["support"]: B x NK x 3 x W x H # item["support"]: B x ANK x 3 x W x H
# item["target"]: B x NK # item["target"]: B x NK
batch_size = item["target"].size(0) batch_size = item["target"].size(0)
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1) image_size = item["query"].shape[-3:]
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) A = int(item["query"].size(1) / (N * K))
query_batch = extractor(item["query"].view([-1, *image_size])).view(batch_size, N * K, A, -1)
support_batch = extractor(item["support"].view([-1, *image_size])).view(batch_size, N, K, A, -1)
query_batch = torch.mean(query_batch, -2)
support_batch = torch.mean(support_batch, -2)
accs.append(evaluate(query_batch, item["target"], support_batch)) accs.append(evaluate(query_batch, item["target"], support_batch))
print(torch.tensor(accs).mean().item()) print(torch.tensor(accs).mean().item())
@ -91,10 +97,11 @@ def test(lmdb_path, import_path):
if __name__ == '__main__': if __name__ == '__main__':
setup_seed(100) setup_seed(100)
defined_path = [ defined_path = [
"/data/few-shot/lmdb/dogs/data.lmdb", "/data/few-shot/lmdb256/dogs.lmdb",
"/data/few-shot/lmdb/flowers/data.lmdb", "/data/few-shot/lmdb256/flowers.lmdb",
"/data/few-shot/lmdb/256-object/data.lmdb", "/data/few-shot/lmdb256/256-object.lmdb",
"/data/few-shot/lmdb/dtd/data.lmdb", "/data/few-shot/lmdb256/dtd.lmdb",
"/data/few-shot/lmdb256/mini-imagenet-test.lmdb"
] ]
parser = argparse.ArgumentParser(description="test") parser = argparse.ArgumentParser(description="test")
parser.add_argument('-i', "--import_path", required=True) parser.add_argument('-i', "--import_path", required=True)