import os import pickle import torch from torch.utils.data import Dataset from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS import lmdb from .transform import transform_pipeline from .registry import DATASET class LMDBDataset(Dataset): def __init__(self, lmdb_path, output_transform=None, map_size=2 ** 40, readonly=True, **lmdb_kwargs): self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly, **lmdb_kwargs) self.output_transform = output_transform with self.db.begin(write=False) as txn: self._len = pickle.loads(txn.get(b"__len__")) def __len__(self): return self._len def __getitem__(self, idx): with self.db.begin(write=False) as txn: sample = pickle.loads(txn.get("{}".format(idx).encode())) if self.output_transform is not None: sample = self.output_transform(sample) return sample @DATASET.register_module() class SingleFolderDataset(Dataset): def __init__(self, root, pipeline): assert os.path.isdir(root) self.root = root samples = [] for r, _, fns in sorted(os.walk(self.root, followlinks=True)): for fn in sorted(fns): path = os.path.join(r, fn) if has_file_allowed_extension(path, IMG_EXTENSIONS): samples.append(path) self.samples = samples self.pipeline = transform_pipeline(pipeline) def __len__(self): return len(self.samples) def __getitem__(self, idx): return self.pipeline(self.samples[idx]) def __repr__(self): return f"" @DATASET.register_module() class GenerationUnpairedDataset(Dataset): def __init__(self, root_a, root_b, random_pair, pipeline): self.A = SingleFolderDataset(root_a, pipeline) self.B = SingleFolderDataset(root_b, pipeline) self.random_pair = random_pair def __getitem__(self, idx): a_idx = idx % len(self.A) b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() return dict(a=self.A[a_idx], b=self.B[b_idx]) def __len__(self): return max(len(self.A), len(self.B)) def __repr__(self): return f"\nPipeline:\n{self.A.pipeline}"