64 lines
2.4 KiB
Python
64 lines
2.4 KiB
Python
from collections import defaultdict
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from torchvision.datasets import ImageFolder
|
|
|
|
from data.registry import DATASET
|
|
from data.transform import transform_pipeline
|
|
|
|
|
|
@DATASET.register_module()
|
|
class ImprovedImageFolder(ImageFolder):
|
|
def __init__(self, root, pipeline):
|
|
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
|
|
self.classes_list = defaultdict(list)
|
|
self.essential_attr = ["classes_list"]
|
|
for i in range(len(self)):
|
|
self.classes_list[self.samples[i][-1]].append(i)
|
|
assert len(self.classes_list) == len(self.classes)
|
|
|
|
|
|
class EpisodicDataset(Dataset):
|
|
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
|
|
self.origin = origin_dataset
|
|
self.num_class = num_class
|
|
assert self.num_class < len(self.origin.classes_list)
|
|
self.num_query = num_query # K
|
|
self.num_support = num_support # K
|
|
self.num_episodes = num_episodes
|
|
|
|
def _fetch_list_data(self, id_list):
|
|
return [self.origin[i][0] for i in id_list]
|
|
|
|
def __len__(self):
|
|
return self.num_episodes
|
|
|
|
def __getitem__(self, _):
|
|
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
|
|
support_set = []
|
|
query_set = []
|
|
target_set = []
|
|
for tag, c in enumerate(random_classes):
|
|
image_list = self.origin.classes_list[c]
|
|
|
|
if len(image_list) >= self.num_query + self.num_support:
|
|
# have enough images belong to this class
|
|
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
|
|
else:
|
|
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
|
|
|
|
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
|
|
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
|
|
support_set.extend(support)
|
|
query_set.extend(query)
|
|
target_set.extend([tag] * self.num_query)
|
|
return {
|
|
"support": torch.stack(support_set),
|
|
"query": torch.stack(query_set),
|
|
"target": torch.tensor(target_set)
|
|
}
|
|
|
|
def __repr__(self):
|
|
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
|