few-shot/data/dataset.py
2020-07-23 22:32:28 +08:00

126 lines
4.8 KiB
Python
Executable File

from scipy.io import loadmat
import torch
import lmdb
import os
import pickle
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.transforms import functional
from pathlib import Path
from collections import defaultdict
class _CARS(Dataset):
def __init__(self, root, loader=default_loader, transform=None):
self.root = Path(root)
self.transform = transform
self.loader = loader
self.annotations = loadmat(self.root / "devkit/cars_train_annos.mat")["annotations"][0]
self.annotations = {d[-1].item(): d[-2].item() - 1 for d in self.annotations}
self.classes_list = defaultdict(list)
for i in range(len(self.annotations)):
self.classes_list[self.annotations["{:05d}.jpg".format(i + 1)]].append(i)
def __len__(self):
return len(self.annotations)
def __getitem__(self, item):
file_name = "{:05d}.jpg".format(item + 1)
target = self.annotations[file_name]
sample = self.loader(self.root / "cars_train" / file_name)
if self.transform is not None:
sample = self.transform(sample)
return sample, target
class ImprovedImageFolder(ImageFolder):
def __init__(self, root, loader=default_loader, transform=None):
super().__init__(root, transform, loader=loader)
self.classes_list = defaultdict(list)
for i in range(len(self)):
self.classes_list[self.samples[i][-1]].append(i)
assert len(self.classes_list) == len(self.classes)
def __getitem__(self, item):
return super().__getitem__(item)
class LMDBDataset(Dataset):
def __init__(self, lmdb_path, transform=None):
self.db = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path), readonly=True,
lock=False, readahead=False, meminit=False)
self.transform = transform
with self.db.begin(write=False) as txn:
self.classes_list = pickle.loads(txn.get(b"classes_list"))
self._len = pickle.loads(txn.get(b"__len__"))
def __len__(self):
return self._len
def __getitem__(self, i):
with self.db.begin(write=False) as txn:
sample, target = pickle.loads(txn.get("{}".format(i).encode()))
if self.transform is not None:
sample = self.transform(sample)
return sample, target
class EpisodicDataset(Dataset):
def __init__(self, origin_dataset, num_class, num_set, num_episodes):
self.origin = origin_dataset
self.num_class = num_class
assert self.num_class < len(self.origin.classes_list)
self.num_set = num_set # K
self.num_episodes = num_episodes
self.t0 = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def apply_transform(self, img):
# img1 = self.transform(img)
# img2 = self.transform(img)
# return [self.t0(img), self.t0(functional.hflip(img))]
return [self.t0(img)]
def __len__(self):
return self.num_episodes
def __getitem__(self, _):
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set_list = []
query_set_list = []
target_list = []
for i, c in enumerate(random_classes):
image_list = self.origin.classes_list[c]
if len(image_list) > self.num_set * 2:
idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist()
support = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]]
query = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]]
support_set_list.extend(sum(map(self.apply_transform, support), list()))
query_set_list.extend(sum(map(self.apply_transform, query), list()))
target_list.extend([i] * self.num_set)
return {
"support": torch.stack(support_set_list),
"query": torch.stack(query_set_list),
"target": torch.tensor(target_list)
}
def __repr__(self):
return "<EpisodicDataset N={} K={} NUM={}>".format(self.num_class, self.num_set, self.num_episodes)