This commit is contained in:
Ray Wong 2020-07-07 19:18:17 +08:00
parent a89f9226e8
commit 07c63abb30
3 changed files with 85 additions and 76 deletions

View File

@ -1,5 +1,9 @@
from scipy.io import loadmat from scipy.io import loadmat
import torch import torch
import lmdb
import os
import pickle
from io import BytesIO
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader from torchvision.datasets.folder import default_loader
from torchvision.datasets import ImageFolder from torchvision.datasets import ImageFolder
@ -22,7 +26,7 @@ class CARS(Dataset):
return len(self.annotations) return len(self.annotations)
def __getitem__(self, item): def __getitem__(self, item):
file_name = "{:05d}.jpg".format(item+1) file_name = "{:05d}.jpg".format(item + 1)
target = self.annotations[file_name] target = self.annotations[file_name]
sample = self.loader(self.root / "cars_train" / file_name) sample = self.loader(self.root / "cars_train" / file_name)
if self.transform is not None: if self.transform is not None:
@ -41,6 +45,22 @@ class ImprovedImageFolder(ImageFolder):
return super().__getitem__(item)[0] return super().__getitem__(item)[0]
class LMDBDataset(Dataset):
def __init__(self, lmdb_path):
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), readonly=True, lock=False,
readahead=False, meminit=False)
with self.db.begin(write=False) as txn:
self.classes_list = pickle.loads(txn.get(b"classes_list"))
self._len = pickle.loads(txn.get(b"__len__"))
def __len__(self):
return self._len
def __getitem__(self, i):
with self.db.begin(write=False) as txn:
return torch.load(BytesIO(txn.get("{}".format(i).encode())))
class EpisodicDataset(Dataset): class EpisodicDataset(Dataset):
def __init__(self, origin_dataset, num_class, num_set, num_episodes): def __init__(self, origin_dataset, num_class, num_set, num_episodes):
self.origin = origin_dataset self.origin = origin_dataset
@ -60,12 +80,12 @@ class EpisodicDataset(Dataset):
for i, c in enumerate(random_classes): for i, c in enumerate(random_classes):
image_list = self.origin.classes_list[c] image_list = self.origin.classes_list[c]
if len(image_list) > self.num_set * 2: if len(image_list) > self.num_set * 2:
idx_list = torch.randperm(len(image_list))[:self.num_set*2].tolist() idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist()
else: else:
idx_list = torch.randint(high=len(image_list), size=(self.num_set*2,)).tolist() idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist()
support_set_list.extend([self.origin[idx] for idx in idx_list[:self.num_set]]) support_set_list.extend([self.origin[idx] for idx in idx_list[:self.num_set]])
query_set_list.extend([self.origin[idx] for idx in idx_list[self.num_set:]]) query_set_list.extend([self.origin[idx] for idx in idx_list[self.num_set:]])
target_list.extend([i]*self.num_set) target_list.extend([i] * self.num_set)
return { return {
"support": torch.stack(support_set_list), "support": torch.stack(support_set_list),
"query": torch.stack(query_set_list), "query": torch.stack(query_set_list),

35
data/lmdbify.py Executable file
View File

@ -0,0 +1,35 @@
import torch
import lmdb
import os
import pickle
from io import BytesIO
from data.dataset import CARS, ImprovedImageFolder
import torchvision
from tqdm import tqdm
def dataset_to_lmdb(dataset, lmdb_path):
env = lmdb.open(lmdb_path, map_size=1099511627776 * 2, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset))):
buffer = BytesIO()
torch.save(dataset[i], buffer)
txn.put("{}".format(i).encode(), buffer.getvalue())
txn.put(b"classes_list", pickle.dumps(dataset.classes_list))
txn.put(b"__len__", pickle.dumps(len(dataset)))
def main():
data_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize([int(224 * 1.15), int(224 * 1.15)]),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
origin_dataset = ImprovedImageFolder("/data/few-shot/CUB_200_2011/CUB_200_2011/images", transform=data_transform)
dataset_to_lmdb(origin_dataset, "/data/few-shot/lmdb/CUB_200_2011/data.lmdb")
if __name__ == '__main__':
main()

