63 lines
2.1 KiB
Python
63 lines
2.1 KiB
Python
import os
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
|
|
|
|
from data.registry import DATASET
|
|
from data.transform import transform_pipeline
|
|
|
|
|
|
@DATASET.register_module()
|
|
class SingleFolderDataset(Dataset):
|
|
def __init__(self, root, pipeline, with_path=False):
|
|
assert os.path.isdir(root)
|
|
self.root = root
|
|
samples = []
|
|
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
|
|
for fn in sorted(fns):
|
|
path = os.path.join(r, fn)
|
|
if has_file_allowed_extension(path, IMG_EXTENSIONS):
|
|
samples.append(path)
|
|
self.samples = samples
|
|
self.pipeline = transform_pipeline(pipeline)
|
|
self.with_path = with_path
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
def __getitem__(self, idx):
|
|
output = dict(img=self.pipeline(self.samples[idx]))
|
|
if self.with_path:
|
|
output["path"] = self.samples[idx]
|
|
return output
|
|
|
|
def __repr__(self):
|
|
return f"<SingleFolderDataset root={self.root} len={len(self)} with_path={self.with_path}>"
|
|
|
|
|
|
@DATASET.register_module()
|
|
class GenerationUnpairedDataset(Dataset):
|
|
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
|
|
self.A = SingleFolderDataset(root_a, pipeline, with_path)
|
|
self.B = SingleFolderDataset(root_b, pipeline, with_path)
|
|
self.with_path = with_path
|
|
self.random_pair = random_pair
|
|
|
|
def __getitem__(self, idx):
|
|
a_idx = idx % len(self.A)
|
|
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
|
output_a = self.A[a_idx]
|
|
output_b = self.B[b_idx]
|
|
output = dict(a=output_a["img"], b=output_b["img"])
|
|
if self.with_path:
|
|
output["a_path"] = output_a["path"]
|
|
output["b_path"] = output_b["path"]
|
|
return output
|
|
|
|
def __len__(self):
|
|
return max(len(self.A), len(self.B))
|
|
|
|
def __repr__(self):
|
|
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|