import os import pickle from collections import defaultdict import torch from torch.utils.data import Dataset from torchvision.datasets import ImageFolder from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader import lmdb from tqdm import tqdm from .transform import transform_pipeline from .registry import DATASET def default_transform_way(transform, sample): return [transform(sample[0]), *sample[1:]] class LMDBDataset(Dataset): def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True, **lmdb_kwargs): self.path = lmdb_path self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly, lock=False, **lmdb_kwargs) with self.db.begin(write=False) as txn: self._len = pickle.loads(txn.get(b"$$len$$")) self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$")) if pipeline is None: self.not_done_pipeline = [] else: self.not_done_pipeline = self._remain_pipeline(pipeline) self.transform = transform_pipeline(self.not_done_pipeline) self.transform_way = transform_way essential_attr = pickle.loads(txn.get(b"$$essential_attr$$")) for ea in essential_attr: setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8")))) def _remain_pipeline(self, pipeline): for i, dp in enumerate(self.done_pipeline): if pipeline[i] != dp: raise ValueError( f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}") return pipeline[len(self.done_pipeline):] def __repr__(self): return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}" def __len__(self): return self._len def __getitem__(self, idx): with self.db.begin(write=False) as txn: sample = pickle.loads(txn.get("{}".format(idx).encode())) sample = self.transform_way(self.transform, sample) return sample @staticmethod def lmdbify(dataset, done_pipeline, lmdb_path): env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path)) with env.begin(write=True) as txn: for i in tqdm(range(len(dataset)), ncols=0): txn.put("{}".format(i).encode(), pickle.dumps(dataset[i])) txn.put(b"$$len$$", pickle.dumps(len(dataset))) txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline)) essential_attr = getattr(dataset, "essential_attr", list()) txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr)) for ea in essential_attr: txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea))) @DATASET.register_module() class ImprovedImageFolder(ImageFolder): def __init__(self, root, pipeline): super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x) self.classes_list = defaultdict(list) self.essential_attr = ["classes_list"] for i in range(len(self)): self.classes_list[self.samples[i][-1]].append(i) assert len(self.classes_list) == len(self.classes) class EpisodicDataset(Dataset): def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes): self.origin = origin_dataset self.num_class = num_class assert self.num_class < len(self.origin.classes_list) self.num_query = num_query # K self.num_support = num_support # K self.num_episodes = num_episodes def _fetch_list_data(self, id_list): return [self.origin[i][0] for i in id_list] def __len__(self): return self.num_episodes def __getitem__(self, _): random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() support_set = [] query_set = [] target_set = [] for tag, c in enumerate(random_classes): image_list = self.origin.classes_list[c] if len(image_list) >= self.num_query + self.num_support: # have enough images belong to this class idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist() else: idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist() support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support])) query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:])) support_set.extend(support) query_set.extend(query) target_set.extend([tag] * self.num_query) return { "support": torch.stack(support_set), "query": torch.stack(query_set), "target": torch.tensor(target_set) } def __repr__(self): return f"" @DATASET.register_module() class SingleFolderDataset(Dataset): def __init__(self, root, pipeline): assert os.path.isdir(root) self.root = root samples = [] for r, _, fns in sorted(os.walk(self.root, followlinks=True)): for fn in sorted(fns): path = os.path.join(r, fn) if has_file_allowed_extension(path, IMG_EXTENSIONS): samples.append(path) self.samples = samples self.pipeline = transform_pipeline(pipeline) def __len__(self): return len(self.samples) def __getitem__(self, idx): return self.pipeline(self.samples[idx]) def __repr__(self): return f"" @DATASET.register_module() class GenerationUnpairedDataset(Dataset): def __init__(self, root_a, root_b, random_pair, pipeline): self.A = SingleFolderDataset(root_a, pipeline) self.B = SingleFolderDataset(root_b, pipeline) self.random_pair = random_pair def __getitem__(self, idx): a_idx = idx % len(self.A) b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() return dict(a=self.A[a_idx], b=self.B[b_idx]) def __len__(self): return max(len(self.A), len(self.B)) def __repr__(self): return f"\nPipeline:\n{self.A.pipeline}"