from collections import defaultdict import torch from torch.utils.data import Dataset from torchvision.datasets import ImageFolder from data.registry import DATASET from data.transform import transform_pipeline @DATASET.register_module() class ImprovedImageFolder(ImageFolder): def __init__(self, root, pipeline): super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x) self.classes_list = defaultdict(list) self.essential_attr = ["classes_list"] for i in range(len(self)): self.classes_list[self.samples[i][-1]].append(i) assert len(self.classes_list) == len(self.classes) class EpisodicDataset(Dataset): def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes): self.origin = origin_dataset self.num_class = num_class assert self.num_class < len(self.origin.classes_list) self.num_query = num_query # K self.num_support = num_support # K self.num_episodes = num_episodes def _fetch_list_data(self, id_list): return [self.origin[i][0] for i in id_list] def __len__(self): return self.num_episodes def __getitem__(self, _): random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() support_set = [] query_set = [] target_set = [] for tag, c in enumerate(random_classes): image_list = self.origin.classes_list[c] if len(image_list) >= self.num_query + self.num_support: # have enough images belong to this class idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist() else: idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist() support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support])) query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:])) support_set.extend(support) query_set.extend(query) target_set.extend([tag] * self.num_query) return { "support": torch.stack(support_set), "query": torch.stack(query_set), "target": torch.tensor(target_set) } def __repr__(self): return f""