98
test.py
View File

@ -1,10 +1,11 @@
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torchvision
from data import dataset from data import dataset
import torch.nn as nn
from ignite.utils import convert_tensor from ignite.utils import convert_tensor
import time import time
from tqdm import tqdm
def setup_seed(seed): def setup_seed(seed):
@ -44,67 +45,11 @@ def evaluate(query, target, support):
return torch.eq(target, indices).float().mean() return torch.eq(target, indices).float().mean()
class Flatten(nn.Module): def test(lmdb_path):
def __init__(self): origin_dataset = dataset.LMDBDataset(lmdb_path)
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.size(0), -1)
def make_extractor():
resnet50 = torchvision.models.resnet50(pretrained=True)
resnet50.to(torch.device("cuda"))
resnet50.fc = torch.nn.Identity()
resnet50.eval()
def extract(images):
with torch.no_grad():
return resnet50(images)
return extract
# def make_extractor():
# model = resnet18()
# model.to(torch.device("cuda"))
# model.eval()
#
# def extract(images):
# with torch.no_grad():
# return model(images)
# return extract
def resnet18(model_path="ResNet18Official.pth"):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_w_fc = torchvision.models.resnet18(pretrained=False)
seq = list(model_w_fc.children())[:-1]
seq.append(Flatten())
model = torch.nn.Sequential(*seq)
# model.load_state_dict(torch.load(model_path), strict=False)
model.load_state_dict(torch.load(model_path, map_location ='cpu'), strict=False)
# model.load_state_dict(torch.load(model_path))
model.eval()
return model
def test():
data_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize([int(224*1.15), int(224*1.15)]),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
origin_dataset = dataset.CARS("/data/few-shot/STANFORD-CARS/", transform=data_transform)
#origin_dataset = dataset.ImprovedImageFolder("/data/few-shot/mini_imagenet_full_size/train", transform=data_transform)
N = torch.randint(5, 10, (1,)).tolist()[0] N = torch.randint(5, 10, (1,)).tolist()[0]
K = torch.randint(1, 10, (1,)).tolist()[0] K = torch.randint(1, 10, (1,)).tolist()[0]
batch_size = 2
episodic_dataset = dataset.EpisodicDataset( episodic_dataset = dataset.EpisodicDataset(
origin_dataset, # 抽取数据集 origin_dataset, # 抽取数据集
N, # N N, # N
@ -113,25 +58,34 @@ def test():
) )
print(episodic_dataset) print(episodic_dataset)
data_loader = DataLoader(episodic_dataset, batch_size=batch_size, pin_memory=True) data_loader = DataLoader(episodic_dataset, batch_size=4, pin_memory=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
extractor = make_extractor()
from submit import make_model
extractor = make_model()
extractor.to(device)
accs = [] accs = []
st = time.time() st = time.time()
for item in data_loader: with torch.no_grad():
item = convert_tensor(item, device, non_blocking=True) for item in tqdm(data_loader, nrows=80):
# item["query"]: B x NK x 3 x W x H item = convert_tensor(item, device, non_blocking=True)
# item["support"]: B x NK x 3 x W x H # item["query"]: B x NK x 3 x W x H
# item["target"]: B x NK # item["support"]: B x NK x 3 x W x H
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N*K, -1) # item["target"]: B x NK
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) batch_size = item["target"].size(0)
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1)
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
accs.append(evaluate(query_batch, item["target"], support_batch)) accs.append(evaluate(query_batch, item["target"], support_batch))
print("time: ", time.time()-st)
st = time.time()
print(torch.tensor(accs).mean().item()) print(torch.tensor(accs).mean().item())
print("time: ", time.time() - st)
if __name__ == '__main__': if __name__ == '__main__':
setup_seed(100) setup_seed(100)
test() for path in ["/data/few-shot/lmdb/CUB_200_2011/data.lmdb",
"/data/few-shot/lmdb/mini-imagenet/train.lmdb",
"/data/few-shot/lmdb/STANFORD-CARS/train.lmdb"]:
print(path)
test(path)