few-shot/data/lmdbify.py
2020-07-23 22:32:28 +08:00

40 lines
1.2 KiB
Python
Executable File

import os
import pickle
import argparse
from PIL import Image
import lmdb
from data.dataset import ImprovedImageFolder
from tqdm import tqdm
def content_loader(path):
im = Image.open(path)
im = im.resize((256, 256))
if im.mode != "RGB":
im = im.convert("RGB")
return im
def dataset_to_lmdb(dataset, lmdb_path):
env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=50):
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"classes_list", pickle.dumps(dataset.classes_list))
txn.put(b"__len__", pickle.dumps(len(dataset)))
def transform(save_path, dataset_path):
print(save_path, dataset_path)
origin_dataset = ImprovedImageFolder(dataset_path, loader=content_loader)
dataset_to_lmdb(origin_dataset, save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="transform dataset to lmdb database")
parser.add_argument('--save', required=True)
parser.add_argument('--dataset', required=True)
args = parser.parse_args()
transform(args.save, args.dataset)