raycv/data/dataset.py
2020-09-10 18:34:52 +08:00

243 lines
9.7 KiB
Python

import os
import pickle
from collections import defaultdict
from pathlib import Path
import lmdb
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
from torchvision.transforms import functional as F
from tqdm import tqdm
from .registry import DATASET
from .transform import transform_pipeline
from .util import dlib_landmark
def default_transform_way(transform, sample):
return [transform(sample[0]), *sample[1:]]
class LMDBDataset(Dataset):
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
**lmdb_kwargs):
self.path = lmdb_path
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
lock=False, **lmdb_kwargs)
with self.db.begin(write=False) as txn:
self._len = pickle.loads(txn.get(b"$$len$$"))
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
if pipeline is None:
self.not_done_pipeline = []
else:
self.not_done_pipeline = self._remain_pipeline(pipeline)
self.transform = transform_pipeline(self.not_done_pipeline)
self.transform_way = transform_way
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
for ea in essential_attr:
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
def _remain_pipeline(self, pipeline):
for i, dp in enumerate(self.done_pipeline):
if pipeline[i] != dp:
raise ValueError(
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
return pipeline[len(self.done_pipeline):]
def __repr__(self):
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
def __len__(self):
return self._len
def __getitem__(self, idx):
with self.db.begin(write=False) as txn:
sample = pickle.loads(txn.get("{}".format(idx).encode()))
sample = self.transform_way(self.transform, sample)
return sample
@staticmethod
def lmdbify(dataset, done_pipeline, lmdb_path):
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=0):
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
essential_attr = getattr(dataset, "essential_attr", list())
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
for ea in essential_attr:
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))
@DATASET.register_module()
class ImprovedImageFolder(ImageFolder):
def __init__(self, root, pipeline):
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
self.classes_list = defaultdict(list)
self.essential_attr = ["classes_list"]
for i in range(len(self)):
self.classes_list[self.samples[i][-1]].append(i)
assert len(self.classes_list) == len(self.classes)
class EpisodicDataset(Dataset):
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
self.origin = origin_dataset
self.num_class = num_class
assert self.num_class < len(self.origin.classes_list)
self.num_query = num_query # K
self.num_support = num_support # K
self.num_episodes = num_episodes
def _fetch_list_data(self, id_list):
return [self.origin[i][0] for i in id_list]
def __len__(self):
return self.num_episodes
def __getitem__(self, _):
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set = []
query_set = []
target_set = []
for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c]
if len(image_list) >= self.num_query + self.num_support:
# have enough images belong to this class
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
support_set.extend(support)
query_set.extend(query)
target_set.extend([tag] * self.num_query)
return {
"support": torch.stack(support_set),
"query": torch.stack(query_set),
"target": torch.tensor(target_set)
}
def __repr__(self):
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
@DATASET.register_module()
class SingleFolderDataset(Dataset):
def __init__(self, root, pipeline, with_path=False):
assert os.path.isdir(root)
self.root = root
samples = []
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
for fn in sorted(fns):
path = os.path.join(r, fn)
if has_file_allowed_extension(path, IMG_EXTENSIONS):
samples.append(path)
self.samples = samples
self.pipeline = transform_pipeline(pipeline)
self.with_path = with_path
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
if not self.with_path:
return self.pipeline(self.samples[idx])
else:
return self.pipeline(self.samples[idx]), self.samples[idx]
def __repr__(self):
return f"<SingleFolderDataset root={self.root} len={len(self)}>"
@DATASET.register_module()
class GenerationUnpairedDataset(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
self.A = SingleFolderDataset(root_a, pipeline, with_path)
self.B = SingleFolderDataset(root_b, pipeline, with_path)
self.random_pair = random_pair
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
return dict(a=self.A[a_idx], b=self.B[b_idx])
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
def normalize_tensor(tensor):
tensor = tensor.float()
tensor -= tensor.min()
tensor /= tensor.max()
return tensor
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
with_path=False):
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
self.landmarks_path = Path(landmarks_path)
assert self.edges_path.exists()
assert self.landmarks_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
self.with_path = with_path
def get_edge(self, origin_path):
op = Path(origin_path)
if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_")
use_landmark = True
else:
edge_type = self.edge_type
use_landmark = False
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
if not use_landmark:
return origin_edge
else:
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
# edges = part_edge + edges
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
if self.with_path:
output = {"a": self.A[a_idx], "b": self.B[b_idx]}
output["edge_a"] = output["a"][1]
return output
output = dict()
output["a"], path_a = self.A[a_idx]
output["b"], path_b = self.B[b_idx]
output["edge_a"] = self.get_edge(path_a)
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"