170 lines
6.6 KiB
Python
170 lines
6.6 KiB
Python
import os
|
|
import pickle
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from torchvision.datasets import ImageFolder
|
|
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
|
|
|
|
import lmdb
|
|
from tqdm import tqdm
|
|
|
|
from .transform import transform_pipeline
|
|
from .registry import DATASET
|
|
|
|
|
|
def default_transform_way(transform, sample):
|
|
return [transform(sample[0]), *sample[1:]]
|
|
|
|
|
|
class LMDBDataset(Dataset):
|
|
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
|
|
**lmdb_kwargs):
|
|
self.path = lmdb_path
|
|
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
|
|
lock=False, **lmdb_kwargs)
|
|
|
|
with self.db.begin(write=False) as txn:
|
|
self._len = pickle.loads(txn.get(b"$$len$$"))
|
|
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
|
|
if pipeline is None:
|
|
self.not_done_pipeline = []
|
|
else:
|
|
self.not_done_pipeline = self._remain_pipeline(pipeline)
|
|
self.transform = transform_pipeline(self.not_done_pipeline)
|
|
self.transform_way = transform_way
|
|
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
|
|
for ea in essential_attr:
|
|
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
|
|
|
|
def _remain_pipeline(self, pipeline):
|
|
for i, dp in enumerate(self.done_pipeline):
|
|
if pipeline[i] != dp:
|
|
raise ValueError(
|
|
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
|
|
return pipeline[len(self.done_pipeline):]
|
|
|
|
def __repr__(self):
|
|
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
|
|
|
|
def __len__(self):
|
|
return self._len
|
|
|
|
def __getitem__(self, idx):
|
|
with self.db.begin(write=False) as txn:
|
|
sample = pickle.loads(txn.get("{}".format(idx).encode()))
|
|
sample = self.transform_way(self.transform, sample)
|
|
return sample
|
|
|
|
@staticmethod
|
|
def lmdbify(dataset, done_pipeline, lmdb_path):
|
|
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
|
|
with env.begin(write=True) as txn:
|
|
for i in tqdm(range(len(dataset)), ncols=0):
|
|
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
|
|
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
|
|
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
|
|
essential_attr = getattr(dataset, "essential_attr", list())
|
|
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
|
|
for ea in essential_attr:
|
|
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))
|
|
|
|
|
|
@DATASET.register_module()
|
|
class ImprovedImageFolder(ImageFolder):
|
|
def __init__(self, root, pipeline):
|
|
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
|
|
self.classes_list = defaultdict(list)
|
|
self.essential_attr = ["classes_list"]
|
|
for i in range(len(self)):
|
|
self.classes_list[self.samples[i][-1]].append(i)
|
|
assert len(self.classes_list) == len(self.classes)
|
|
|
|
|
|
class EpisodicDataset(Dataset):
|
|
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
|
|
self.origin = origin_dataset
|
|
self.num_class = num_class
|
|
assert self.num_class < len(self.origin.classes_list)
|
|
self.num_query = num_query # K
|
|
self.num_support = num_support # K
|
|
self.num_episodes = num_episodes
|
|
|
|
def _fetch_list_data(self, id_list):
|
|
return [self.origin[i][0] for i in id_list]
|
|
|
|
def __len__(self):
|
|
return self.num_episodes
|
|
|
|
def __getitem__(self, _):
|
|
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
|
|
support_set_list = []
|
|
query_set_list = []
|
|
target_list = []
|
|
for tag, c in enumerate(random_classes):
|
|
image_list = self.origin.classes_list[c]
|
|
|
|
if len(image_list) >= self.num_query + self.num_support:
|
|
# have enough images belong to this class
|
|
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
|
|
else:
|
|
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
|
|
|
|
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
|
|
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
|
|
support_set_list.extend(support)
|
|
query_set_list.extend(query)
|
|
target_list.extend([tag] * self.num_query)
|
|
return {
|
|
"support": torch.stack(support_set_list),
|
|
"query": torch.stack(query_set_list),
|
|
"target": torch.tensor(target_list)
|
|
}
|
|
|
|
def __repr__(self):
|
|
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
|
|
|
|
|
|
@DATASET.register_module()
|
|
class SingleFolderDataset(Dataset):
|
|
def __init__(self, root, pipeline):
|
|
assert os.path.isdir(root)
|
|
self.root = root
|
|
samples = []
|
|
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
|
|
for fn in sorted(fns):
|
|
path = os.path.join(r, fn)
|
|
if has_file_allowed_extension(path, IMG_EXTENSIONS):
|
|
samples.append(path)
|
|
self.samples = samples
|
|
self.pipeline = transform_pipeline(pipeline)
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.pipeline(self.samples[idx])
|
|
|
|
def __repr__(self):
|
|
return f"<SingleFolderDataset root={self.root} len={len(self)}>"
|
|
|
|
|
|
@DATASET.register_module()
|
|
class GenerationUnpairedDataset(Dataset):
|
|
def __init__(self, root_a, root_b, random_pair, pipeline):
|
|
self.A = SingleFolderDataset(root_a, pipeline)
|
|
self.B = SingleFolderDataset(root_b, pipeline)
|
|
self.random_pair = random_pair
|
|
|
|
def __getitem__(self, idx):
|
|
a_idx = idx % len(self.A)
|
|
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
|
return dict(a=self.A[a_idx], b=self.B[b_idx])
|
|
|
|
def __len__(self):
|
|
return max(len(self.A), len(self.B))
|
|
|
|
def __repr__(self):
|
|
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|