import os import pickle from collections import defaultdict from itertools import permutations, combinations from pathlib import Path import lmdb import torch from PIL import Image from torch.utils.data import Dataset from torchvision.datasets import ImageFolder from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS from torchvision.transforms import functional as F from tqdm import tqdm from .registry import DATASET from .transform import transform_pipeline from .util import dlib_landmark def default_transform_way(transform, sample): return [transform(sample[0]), *sample[1:]] class LMDBDataset(Dataset): def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True, **lmdb_kwargs): self.path = lmdb_path self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly, lock=False, **lmdb_kwargs) with self.db.begin(write=False) as txn: self._len = pickle.loads(txn.get(b"$$len$$")) self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$")) if pipeline is None: self.not_done_pipeline = [] else: self.not_done_pipeline = self._remain_pipeline(pipeline) self.transform = transform_pipeline(self.not_done_pipeline) self.transform_way = transform_way essential_attr = pickle.loads(txn.get(b"$$essential_attr$$")) for ea in essential_attr: setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8")))) def _remain_pipeline(self, pipeline): for i, dp in enumerate(self.done_pipeline): if pipeline[i] != dp: raise ValueError( f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}") return pipeline[len(self.done_pipeline):] def __repr__(self): return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}" def __len__(self): return self._len def __getitem__(self, idx): with self.db.begin(write=False) as txn: sample = pickle.loads(txn.get("{}".format(idx).encode())) sample = self.transform_way(self.transform, sample) return sample @staticmethod def lmdbify(dataset, done_pipeline, lmdb_path): env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path)) with env.begin(write=True) as txn: for i in tqdm(range(len(dataset)), ncols=0): txn.put("{}".format(i).encode(), pickle.dumps(dataset[i])) txn.put(b"$$len$$", pickle.dumps(len(dataset))) txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline)) essential_attr = getattr(dataset, "essential_attr", list()) txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr)) for ea in essential_attr: txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea))) @DATASET.register_module() class ImprovedImageFolder(ImageFolder): def __init__(self, root, pipeline): super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x) self.classes_list = defaultdict(list) self.essential_attr = ["classes_list"] for i in range(len(self)): self.classes_list[self.samples[i][-1]].append(i) assert len(self.classes_list) == len(self.classes) class EpisodicDataset(Dataset): def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes): self.origin = origin_dataset self.num_class = num_class assert self.num_class < len(self.origin.classes_list) self.num_query = num_query # K self.num_support = num_support # K self.num_episodes = num_episodes def _fetch_list_data(self, id_list): return [self.origin[i][0] for i in id_list] def __len__(self): return self.num_episodes def __getitem__(self, _): random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() support_set = [] query_set = [] target_set = [] for tag, c in enumerate(random_classes): image_list = self.origin.classes_list[c] if len(image_list) >= self.num_query + self.num_support: # have enough images belong to this class idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist() else: idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist() support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support])) query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:])) support_set.extend(support) query_set.extend(query) target_set.extend([tag] * self.num_query) return { "support": torch.stack(support_set), "query": torch.stack(query_set), "target": torch.tensor(target_set) } def __repr__(self): return f"" @DATASET.register_module() class SingleFolderDataset(Dataset): def __init__(self, root, pipeline, with_path=False): assert os.path.isdir(root) self.root = root samples = [] for r, _, fns in sorted(os.walk(self.root, followlinks=True)): for fn in sorted(fns): path = os.path.join(r, fn) if has_file_allowed_extension(path, IMG_EXTENSIONS): samples.append(path) self.samples = samples self.pipeline = transform_pipeline(pipeline) self.with_path = with_path def __len__(self): return len(self.samples) def __getitem__(self, idx): if not self.with_path: return self.pipeline(self.samples[idx]) else: return self.pipeline(self.samples[idx]), self.samples[idx] def __repr__(self): return f"" @DATASET.register_module() class GenerationUnpairedDataset(Dataset): def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False): self.A = SingleFolderDataset(root_a, pipeline, with_path) self.B = SingleFolderDataset(root_b, pipeline, with_path) self.random_pair = random_pair def __getitem__(self, idx): a_idx = idx % len(self.A) b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() return dict(a=self.A[a_idx], b=self.B[b_idx]) def __len__(self): return max(len(self.A), len(self.B)) def __repr__(self): return f"\nPipeline:\n{self.A.pipeline}" def normalize_tensor(tensor): tensor = tensor.float() tensor -= tensor.min() tensor /= tensor.max() return tensor @DATASET.register_module() class GenerationUnpairedDatasetWithEdge(Dataset): def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256), with_path=False): assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"] self.edge_type = edge_type self.size = size self.edges_path = Path(edges_path) self.landmarks_path = Path(landmarks_path) assert self.edges_path.exists() assert self.landmarks_path.exists() self.A = SingleFolderDataset(root_a, pipeline, with_path=True) self.B = SingleFolderDataset(root_b, pipeline, with_path=True) self.random_pair = random_pair self.with_path = with_path def get_edge(self, origin_path): op = Path(origin_path) if self.edge_type.startswith("landmark_"): edge_type = self.edge_type.lstrip("landmark_") use_landmark = op.parent.name.endswith("A") else: edge_type = self.edge_type use_landmark = False edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png" origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR)) if not use_landmark: return origin_edge else: landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt" key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size) dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size))) part_labels = normalize_tensor(torch.from_numpy(part_labels)) part_edge = torch.from_numpy(part_edge).unsqueeze(0).float() # edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face # edges = part_edge + edges return torch.cat([origin_edge, part_edge, dist_tensor, part_labels]) def __getitem__(self, idx): a_idx = idx % len(self.A) b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() output = dict(a={}, b={}) output["a"]["img"], output["a"]["path"] = self.A[a_idx] output["b"]["img"], output["b"]["path"] = self.B[b_idx] for p in "ab": output[p]["edge"] = self.get_edge(output[p]["path"]) return output def __len__(self): return max(len(self.A), len(self.B)) def __repr__(self): return f"\nPipeline:\n{self.A.pipeline}" @DATASET.register_module() class PoseFacesWithSingleAnime(Dataset): def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size, with_order=True): self.num_face = num_face self.landmark_path = Path(landmark_path) self.with_order = with_order self.root_face = Path(root_face) self.root_anime = Path(root_anime) self.img_size = img_size self.face_samples = self.iter_folders() self.face_pipeline = transform_pipeline(face_pipeline) self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True) def iter_folders(self): pics_per_person = defaultdict(list) for p in self.root_face.glob("*.jpg"): pics_per_person[p.stem[:7]].append(p.stem) data = [] for p in pics_per_person: if len(pics_per_person[p]) >= self.num_face: if self.with_order: data.extend(list(combinations(pics_per_person[p], self.num_face))) else: data.extend(list(permutations(pics_per_person[p], self.num_face))) return data def read_pose(self, pose_txt): key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size) dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size))) part_labels = normalize_tensor(torch.from_numpy(part_labels)) part_edge = torch.from_numpy(part_edge).unsqueeze(0).float() return torch.cat([part_labels, part_edge, dist_tensor]) def __len__(self): return len(self.face_samples) def __getitem__(self, idx): output = dict() output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()] for i, f in enumerate(self.face_samples[idx]): output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg") output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt") output[f"stem_{i}"] = f return output