from collections import defaultdict from itertools import permutations, combinations from pathlib import Path import torch from PIL import Image from torch.utils.data import Dataset from torchvision.transforms import functional as F from data.registry import DATASET from data.transform import transform_pipeline from data.util import dlib_landmark def normalize_tensor(tensor): tensor = tensor.float() tensor -= tensor.min() tensor /= tensor.max() return tensor @DATASET.register_module() class GenerationUnpairedDatasetWithEdge(Dataset): def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256), with_path=False): assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"] self.edge_type = edge_type self.size = size self.edges_path = Path(edges_path) self.landmarks_path = Path(landmarks_path) assert self.edges_path.exists() assert self.landmarks_path.exists() self.A = SingleFolderDataset(root_a, pipeline, with_path=True) self.B = SingleFolderDataset(root_b, pipeline, with_path=True) self.random_pair = random_pair self.with_path = with_path def get_edge(self, origin_path): op = Path(origin_path) if self.edge_type.startswith("landmark_"): edge_type = self.edge_type.lstrip("landmark_") use_landmark = op.parent.name.endswith("A") else: edge_type = self.edge_type use_landmark = False edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png" origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR)) if not use_landmark: return origin_edge else: landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt" key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size) dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size))) part_labels = normalize_tensor(torch.from_numpy(part_labels)) part_edge = torch.from_numpy(part_edge).unsqueeze(0).float() # edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face # edges = part_edge + edges return torch.cat([origin_edge, part_edge, dist_tensor, part_labels]) def __getitem__(self, idx): a_idx = idx % len(self.A) b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() output = dict(a={}, b={}) output["a"]["img"], output["a"]["path"] = self.A[a_idx] output["b"]["img"], output["b"]["path"] = self.B[b_idx] for p in "ab": output[p]["edge"] = self.get_edge(output[p]["path"]) return output def __len__(self): return max(len(self.A), len(self.B)) def __repr__(self): return f"\nPipeline:\n{self.A.pipeline}" @DATASET.register_module() class PoseFacesWithSingleAnime(Dataset): def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size, with_order=True): self.num_face = num_face self.landmark_path = Path(landmark_path) self.with_order = with_order self.root_face = Path(root_face) self.root_anime = Path(root_anime) self.img_size = img_size self.face_samples = self.iter_folders() self.face_pipeline = transform_pipeline(face_pipeline) self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True) def iter_folders(self): pics_per_person = defaultdict(list) for p in self.root_face.glob("*.jpg"): pics_per_person[p.stem[:7]].append(p.stem) data = [] for p in pics_per_person: if len(pics_per_person[p]) >= self.num_face: if self.with_order: data.extend(list(combinations(pics_per_person[p], self.num_face))) else: data.extend(list(permutations(pics_per_person[p], self.num_face))) return data def read_pose(self, pose_txt): key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size) dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size))) part_labels = normalize_tensor(torch.from_numpy(part_labels)) part_edge = torch.from_numpy(part_edge).unsqueeze(0).float() return torch.cat([part_labels, part_edge, dist_tensor]) def __len__(self): return len(self.face_samples) def __getitem__(self, idx): output = dict() output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()] for i, f in enumerate(self.face_samples[idx]): output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg") output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt") output[f"stem_{i}"] = f return output