Compare commits

...

2 Commits

Author SHA1 Message Date
3a72dcb5f0 change line ending 2020-07-20 11:02:39 +08:00
7d720c181b 1 2020-07-16 16:07:03 +08:00
4 changed files with 253 additions and 251 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
*.pth *.pth
.idea/ .idea/
submit/

View File

@ -3,6 +3,7 @@ import torch
import lmdb import lmdb
import os import os
import pickle import pickle
from PIL import Image
from io import BytesIO from io import BytesIO
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader from torchvision.datasets.folder import default_loader
@ -60,9 +61,15 @@ class LMDBDataset(Dataset):
def __getitem__(self, i): def __getitem__(self, i):
with self.db.begin(write=False) as txn: with self.db.begin(write=False) as txn:
sample = torch.load(BytesIO(txn.get("{}".format(i).encode()))) sample = Image.open(BytesIO(txn.get("{}".format(i).encode())))
if sample.mode != "RGB":
sample = sample.convert("RGB")
if self.transform is not None: if self.transform is not None:
sample = self.transform(sample) try:
sample = self.transform(sample)
except RuntimeError as re:
print(sample.format, sample.size, sample.mode)
raise re
return sample return sample

View File

@ -3,34 +3,29 @@ import pickle
from io import BytesIO from io import BytesIO
import argparse import argparse
import torch
import lmdb import lmdb
from data.dataset import CARS, ImprovedImageFolder from data.dataset import CARS, ImprovedImageFolder
import torchvision
from tqdm import tqdm from tqdm import tqdm
def content_loader(path):
with open(path, "rb") as f:
return f.read()
def dataset_to_lmdb(dataset, lmdb_path): def dataset_to_lmdb(dataset, lmdb_path):
env = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path)) env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn: with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=50): for i in tqdm(range(len(dataset)), ncols=50):
buffer = BytesIO() txn.put("{}".format(i).encode(), bytearray(dataset[i]))
torch.save(dataset[i], buffer)
txn.put("{}".format(i).encode(), buffer.getvalue())
txn.put(b"classes_list", pickle.dumps(dataset.classes_list)) txn.put(b"classes_list", pickle.dumps(dataset.classes_list))
txn.put(b"__len__", pickle.dumps(len(dataset))) txn.put(b"__len__", pickle.dumps(len(dataset)))
def transform(save_path, dataset_path): def transform(save_path, dataset_path):
print(save_path, dataset_path) print(save_path, dataset_path)
dt = torchvision.transforms.Compose([ # origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", loader=content_loader)
torchvision.transforms.Resize((256, 256)), origin_dataset = ImprovedImageFolder(dataset_path, loader=content_loader)
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", transform=dt)
origin_dataset = ImprovedImageFolder(dataset_path, transform=dt)
dataset_to_lmdb(origin_dataset, save_path) dataset_to_lmdb(origin_dataset, save_path)

33
test.py
View File

@ -1,6 +1,6 @@
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms import torchvision
from data import dataset from data import dataset
import argparse import argparse
@ -48,7 +48,13 @@ def evaluate(query, target, support):
def test(lmdb_path, import_path): def test(lmdb_path, import_path):
origin_dataset = dataset.LMDBDataset(lmdb_path) dt = torchvision.transforms.Compose([
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
origin_dataset = dataset.LMDBDataset(lmdb_path, transform=dt)
N = 5 N = 5
K = 5 K = 5
episodic_dataset = dataset.EpisodicDataset( episodic_dataset = dataset.EpisodicDataset(
@ -59,7 +65,7 @@ def test(lmdb_path, import_path):
) )
print(episodic_dataset) print(episodic_dataset)
data_loader = DataLoader(episodic_dataset, batch_size=16, pin_memory=False) data_loader = DataLoader(episodic_dataset, batch_size=20, pin_memory=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
submit = import_module(f"submit.{import_path}") submit = import_module(f"submit.{import_path}")
@ -69,11 +75,8 @@ def test(lmdb_path, import_path):
accs = [] accs = []
load_st = time.time()
with torch.no_grad(): with torch.no_grad():
for item in data_loader: for item in tqdm(data_loader):
st = time.time()
print("load", time.time() - load_st)
item = convert_tensor(item, device, non_blocking=True) item = convert_tensor(item, device, non_blocking=True)
# item["query"]: B x NK x 3 x W x H # item["query"]: B x NK x 3 x W x H
# item["support"]: B x NK x 3 x W x H # item["support"]: B x NK x 3 x W x H
@ -81,21 +84,17 @@ def test(lmdb_path, import_path):
batch_size = item["target"].size(0) batch_size = item["target"].size(0)
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1) query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1)
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
print("compute", time.time() - st) accs.append(evaluate(query_batch, item["target"], support_batch))
load_st = time.time()
accs.append(evaluate(query_batch, item["target"], support_batch))
print(torch.tensor(accs).mean().item()) print(torch.tensor(accs).mean().item())
print("time: ", time.time() - st)
if __name__ == '__main__': if __name__ == '__main__':
setup_seed(100) setup_seed(100)
defined_path = ["/data/few-shot/lmdb/mini-imagenet/val.lmdb", defined_path = [
"/data/few-shot/lmdb/CUB_200_2011/data.lmdb", "/data/few-shot/lmdb/dogs/data.lmdb",
"/data/few-shot/lmdb/STANFORD-CARS/train.lmdb", "/data/few-shot/lmdb/flowers/data.lmdb",
# "/data/few-shot/lmdb/Plantae/data.lmdb", "/data/few-shot/lmdb/256-object/data.lmdb",
# "/data/few-shot/lmdb/Places365/val.lmdb" "/data/few-shot/lmdb/dtd/data.lmdb",
] ]
parser = argparse.ArgumentParser(description="test") parser = argparse.ArgumentParser(description="test")
parser.add_argument('-i', "--import_path", required=True) parser.add_argument('-i', "--import_path", required=True)