few-shot/data/lmdbify.py
2020-07-07 19:18:17 +08:00

36 lines
1.2 KiB
Python
Executable File

import torch
import lmdb
import os
import pickle
from io import BytesIO
from data.dataset import CARS, ImprovedImageFolder
import torchvision
from tqdm import tqdm
def dataset_to_lmdb(dataset, lmdb_path):
env = lmdb.open(lmdb_path, map_size=1099511627776 * 2, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset))):
buffer = BytesIO()
torch.save(dataset[i], buffer)
txn.put("{}".format(i).encode(), buffer.getvalue())
txn.put(b"classes_list", pickle.dumps(dataset.classes_list))
txn.put(b"__len__", pickle.dumps(len(dataset)))
def main():
data_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize([int(224 * 1.15), int(224 * 1.15)]),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
origin_dataset = ImprovedImageFolder("/data/few-shot/CUB_200_2011/CUB_200_2011/images", transform=data_transform)
dataset_to_lmdb(origin_dataset, "/data/few-shot/lmdb/CUB_200_2011/data.lmdb")
if __name__ == '__main__':
main()