raycv/data/dataset.py

74 lines
2.4 KiB
Python

import os
import pickle
import torch
from torch.utils.data import Dataset
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
import lmdb
from .transform import transform_pipeline
from .registry import DATASET
class LMDBDataset(Dataset):
def __init__(self, lmdb_path, output_transform=None, map_size=2 ** 40, readonly=True, **lmdb_kwargs):
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
**lmdb_kwargs)
self.output_transform = output_transform
with self.db.begin(write=False) as txn:
self._len = pickle.loads(txn.get(b"__len__"))
def __len__(self):
return self._len
def __getitem__(self, idx):
with self.db.begin(write=False) as txn:
sample = pickle.loads(txn.get("{}".format(idx).encode()))
if self.output_transform is not None:
sample = self.output_transform(sample)
return sample
@DATASET.register_module()
class SingleFolderDataset(Dataset):
def __init__(self, root, pipeline):
assert os.path.isdir(root)
self.root = root
samples = []
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
for fn in sorted(fns):
path = os.path.join(r, fn)
if has_file_allowed_extension(path, IMG_EXTENSIONS):
samples.append(path)
self.samples = samples
self.pipeline = transform_pipeline(pipeline)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.pipeline(self.samples[idx])
def __repr__(self):
return f"<SingleFolderDataset root={self.root} len={len(self)}>"
@DATASET.register_module()
class GenerationUnpairedDataset(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline):
self.A = SingleFolderDataset(root_a, pipeline)
self.B = SingleFolderDataset(root_b, pipeline)
self.random_pair = random_pair
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
return dict(a=self.A[a_idx], b=self.B[b_idx])
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"