few-shot/data/lmdbify.py
2020-07-29 00:03:16 +08:00

35 lines
970 B
Python
Executable File

import os
import pickle
from PIL import Image
import lmdb
from data.dataset import ImprovedImageFolder
from tqdm import tqdm
import fire
def content_loader(path):
im = Image.open(path)
im = im.resize((256, 256))
if im.mode != "RGB":
im = im.convert("RGB")
return im
def dataset_to_lmdb(dataset, lmdb_path):
env = lmdb.open(lmdb_path, map_size=1099511627776*2, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=50):
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"classes_list", pickle.dumps(dataset.classes_list))
txn.put(b"__len__", pickle.dumps(len(dataset)))
def transform(save_path, dataset_path):
print(save_path, dataset_path)
origin_dataset = ImprovedImageFolder(dataset_path, loader=content_loader)
dataset_to_lmdb(origin_dataset, save_path)
if __name__ == '__main__':
fire.Fire(transform)