from scipy.io import loadmat import torch from torch.utils.data import Dataset from torchvision.datasets.folder import default_loader from torchvision.datasets import ImageFolder from pathlib import Path from collections import defaultdict class CARS(Dataset): def __init__(self, root, loader=default_loader, transform=None): self.root = Path(root) self.transform = transform self.loader = loader self.annotations = loadmat(self.root / "devkit/cars_train_annos.mat")["annotations"][0] self.annotations = {d[-1].item(): d[-2].item() - 1 for d in self.annotations} self.classes_list = defaultdict(list) for i in range(len(self.annotations)): self.classes_list[self.annotations["{:05d}.jpg".format(i + 1)]].append(i) def __len__(self): return len(self.annotations) def __getitem__(self, item): file_name = "{:05d}.jpg".format(item+1) target = self.annotations[file_name] sample = self.loader(self.root / "cars_train" / file_name) if self.transform is not None: sample = self.transform(sample) return sample class ImprovedImageFolder(ImageFolder): def __init__(self, root, loader=default_loader, transform=None): super().__init__(root, transform, loader=loader) self.classes_list = defaultdict(list) for i in range(len(self)): self.classes_list[self.samples[i][-1]].append(i) def __getitem__(self, item): return super().__getitem__(item)[0] class EpisodicDataset(Dataset): def __init__(self, origin_dataset, num_class, num_set, num_episodes): self.origin = origin_dataset self.num_class = num_class assert self.num_class < len(self.origin.classes_list) self.num_set = num_set # K self.num_episodes = num_episodes def __len__(self): return self.num_episodes def __getitem__(self, _): random_classes = torch.randint(high=len(self.origin.classes_list), size=(self.num_class,)).tolist() support_set_list = [] query_set_list = [] target_list = [] for i, c in enumerate(random_classes): image_list = self.origin.classes_list[c] if len(image_list) > self.num_set * 2: idx_list = torch.randperm(len(image_list))[:self.num_set*2].tolist() else: idx_list = torch.randint(high=len(image_list), size=(self.num_set*2,)).tolist() support_set_list.extend([self.origin[idx] for idx in idx_list[:self.num_set]]) query_set_list.extend([self.origin[idx] for idx in idx_list[self.num_set:]]) target_list.extend([i]*self.num_set) return { "support": torch.stack(support_set_list), "query": torch.stack(query_set_list), "target": torch.tensor(target_list) } def __repr__(self): return "".format(self.num_class, self.num_set, self.num_episodes)