import os import pickle import lmdb from torch.utils.data import Dataset from tqdm import tqdm from data.transform import transform_pipeline def default_transform_way(transform, sample): return [transform(sample[0]), *sample[1:]] class LMDBDataset(Dataset): def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True, **lmdb_kwargs): self.path = lmdb_path self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly, lock=False, **lmdb_kwargs) with self.db.begin(write=False) as txn: self._len = pickle.loads(txn.get(b"$$len$$")) self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$")) if pipeline is None: self.not_done_pipeline = [] else: self.not_done_pipeline = self._remain_pipeline(pipeline) self.transform = transform_pipeline(self.not_done_pipeline) self.transform_way = transform_way essential_attr = pickle.loads(txn.get(b"$$essential_attr$$")) for ea in essential_attr: setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8")))) def _remain_pipeline(self, pipeline): for i, dp in enumerate(self.done_pipeline): if pipeline[i] != dp: raise ValueError( f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}") return pipeline[len(self.done_pipeline):] def __repr__(self): return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}" def __len__(self): return self._len def __getitem__(self, idx): with self.db.begin(write=False) as txn: sample = pickle.loads(txn.get("{}".format(idx).encode())) sample = self.transform_way(self.transform, sample) return sample @staticmethod def lmdbify(dataset, done_pipeline, lmdb_path): env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path)) with env.begin(write=True) as txn: for i in tqdm(range(len(dataset)), ncols=0): txn.put("{}".format(i).encode(), pickle.dumps(dataset[i])) txn.put(b"$$len$$", pickle.dumps(len(dataset))) txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline)) essential_attr = getattr(dataset, "essential_attr", list()) txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr)) for ea in essential_attr: txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))