raycv/data/dataset/lmdb.py
2020-10-10 10:43:00 +08:00

66 lines
2.7 KiB
Python

import os
import pickle
import lmdb
from torch.utils.data import Dataset
from tqdm import tqdm
from data.transform import transform_pipeline
def default_transform_way(transform, sample):
return [transform(sample[0]), *sample[1:]]
class LMDBDataset(Dataset):
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
**lmdb_kwargs):
self.path = lmdb_path
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
lock=False, **lmdb_kwargs)
with self.db.begin(write=False) as txn:
self._len = pickle.loads(txn.get(b"$$len$$"))
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
if pipeline is None:
self.not_done_pipeline = []
else:
self.not_done_pipeline = self._remain_pipeline(pipeline)
self.transform = transform_pipeline(self.not_done_pipeline)
self.transform_way = transform_way
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
for ea in essential_attr:
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
def _remain_pipeline(self, pipeline):
for i, dp in enumerate(self.done_pipeline):
if pipeline[i] != dp:
raise ValueError(
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
return pipeline[len(self.done_pipeline):]
def __repr__(self):
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
def __len__(self):
return self._len
def __getitem__(self, idx):
with self.db.begin(write=False) as txn:
sample = pickle.loads(txn.get("{}".format(idx).encode()))
sample = self.transform_way(self.transform, sample)
return sample
@staticmethod
def lmdbify(dataset, done_pipeline, lmdb_path):
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=0):
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
essential_attr = getattr(dataset, "essential_attr", list())
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
for ea in essential_attr:
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))