Compare commits

...

1 Commits

Author SHA1 Message Date
ead93c1b0e test 2020-07-23 22:32:28 +08:00
5 changed files with 78 additions and 34 deletions

View File

@ -8,11 +8,13 @@ from io import BytesIO
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.transforms import functional
from pathlib import Path
from collections import defaultdict
class CARS(Dataset):
class _CARS(Dataset):
def __init__(self, root, loader=default_loader, transform=None):
self.root = Path(root)
self.transform = transform
@ -32,7 +34,7 @@ class CARS(Dataset):
sample = self.loader(self.root / "cars_train" / file_name)
if self.transform is not None:
sample = self.transform(sample)
return sample
return sample, target
class ImprovedImageFolder(ImageFolder):
@ -44,7 +46,7 @@ class ImprovedImageFolder(ImageFolder):
assert len(self.classes_list) == len(self.classes)
def __getitem__(self, item):
return super().__getitem__(item)[0]
return super().__getitem__(item)
class LMDBDataset(Dataset):
@ -61,16 +63,10 @@ class LMDBDataset(Dataset):
def __getitem__(self, i):
with self.db.begin(write=False) as txn:
sample = Image.open(BytesIO(txn.get("{}".format(i).encode())))
if sample.mode != "RGB":
sample = sample.convert("RGB")
sample, target = pickle.loads(txn.get("{}".format(i).encode()))
if self.transform is not None:
try:
sample = self.transform(sample)
except RuntimeError as re:
print(sample.format, sample.size, sample.mode)
raise re
return sample
sample = self.transform(sample)
return sample, target
class EpisodicDataset(Dataset):
@ -81,6 +77,24 @@ class EpisodicDataset(Dataset):
self.num_set = num_set # K
self.num_episodes = num_episodes
self.t0 = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def apply_transform(self, img):
# img1 = self.transform(img)
# img2 = self.transform(img)
# return [self.t0(img), self.t0(functional.hflip(img))]
return [self.t0(img)]
def __len__(self):
return self.num_episodes
@ -95,8 +109,11 @@ class EpisodicDataset(Dataset):
idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist()
support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]])
query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]])
support = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]]
query = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]]
support_set_list.extend(sum(map(self.apply_transform, support), list()))
query_set_list.extend(sum(map(self.apply_transform, query), list()))
target_list.extend([i] * self.num_set)
return {
"support": torch.stack(support_set_list),

View File

@ -1,30 +1,31 @@
import os
import pickle
from io import BytesIO
import argparse
from PIL import Image
import lmdb
from data.dataset import CARS, ImprovedImageFolder
from data.dataset import ImprovedImageFolder
from tqdm import tqdm
def content_loader(path):
with open(path, "rb") as f:
return f.read()
im = Image.open(path)
im = im.resize((256, 256))
if im.mode != "RGB":
im = im.convert("RGB")
return im
def dataset_to_lmdb(dataset, lmdb_path):
env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=50):
txn.put("{}".format(i).encode(), bytearray(dataset[i]))
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"classes_list", pickle.dumps(dataset.classes_list))
txn.put(b"__len__", pickle.dumps(len(dataset)))
def transform(save_path, dataset_path):
print(save_path, dataset_path)
# origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", loader=content_loader)
origin_dataset = ImprovedImageFolder(dataset_path, loader=content_loader)
dataset_to_lmdb(origin_dataset, save_path)

0
loss/__init__.py Executable file
View File

19
loss/prototypical.py Executable file
View File

@ -0,0 +1,19 @@
import torch
def euclidean_dist(x, y):
"""
Compute euclidean distance between two tensors
"""
# x: B x N x D
# y: B x M x D
n = x.size(-2)
m = y.size(-2)
d = x.size(-1)
if d != y.size(-1):
raise Exception
x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D
y = y.unsqueeze(1).expand(x.size(0), n, m, d)
return torch.pow(x - y, 2).sum(-1)

33
test.py
View File

@ -49,12 +49,11 @@ def evaluate(query, target, support):
def test(lmdb_path, import_path):
dt = torchvision.transforms.Compose([
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
origin_dataset = dataset.LMDBDataset(lmdb_path, transform=dt)
origin_dataset = dataset.LMDBDataset(lmdb_path, transform=None)
N = 5
K = 5
episodic_dataset = dataset.EpisodicDataset(
@ -65,8 +64,8 @@ def test(lmdb_path, import_path):
)
print(episodic_dataset)
data_loader = DataLoader(episodic_dataset, batch_size=20, pin_memory=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_loader = DataLoader(episodic_dataset, batch_size=8, pin_memory=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
submit = import_module(f"submit.{import_path}")
@ -78,12 +77,19 @@ def test(lmdb_path, import_path):
with torch.no_grad():
for item in tqdm(data_loader):
item = convert_tensor(item, device, non_blocking=True)
# item["query"]: B x NK x 3 x W x H
# item["support"]: B x NK x 3 x W x H
# item["query"]: B x ANK x 3 x W x H
# item["support"]: B x ANK x 3 x W x H
# item["target"]: B x NK
batch_size = item["target"].size(0)
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1)
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
image_size = item["query"].shape[-3:]
A = int(item["query"].size(1) / (N * K))
query_batch = extractor(item["query"].view([-1, *image_size])).view(batch_size, N * K, A, -1)
support_batch = extractor(item["support"].view([-1, *image_size])).view(batch_size, N, K, A, -1)
query_batch = torch.mean(query_batch, -2)
support_batch = torch.mean(support_batch, -2)
accs.append(evaluate(query_batch, item["target"], support_batch))
print(torch.tensor(accs).mean().item())
@ -91,11 +97,12 @@ def test(lmdb_path, import_path):
if __name__ == '__main__':
setup_seed(100)
defined_path = [
"/data/few-shot/lmdb/dogs/data.lmdb",
"/data/few-shot/lmdb/flowers/data.lmdb",
"/data/few-shot/lmdb/256-object/data.lmdb",
"/data/few-shot/lmdb/dtd/data.lmdb",
]
"/data/few-shot/lmdb256/dogs.lmdb",
"/data/few-shot/lmdb256/flowers.lmdb",
"/data/few-shot/lmdb256/256-object.lmdb",
"/data/few-shot/lmdb256/dtd.lmdb",
"/data/few-shot/lmdb256/mini-imagenet-test.lmdb"
]
parser = argparse.ArgumentParser(description="test")
parser.add_argument('-i', "--import_path", required=True)
args = parser.parse_args()