add lmdb dataset support and EpisodicDataset

This commit is contained in:
Ray Wong 2020-08-10 10:51:24 +08:00
parent 8102651a28
commit 323bf2f6ab
4 changed files with 142 additions and 36 deletions

View File

@ -1,4 +1,4 @@
name: cross-domain
name: cross-domain-1
engine: crossdomain
result_dir: ./result
@ -33,7 +33,9 @@ baseline:
dataset:
train:
path: /data/few-shot/mini_imagenet_full_size/train
lmdb_path: /data/few-shot/lmdb/mini-ImageNet/train.lmdb
pipeline:
- Load
- RandomResizedCrop:
size: [256, 256]
- ColorJitter:
@ -47,7 +49,9 @@ baseline:
std: [0.229, 0.224, 0.225]
val:
path: /data/few-shot/mini_imagenet_full_size/val
lmdb_path: /data/few-shot/lmdb/mini-ImageNet/val.lmdb
pipeline:
- Load
- Resize:
size: [286, 286]
- RandomCrop:

View File

@ -1,23 +1,52 @@
import os
import pickle
from collections import defaultdict
import torch
from torch.utils.data import Dataset
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
import lmdb
from tqdm import tqdm
from .transform import transform_pipeline
from .registry import DATASET
def default_transform_way(transform, sample):
return [transform(sample[0]), *sample[1:]]
class LMDBDataset(Dataset):
def __init__(self, lmdb_path, output_transform=None, map_size=2 ** 40, readonly=True, **lmdb_kwargs):
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
**lmdb_kwargs):
self.path = lmdb_path
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
**lmdb_kwargs)
self.output_transform = output_transform
lock=False, **lmdb_kwargs)
with self.db.begin(write=False) as txn:
self._len = pickle.loads(txn.get(b"__len__"))
self._len = pickle.loads(txn.get(b"$$len$$"))
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
if pipeline is None:
self.not_done_pipeline = []
else:
self.not_done_pipeline = self._remain_pipeline(pipeline)
self.transform = transform_pipeline(self.not_done_pipeline)
self.transform_way = transform_way
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
for ea in essential_attr:
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
def _remain_pipeline(self, pipeline):
for i, dp in enumerate(self.done_pipeline):
if pipeline[i] != dp:
raise ValueError(
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
return pipeline[len(self.done_pipeline):]
def __repr__(self):
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
def __len__(self):
return self._len
@ -25,10 +54,77 @@ class LMDBDataset(Dataset):
def __getitem__(self, idx):
with self.db.begin(write=False) as txn:
sample = pickle.loads(txn.get("{}".format(idx).encode()))
if self.output_transform is not None:
sample = self.output_transform(sample)
sample = self.transform_way(self.transform, sample)
return sample
@staticmethod
def lmdbify(dataset, done_pipeline, lmdb_path):
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=0):
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
essential_attr = getattr(dataset, "essential_attr", list())
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
for ea in essential_attr:
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))
@DATASET.register_module()
class ImprovedImageFolder(ImageFolder):
def __init__(self, root, pipeline):
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
self.classes_list = defaultdict(list)
self.essential_attr = ["classes_list"]
for i in range(len(self)):
self.classes_list[self.samples[i][-1]].append(i)
assert len(self.classes_list) == len(self.classes)
class EpisodicDataset(Dataset):
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
self.origin = origin_dataset
self.num_class = num_class
assert self.num_class < len(self.origin.classes_list)
self.num_query = num_query # K
self.num_support = num_support # K
self.num_episodes = num_episodes
def _fetch_list_data(self, id_list):
return [self.origin[i][0] for i in id_list]
def __len__(self):
return self.num_episodes
def __getitem__(self, _):
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set_list = []
query_set_list = []
target_list = []
for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c]
if len(image_list) >= self.num_query + self.num_support:
# have enough images belong to this class
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
support_set_list.extend(support)
query_set_list.extend(query)
target_list.extend([tag] * self.num_query)
return {
"support": torch.stack(support_set_list),
"query": torch.stack(query_set_list),
"target": torch.tensor(target_list)
}
def __repr__(self):
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
@DATASET.register_module()
class SingleFolderDataset(Dataset):

View File

@ -14,40 +14,28 @@ from ignite.contrib.handlers import ProgressBar
from util.build import build_model, build_optimizer
from util.handler import setup_common_handlers
from data.transform import transform_pipeline
from data.dataset import LMDBDataset
def baseline_trainer(config, logger, val_loader):
def baseline_trainer(config, logger):
model = build_model(config.model, config.distributed.model)
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
loss_fn = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True)
trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True,
output_transform=lambda x, y, y_pred, loss: (loss.item(), y_pred, y))
trainer.logger = logger
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
Accuracy(output_transform=lambda x: (x[1], x[2])).attach(trainer, "acc")
ProgressBar(ncols=0).attach(trainer)
val_metrics = {
"accuracy": Accuracy(),
"nll": Loss(loss_fn)
}
evaluator = create_supervised_evaluator(model, val_metrics, idist.device())
ProgressBar(ncols=0).attach(evaluator)
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(engine):
logger.info(f"Epoch[{engine.state.epoch}] Loss: {engine.state.output:.2f}")
evaluator.run(val_loader)
metrics = evaluator.state.metrics
logger.info("Training Results - Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))
if idist.get_rank() == 0:
GpuInfo().attach(trainer, name='gpu')
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_logger.attach(
evaluator,
trainer,
log_handler=OutputHandler(
tag="val",
tag="train",
metric_names='all',
global_step_transform=global_step_from_engine(trainer),
),
@ -70,8 +58,7 @@ def baseline_trainer(config, logger, val_loader):
to_save = dict(model=model, optimizer=optimizer, trainer=trainer)
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.EPOCH_COMPLETED, to_save=to_save,
save_interval_event=Events.EPOCH_COMPLETED(every=25), n_saved=5,
metrics_to_print=["loss"])
save_best_model_by_val_score(config.output_dir, evaluator, model, "accuracy", 1, trainer)
metrics_to_print=["loss", "acc"])
return trainer
@ -80,14 +67,13 @@ def run(task, config, logger):
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "baseline":
train_dataset = ImageFolder(config.baseline.data.dataset.train.path,
transform=transform_pipeline(config.baseline.data.dataset.train.pipeline))
val_dataset = ImageFolder(config.baseline.data.dataset.val.path,
transform=transform_pipeline(config.baseline.data.dataset.val.pipeline))
train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path,
pipeline=config.baseline.data.dataset.train.pipeline)
# train_dataset = ImageFolder(config.baseline.data.dataset.train.path,
# transform=transform_pipeline(config.baseline.data.dataset.train.pipeline))
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
val_data_loader = idist.auto_dataloader(val_dataset, **config.baseline.data.dataloader)
trainer = baseline_trainer(config, logger, val_data_loader)
trainer = baseline_trainer(config, logger)
try:
trainer.run(train_data_loader, max_epochs=400)
except Exception:

20
tool/lmdbify.py Normal file
View File

@ -0,0 +1,20 @@
import fire
from omegaconf import OmegaConf
from data.dataset import ImprovedImageFolder, LMDBDataset
pipeline = """
pipeline:
- Load
"""
def transform(dataset_path, save_path):
print(save_path, dataset_path)
conf = OmegaConf.create(pipeline)
print(conf.pipeline.pretty())
origin_dataset = ImprovedImageFolder(dataset_path, conf.pipeline)
LMDBDataset.lmdbify(origin_dataset, conf.pipeline, save_path)
if __name__ == '__main__':
fire.Fire(transform)