diff --git a/data/dataset.py b/data/dataset.py index 0c72fb2..2db4f46 100755 --- a/data/dataset.py +++ b/data/dataset.py @@ -46,7 +46,7 @@ class EpisodicDataset(Dataset): self.origin = origin_dataset self.num_class = num_class assert self.num_class < len(self.origin.classes_list) - self.num_set = num_set*2 # 2*K + self.num_set = num_set # K self.num_episodes = num_episodes def __len__(self): @@ -54,15 +54,23 @@ class EpisodicDataset(Dataset): def __getitem__(self, _): random_classes = torch.randint(high=len(self.origin.classes_list), size=(self.num_class,)).tolist() - item = {} - for i in random_classes: - image_list = self.origin.classes_list[i] - if len(image_list) > self.num_set: - idx_list = torch.randperm(len(image_list))[:self.num_set].tolist() + support_set_list = [] + query_set_list = [] + target_list = [] + for i, c in enumerate(random_classes): + image_list = self.origin.classes_list[c] + if len(image_list) > self.num_set * 2: + idx_list = torch.randperm(len(image_list))[:self.num_set*2].tolist() else: - idx_list = torch.randint(high=len(image_list), size=(self.num_set,)).tolist() - item[i] = [self.origin[idx] for idx in idx_list] - return item + idx_list = torch.randint(high=len(image_list), size=(self.num_set*2,)).tolist() + support_set_list.extend([self.origin[idx] for idx in idx_list[:self.num_set]]) + query_set_list.extend([self.origin[idx] for idx in idx_list[self.num_set:]]) + target_list.extend([i]*self.num_set) + return { + "support": torch.stack(support_set_list), + "query": torch.stack(query_set_list), + "target": torch.tensor(target_list) + } def __repr__(self): return "".format(self.num_class, self.num_set, self.num_episodes) diff --git a/test.py b/test.py index 698815e..9a45c8d 100755 --- a/test.py +++ b/test.py @@ -3,6 +3,7 @@ from torch.utils.data import DataLoader import torchvision from data import dataset import torch.nn as nn +from ignite.utils import convert_tensor def setup_seed(seed): @@ -31,15 +32,17 @@ def euclidean_dist(x, y): def evaluate(query, target, support): """ - :param query: NK x D vector - :param target: NK x 1 vector - :param support: N x K x D vector + :param query: B x NK x D vector + :param target: B x NK x 1 vector + :param support: B x N x K x D vector :return: """ + K = support.size(1) prototypes = support.mean(1) distance = euclidean_dist(query, prototypes) indices = distance.argmin(1) - return torch.eq(target, indices).float().mean() + y_hat = torch.tensor([target[i*K-1] for i in indices]) + return torch.eq(target.to("cpu"), y_hat).float().mean() class Flatten(nn.Module): @@ -49,6 +52,7 @@ class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) + # def make_extractor(): # resnet50 = torchvision.models.resnet50(pretrained=True) # resnet50.to(torch.device("cuda")) @@ -97,41 +101,32 @@ def test(): torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) origin_dataset = dataset.CARS("/data/few-shot/STANFORD-CARS/", transform=data_transform) + #origin_dataset = dataset.ImprovedImageFolder("/data/few-shot/mini_imagenet_full_size/train", transform=data_transform) + N = torch.randint(5, 10, (1,)).tolist()[0] + K = torch.randint(1, 10, (1,)).tolist()[0] episodic_dataset = dataset.EpisodicDataset( origin_dataset, # 抽取数据集 - torch.randint(5, 10, (1,)).tolist()[0], # N - torch.randint(1, 10, (1,)).tolist()[0], # K - 5 # 任务数目 + N, # N + K, # K + 100 # 任务数目 ) print(episodic_dataset) - data_loader = DataLoader(episodic_dataset) + data_loader = DataLoader(episodic_dataset, batch_size=2) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") extractor = make_extractor() accs = [] for item in data_loader: - support_list = [] - query_list = [] - class_id_list = [] - for class_id in item: - for i in range(len(item[class_id])): - item[class_id][i] = item[class_id][i].to(device) - num_support_set = len(item[class_id]) // 2 - num_query_set = len(item[class_id]) - num_support_set - support_list.append(torch.stack([extractor(pair) for pair in item[class_id][:num_support_set]])) - query_list.append(torch.stack([extractor(pair) for pair in item[class_id][num_support_set:]])) - class_id_list.extend([class_id]*num_query_set) - query = torch.squeeze(torch.cat(query_list)).to(device) - support = torch.squeeze(torch.stack(support_list)).to(device) - target = torch.squeeze(torch.tensor(class_id_list)).to(device) - - accs.append(evaluate(query, target, support)) - - print(accs) + item = convert_tensor(item, device) + query_batch = [extractor(images) for images in item["query"]] + support_batch = [torch.stack(torch.split(extractor(images), K)) for images in item["support"]] + for i in range(len(query_batch)): + accs.append(evaluate(query_batch[i], item["target"][i], support_batch[i])) + print(torch.tensor(accs).mean().item()) if __name__ == '__main__': - setup_seed(10) - test() \ No newline at end of file + setup_seed(100) + test()