From 6ea13df465b508bbad7de077c35ce1720c40c343 Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Sat, 10 Oct 2020 10:43:00 +0800 Subject: [PATCH] temp commit --- data/dataset.py | 287 ------------------------------ data/dataset/__init__.py | 3 + data/dataset/few-shot.py | 63 +++++++ data/dataset/image_translation.py | 62 +++++++ data/dataset/lmdb.py | 65 +++++++ data/dataset/pose_transfer.py | 122 +++++++++++++ data/transform.py | 2 +- model/GAN/__init__.py | 3 + model/__init__.py | 8 +- util/misc.py | 20 +++ 10 files changed, 340 insertions(+), 295 deletions(-) delete mode 100644 data/dataset.py create mode 100644 data/dataset/__init__.py create mode 100644 data/dataset/few-shot.py create mode 100644 data/dataset/image_translation.py create mode 100644 data/dataset/lmdb.py create mode 100644 data/dataset/pose_transfer.py diff --git a/data/dataset.py b/data/dataset.py deleted file mode 100644 index 9e53028..0000000 --- a/data/dataset.py +++ /dev/null @@ -1,287 +0,0 @@ -import os -import pickle -from collections import defaultdict -from itertools import permutations, combinations -from pathlib import Path - -import lmdb -import torch -from PIL import Image -from torch.utils.data import Dataset -from torchvision.datasets import ImageFolder -from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS -from torchvision.transforms import functional as F -from tqdm import tqdm - -from .registry import DATASET -from .transform import transform_pipeline -from .util import dlib_landmark - - -def default_transform_way(transform, sample): - return [transform(sample[0]), *sample[1:]] - - -class LMDBDataset(Dataset): - def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True, - **lmdb_kwargs): - self.path = lmdb_path - self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly, - lock=False, **lmdb_kwargs) - - with self.db.begin(write=False) as txn: - self._len = pickle.loads(txn.get(b"$$len$$")) - self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$")) - if pipeline is None: - self.not_done_pipeline = [] - else: - self.not_done_pipeline = self._remain_pipeline(pipeline) - self.transform = transform_pipeline(self.not_done_pipeline) - self.transform_way = transform_way - essential_attr = pickle.loads(txn.get(b"$$essential_attr$$")) - for ea in essential_attr: - setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8")))) - - def _remain_pipeline(self, pipeline): - for i, dp in enumerate(self.done_pipeline): - if pipeline[i] != dp: - raise ValueError( - f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}") - return pipeline[len(self.done_pipeline):] - - def __repr__(self): - return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}" - - def __len__(self): - return self._len - - def __getitem__(self, idx): - with self.db.begin(write=False) as txn: - sample = pickle.loads(txn.get("{}".format(idx).encode())) - sample = self.transform_way(self.transform, sample) - return sample - - @staticmethod - def lmdbify(dataset, done_pipeline, lmdb_path): - env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path)) - with env.begin(write=True) as txn: - for i in tqdm(range(len(dataset)), ncols=0): - txn.put("{}".format(i).encode(), pickle.dumps(dataset[i])) - txn.put(b"$$len$$", pickle.dumps(len(dataset))) - txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline)) - essential_attr = getattr(dataset, "essential_attr", list()) - txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr)) - for ea in essential_attr: - txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea))) - - -@DATASET.register_module() -class ImprovedImageFolder(ImageFolder): - def __init__(self, root, pipeline): - super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x) - self.classes_list = defaultdict(list) - self.essential_attr = ["classes_list"] - for i in range(len(self)): - self.classes_list[self.samples[i][-1]].append(i) - assert len(self.classes_list) == len(self.classes) - - -class EpisodicDataset(Dataset): - def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes): - self.origin = origin_dataset - self.num_class = num_class - assert self.num_class < len(self.origin.classes_list) - self.num_query = num_query # K - self.num_support = num_support # K - self.num_episodes = num_episodes - - def _fetch_list_data(self, id_list): - return [self.origin[i][0] for i in id_list] - - def __len__(self): - return self.num_episodes - - def __getitem__(self, _): - random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() - support_set = [] - query_set = [] - target_set = [] - for tag, c in enumerate(random_classes): - image_list = self.origin.classes_list[c] - - if len(image_list) >= self.num_query + self.num_support: - # have enough images belong to this class - idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist() - else: - idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist() - - support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support])) - query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:])) - support_set.extend(support) - query_set.extend(query) - target_set.extend([tag] * self.num_query) - return { - "support": torch.stack(support_set), - "query": torch.stack(query_set), - "target": torch.tensor(target_set) - } - - def __repr__(self): - return f"" - - -@DATASET.register_module() -class SingleFolderDataset(Dataset): - def __init__(self, root, pipeline, with_path=False): - assert os.path.isdir(root) - self.root = root - samples = [] - for r, _, fns in sorted(os.walk(self.root, followlinks=True)): - for fn in sorted(fns): - path = os.path.join(r, fn) - if has_file_allowed_extension(path, IMG_EXTENSIONS): - samples.append(path) - self.samples = samples - self.pipeline = transform_pipeline(pipeline) - self.with_path = with_path - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - if not self.with_path: - return self.pipeline(self.samples[idx]) - else: - return self.pipeline(self.samples[idx]), self.samples[idx] - - def __repr__(self): - return f"" - - -@DATASET.register_module() -class GenerationUnpairedDataset(Dataset): - def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False): - self.A = SingleFolderDataset(root_a, pipeline, with_path) - self.B = SingleFolderDataset(root_b, pipeline, with_path) - self.random_pair = random_pair - - def __getitem__(self, idx): - a_idx = idx % len(self.A) - b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() - return dict(a=self.A[a_idx], b=self.B[b_idx]) - - def __len__(self): - return max(len(self.A), len(self.B)) - - def __repr__(self): - return f"\nPipeline:\n{self.A.pipeline}" - - -def normalize_tensor(tensor): - tensor = tensor.float() - tensor -= tensor.min() - tensor /= tensor.max() - return tensor - - -@DATASET.register_module() -class GenerationUnpairedDatasetWithEdge(Dataset): - def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256), - with_path=False): - assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"] - self.edge_type = edge_type - self.size = size - self.edges_path = Path(edges_path) - self.landmarks_path = Path(landmarks_path) - assert self.edges_path.exists() - assert self.landmarks_path.exists() - self.A = SingleFolderDataset(root_a, pipeline, with_path=True) - self.B = SingleFolderDataset(root_b, pipeline, with_path=True) - self.random_pair = random_pair - self.with_path = with_path - - def get_edge(self, origin_path): - op = Path(origin_path) - if self.edge_type.startswith("landmark_"): - edge_type = self.edge_type.lstrip("landmark_") - use_landmark = op.parent.name.endswith("A") - else: - edge_type = self.edge_type - use_landmark = False - edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png" - origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR)) - if not use_landmark: - return origin_edge - else: - landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt" - key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size) - - dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size))) - part_labels = normalize_tensor(torch.from_numpy(part_labels)) - part_edge = torch.from_numpy(part_edge).unsqueeze(0).float() - # edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face - # edges = part_edge + edges - return torch.cat([origin_edge, part_edge, dist_tensor, part_labels]) - - def __getitem__(self, idx): - a_idx = idx % len(self.A) - b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() - output = dict(a={}, b={}) - output["a"]["img"], output["a"]["path"] = self.A[a_idx] - output["b"]["img"], output["b"]["path"] = self.B[b_idx] - for p in "ab": - output[p]["edge"] = self.get_edge(output[p]["path"]) - return output - - def __len__(self): - return max(len(self.A), len(self.B)) - - def __repr__(self): - return f"\nPipeline:\n{self.A.pipeline}" - - -@DATASET.register_module() -class PoseFacesWithSingleAnime(Dataset): - def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size, - with_order=True): - self.num_face = num_face - self.landmark_path = Path(landmark_path) - self.with_order = with_order - self.root_face = Path(root_face) - self.root_anime = Path(root_anime) - self.img_size = img_size - self.face_samples = self.iter_folders() - self.face_pipeline = transform_pipeline(face_pipeline) - self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True) - - def iter_folders(self): - pics_per_person = defaultdict(list) - for p in self.root_face.glob("*.jpg"): - pics_per_person[p.stem[:7]].append(p.stem) - data = [] - for p in pics_per_person: - if len(pics_per_person[p]) >= self.num_face: - if self.with_order: - data.extend(list(combinations(pics_per_person[p], self.num_face))) - else: - data.extend(list(permutations(pics_per_person[p], self.num_face))) - return data - - def read_pose(self, pose_txt): - key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size) - dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size))) - part_labels = normalize_tensor(torch.from_numpy(part_labels)) - part_edge = torch.from_numpy(part_edge).unsqueeze(0).float() - return torch.cat([part_labels, part_edge, dist_tensor]) - - def __len__(self): - return len(self.face_samples) - - def __getitem__(self, idx): - output = dict() - output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()] - for i, f in enumerate(self.face_samples[idx]): - output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg") - output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt") - output[f"stem_{i}"] = f - return output diff --git a/data/dataset/__init__.py b/data/dataset/__init__.py new file mode 100644 index 0000000..7195a56 --- /dev/null +++ b/data/dataset/__init__.py @@ -0,0 +1,3 @@ +from util.misc import import_submodules + +__all__ = import_submodules(__name__).keys() \ No newline at end of file diff --git a/data/dataset/few-shot.py b/data/dataset/few-shot.py new file mode 100644 index 0000000..d38e05f --- /dev/null +++ b/data/dataset/few-shot.py @@ -0,0 +1,63 @@ +from collections import defaultdict + +import torch +from torch.utils.data import Dataset +from torchvision.datasets import ImageFolder + +from data.registry import DATASET +from data.transform import transform_pipeline + + +@DATASET.register_module() +class ImprovedImageFolder(ImageFolder): + def __init__(self, root, pipeline): + super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x) + self.classes_list = defaultdict(list) + self.essential_attr = ["classes_list"] + for i in range(len(self)): + self.classes_list[self.samples[i][-1]].append(i) + assert len(self.classes_list) == len(self.classes) + + +class EpisodicDataset(Dataset): + def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes): + self.origin = origin_dataset + self.num_class = num_class + assert self.num_class < len(self.origin.classes_list) + self.num_query = num_query # K + self.num_support = num_support # K + self.num_episodes = num_episodes + + def _fetch_list_data(self, id_list): + return [self.origin[i][0] for i in id_list] + + def __len__(self): + return self.num_episodes + + def __getitem__(self, _): + random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() + support_set = [] + query_set = [] + target_set = [] + for tag, c in enumerate(random_classes): + image_list = self.origin.classes_list[c] + + if len(image_list) >= self.num_query + self.num_support: + # have enough images belong to this class + idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist() + else: + idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist() + + support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support])) + query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:])) + support_set.extend(support) + query_set.extend(query) + target_set.extend([tag] * self.num_query) + return { + "support": torch.stack(support_set), + "query": torch.stack(query_set), + "target": torch.tensor(target_set) + } + + def __repr__(self): + return f"" diff --git a/data/dataset/image_translation.py b/data/dataset/image_translation.py new file mode 100644 index 0000000..be22098 --- /dev/null +++ b/data/dataset/image_translation.py @@ -0,0 +1,62 @@ +import os + +import torch +from torch.utils.data import Dataset +from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS + +from data.registry import DATASET +from data.transform import transform_pipeline + + +@DATASET.register_module() +class SingleFolderDataset(Dataset): + def __init__(self, root, pipeline, with_path=False): + assert os.path.isdir(root) + self.root = root + samples = [] + for r, _, fns in sorted(os.walk(self.root, followlinks=True)): + for fn in sorted(fns): + path = os.path.join(r, fn) + if has_file_allowed_extension(path, IMG_EXTENSIONS): + samples.append(path) + self.samples = samples + self.pipeline = transform_pipeline(pipeline) + self.with_path = with_path + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + output = dict(img=self.pipeline(self.samples[idx])) + if self.with_path: + output["path"] = self.samples[idx] + return output + + def __repr__(self): + return f"" + + +@DATASET.register_module() +class GenerationUnpairedDataset(Dataset): + def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False): + self.A = SingleFolderDataset(root_a, pipeline, with_path) + self.B = SingleFolderDataset(root_b, pipeline, with_path) + self.with_path = with_path + self.random_pair = random_pair + + def __getitem__(self, idx): + a_idx = idx % len(self.A) + b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() + output_a = self.A[a_idx] + output_b = self.B[b_idx] + output = dict(a=output_a["img"], b=output_b["img"]) + if self.with_path: + output["a_path"] = output_a["path"] + output["b_path"] = output_b["path"] + return output + + def __len__(self): + return max(len(self.A), len(self.B)) + + def __repr__(self): + return f"\nPipeline:\n{self.A.pipeline}" diff --git a/data/dataset/lmdb.py b/data/dataset/lmdb.py new file mode 100644 index 0000000..12ccb9f --- /dev/null +++ b/data/dataset/lmdb.py @@ -0,0 +1,65 @@ +import os +import pickle + +import lmdb +from torch.utils.data import Dataset +from tqdm import tqdm + +from data.transform import transform_pipeline + + +def default_transform_way(transform, sample): + return [transform(sample[0]), *sample[1:]] + + +class LMDBDataset(Dataset): + def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True, + **lmdb_kwargs): + self.path = lmdb_path + self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly, + lock=False, **lmdb_kwargs) + + with self.db.begin(write=False) as txn: + self._len = pickle.loads(txn.get(b"$$len$$")) + self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$")) + if pipeline is None: + self.not_done_pipeline = [] + else: + self.not_done_pipeline = self._remain_pipeline(pipeline) + self.transform = transform_pipeline(self.not_done_pipeline) + self.transform_way = transform_way + essential_attr = pickle.loads(txn.get(b"$$essential_attr$$")) + for ea in essential_attr: + setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8")))) + + def _remain_pipeline(self, pipeline): + for i, dp in enumerate(self.done_pipeline): + if pipeline[i] != dp: + raise ValueError( + f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}") + return pipeline[len(self.done_pipeline):] + + def __repr__(self): + return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}" + + def __len__(self): + return self._len + + def __getitem__(self, idx): + with self.db.begin(write=False) as txn: + sample = pickle.loads(txn.get("{}".format(idx).encode())) + sample = self.transform_way(self.transform, sample) + return sample + + @staticmethod + def lmdbify(dataset, done_pipeline, lmdb_path): + env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path)) + with env.begin(write=True) as txn: + for i in tqdm(range(len(dataset)), ncols=0): + txn.put("{}".format(i).encode(), pickle.dumps(dataset[i])) + txn.put(b"$$len$$", pickle.dumps(len(dataset))) + txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline)) + essential_attr = getattr(dataset, "essential_attr", list()) + txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr)) + for ea in essential_attr: + txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea))) diff --git a/data/dataset/pose_transfer.py b/data/dataset/pose_transfer.py new file mode 100644 index 0000000..fe6a758 --- /dev/null +++ b/data/dataset/pose_transfer.py @@ -0,0 +1,122 @@ +from collections import defaultdict +from itertools import permutations, combinations +from pathlib import Path + +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision.transforms import functional as F + +from data.registry import DATASET +from data.transform import transform_pipeline +from data.util import dlib_landmark + + +def normalize_tensor(tensor): + tensor = tensor.float() + tensor -= tensor.min() + tensor /= tensor.max() + return tensor + + +@DATASET.register_module() +class GenerationUnpairedDatasetWithEdge(Dataset): + def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256), + with_path=False): + assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"] + self.edge_type = edge_type + self.size = size + self.edges_path = Path(edges_path) + self.landmarks_path = Path(landmarks_path) + assert self.edges_path.exists() + assert self.landmarks_path.exists() + self.A = SingleFolderDataset(root_a, pipeline, with_path=True) + self.B = SingleFolderDataset(root_b, pipeline, with_path=True) + self.random_pair = random_pair + self.with_path = with_path + + def get_edge(self, origin_path): + op = Path(origin_path) + if self.edge_type.startswith("landmark_"): + edge_type = self.edge_type.lstrip("landmark_") + use_landmark = op.parent.name.endswith("A") + else: + edge_type = self.edge_type + use_landmark = False + edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png" + origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR)) + if not use_landmark: + return origin_edge + else: + landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt" + key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size) + + dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size))) + part_labels = normalize_tensor(torch.from_numpy(part_labels)) + part_edge = torch.from_numpy(part_edge).unsqueeze(0).float() + # edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face + # edges = part_edge + edges + return torch.cat([origin_edge, part_edge, dist_tensor, part_labels]) + + def __getitem__(self, idx): + a_idx = idx % len(self.A) + b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() + output = dict(a={}, b={}) + output["a"]["img"], output["a"]["path"] = self.A[a_idx] + output["b"]["img"], output["b"]["path"] = self.B[b_idx] + for p in "ab": + output[p]["edge"] = self.get_edge(output[p]["path"]) + return output + + def __len__(self): + return max(len(self.A), len(self.B)) + + def __repr__(self): + return f"\nPipeline:\n{self.A.pipeline}" + + +@DATASET.register_module() +class PoseFacesWithSingleAnime(Dataset): + def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size, + with_order=True): + self.num_face = num_face + self.landmark_path = Path(landmark_path) + self.with_order = with_order + self.root_face = Path(root_face) + self.root_anime = Path(root_anime) + self.img_size = img_size + self.face_samples = self.iter_folders() + self.face_pipeline = transform_pipeline(face_pipeline) + self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True) + + def iter_folders(self): + pics_per_person = defaultdict(list) + for p in self.root_face.glob("*.jpg"): + pics_per_person[p.stem[:7]].append(p.stem) + data = [] + for p in pics_per_person: + if len(pics_per_person[p]) >= self.num_face: + if self.with_order: + data.extend(list(combinations(pics_per_person[p], self.num_face))) + else: + data.extend(list(permutations(pics_per_person[p], self.num_face))) + return data + + def read_pose(self, pose_txt): + key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size) + dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size))) + part_labels = normalize_tensor(torch.from_numpy(part_labels)) + part_edge = torch.from_numpy(part_edge).unsqueeze(0).float() + return torch.cat([part_labels, part_edge, dist_tensor]) + + def __len__(self): + return len(self.face_samples) + + def __getitem__(self, idx): + output = dict() + output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()] + for i, f in enumerate(self.face_samples[idx]): + output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg") + output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt") + output[f"stem_{i}"] = f + return output diff --git a/data/transform.py b/data/transform.py index 1aa6c06..fad7275 100644 --- a/data/transform.py +++ b/data/transform.py @@ -28,7 +28,7 @@ class Load: def transform_pipeline(pipeline_description): - if len(pipeline_description) == 0: + if pipeline_description is None or len(pipeline_description) == 0: return lambda x: x transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description] return transforms.Compose(transform_list) diff --git a/model/GAN/__init__.py b/model/GAN/__init__.py index e69de29..7195a56 100644 --- a/model/GAN/__init__.py +++ b/model/GAN/__init__.py @@ -0,0 +1,3 @@ +from util.misc import import_submodules + +__all__ = import_submodules(__name__).keys() \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py index 20c1c47..aef3eef 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -1,10 +1,4 @@ from model.registry import MODEL, NORMALIZATION -import model.GAN.CycleGAN -import model.GAN.MUNIT -import model.GAN.TAFG -import model.GAN.TSIT -import model.GAN.UGATIT -import model.GAN.base -import model.GAN.wrapper +import model.GAN import model.base.normalization diff --git a/util/misc.py b/util/misc.py index ca8b49c..80a61ca 100644 --- a/util/misc.py +++ b/util/misc.py @@ -1,4 +1,6 @@ +import importlib import logging +import pkgutil from pathlib import Path from typing import Optional @@ -12,6 +14,24 @@ def add_spectral_norm(module): return module +def import_submodules(package, recursive=True): + """ Import all submodules of a module, recursively, including subpackages + + :param package: package (name or actual module) + :type package: str | module + :rtype: dict[str, types.ModuleType] + """ + if isinstance(package, str): + package = importlib.import_module(package) + results = {} + for loader, name, is_pkg in pkgutil.walk_packages(package.__path__): + full_name = package.__name__ + '.' + name + results[name] = importlib.import_module(full_name) + if recursive and is_pkg: + results.update(import_submodules(full_name)) + return results + + def setup_logger( name: Optional[str] = None, level: int = logging.INFO,