import os import torch from torch.utils.data import Dataset from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS from data.registry import DATASET from data.transform import transform_pipeline @DATASET.register_module() class SingleFolderDataset(Dataset): def __init__(self, root, pipeline, with_path=False): assert os.path.isdir(root) self.root = root samples = [] for r, _, fns in sorted(os.walk(self.root, followlinks=True)): for fn in sorted(fns): path = os.path.join(r, fn) if has_file_allowed_extension(path, IMG_EXTENSIONS): samples.append(path) self.samples = samples self.pipeline = transform_pipeline(pipeline) self.with_path = with_path def __len__(self): return len(self.samples) def __getitem__(self, idx): output = dict(img=self.pipeline(self.samples[idx])) if self.with_path: output["path"] = self.samples[idx] return output def __repr__(self): return f"" @DATASET.register_module() class GenerationUnpairedDataset(Dataset): def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False): self.A = SingleFolderDataset(root_a, pipeline, with_path) self.B = SingleFolderDataset(root_b, pipeline, with_path) self.with_path = with_path self.random_pair = random_pair def __getitem__(self, idx): a_idx = idx % len(self.A) b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() output_a = self.A[a_idx] output_b = self.B[b_idx] output = dict(a=output_a["img"], b=output_b["img"]) if self.with_path: output["a_path"] = output_a["path"] output["b_path"] = output_b["path"] return output def __len__(self): return max(len(self.A), len(self.B)) def __repr__(self): return f"\nPipeline:\n{self.A.pipeline}"