import torch from torch.utils.data import DataLoader import torchvision from data import dataset import torch.nn as nn def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True def euclidean_dist(x, y): """ Compute euclidean distance between two tensors """ # x: N x D # y: M x D n = x.size(0) m = y.size(0) d = x.size(1) if d != y.size(1): raise Exception x = x.unsqueeze(1).expand(n, m, d) y = y.unsqueeze(0).expand(n, m, d) return torch.pow(x - y, 2).sum(2) def evaluate(query, target, support): """ :param query: NK x D vector :param target: NK x 1 vector :param support: N x K x D vector :return: """ prototypes = support.mean(1) distance = euclidean_dist(query, prototypes) indices = distance.argmin(1) return torch.eq(target, indices).float().mean() class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): return x.view(x.size(0), -1) # def make_extractor(): # resnet50 = torchvision.models.resnet50(pretrained=True) # resnet50.to(torch.device("cuda")) # resnet50.fc = torch.nn.Identity() # resnet50.eval() # # def extract(images): # with torch.no_grad(): # return resnet50(images) # return extract def make_extractor(): model = resnet18() model.to(torch.device("cuda")) model.eval() def extract(images): with torch.no_grad(): return model(images) return extract def resnet18(model_path="ResNet18Official.pth"): """Constructs a ResNet-18 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_w_fc = torchvision.models.resnet18(pretrained=False) seq = list(model_w_fc.children())[:-1] seq.append(Flatten()) model = torch.nn.Sequential(*seq) # model.load_state_dict(torch.load(model_path), strict=False) model.load_state_dict(torch.load(model_path, map_location ='cpu'), strict=False) # model.load_state_dict(torch.load(model_path)) model.eval() return model def test(): data_transform = torchvision.transforms.Compose([ torchvision.transforms.Resize([int(224*1.15), int(224*1.15)]), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) origin_dataset = dataset.CARS("/data/few-shot/STANFORD-CARS/", transform=data_transform) episodic_dataset = dataset.EpisodicDataset( origin_dataset, # 抽取数据集 torch.randint(5, 10, (1,)).tolist()[0], # N torch.randint(1, 10, (1,)).tolist()[0], # K 5 # 任务数目 ) print(episodic_dataset) data_loader = DataLoader(episodic_dataset) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") extractor = make_extractor() accs = [] for item in data_loader: support_list = [] query_list = [] class_id_list = [] for class_id in item: for i in range(len(item[class_id])): item[class_id][i] = item[class_id][i].to(device) num_support_set = len(item[class_id]) // 2 num_query_set = len(item[class_id]) - num_support_set support_list.append(torch.stack([extractor(pair) for pair in item[class_id][:num_support_set]])) query_list.append(torch.stack([extractor(pair) for pair in item[class_id][num_support_set:]])) class_id_list.extend([class_id]*num_query_set) query = torch.squeeze(torch.cat(query_list)).to(device) support = torch.squeeze(torch.stack(support_list)).to(device) target = torch.squeeze(torch.tensor(class_id_list)).to(device) accs.append(evaluate(query, target, support)) print(accs) if __name__ == '__main__': setup_seed(10) test()