diff --git a/configs/few-shot/crossdomain.yml b/configs/few-shot/crossdomain.yml index 8b4cb56..b777b17 100644 --- a/configs/few-shot/crossdomain.yml +++ b/configs/few-shot/crossdomain.yml @@ -1,4 +1,4 @@ -name: cross-domain +name: cross-domain-1 engine: crossdomain result_dir: ./result @@ -33,7 +33,9 @@ baseline: dataset: train: path: /data/few-shot/mini_imagenet_full_size/train + lmdb_path: /data/few-shot/lmdb/mini-ImageNet/train.lmdb pipeline: + - Load - RandomResizedCrop: size: [256, 256] - ColorJitter: @@ -47,7 +49,9 @@ baseline: std: [0.229, 0.224, 0.225] val: path: /data/few-shot/mini_imagenet_full_size/val + lmdb_path: /data/few-shot/lmdb/mini-ImageNet/val.lmdb pipeline: + - Load - Resize: size: [286, 286] - RandomCrop: diff --git a/data/dataset.py b/data/dataset.py index 87c0789..3cb9800 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -1,23 +1,52 @@ import os import pickle +from collections import defaultdict import torch from torch.utils.data import Dataset -from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS +from torchvision.datasets import ImageFolder +from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader import lmdb +from tqdm import tqdm from .transform import transform_pipeline from .registry import DATASET +def default_transform_way(transform, sample): + return [transform(sample[0]), *sample[1:]] + + class LMDBDataset(Dataset): - def __init__(self, lmdb_path, output_transform=None, map_size=2 ** 40, readonly=True, **lmdb_kwargs): + def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True, + **lmdb_kwargs): + self.path = lmdb_path self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly, - **lmdb_kwargs) - self.output_transform = output_transform + lock=False, **lmdb_kwargs) + with self.db.begin(write=False) as txn: - self._len = pickle.loads(txn.get(b"__len__")) + self._len = pickle.loads(txn.get(b"$$len$$")) + self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$")) + if pipeline is None: + self.not_done_pipeline = [] + else: + self.not_done_pipeline = self._remain_pipeline(pipeline) + self.transform = transform_pipeline(self.not_done_pipeline) + self.transform_way = transform_way + essential_attr = pickle.loads(txn.get(b"$$essential_attr$$")) + for ea in essential_attr: + setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8")))) + + def _remain_pipeline(self, pipeline): + for i, dp in enumerate(self.done_pipeline): + if pipeline[i] != dp: + raise ValueError( + f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}") + return pipeline[len(self.done_pipeline):] + + def __repr__(self): + return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}" def __len__(self): return self._len @@ -25,10 +54,77 @@ class LMDBDataset(Dataset): def __getitem__(self, idx): with self.db.begin(write=False) as txn: sample = pickle.loads(txn.get("{}".format(idx).encode())) - if self.output_transform is not None: - sample = self.output_transform(sample) + sample = self.transform_way(self.transform, sample) return sample + @staticmethod + def lmdbify(dataset, done_pipeline, lmdb_path): + env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path)) + with env.begin(write=True) as txn: + for i in tqdm(range(len(dataset)), ncols=0): + txn.put("{}".format(i).encode(), pickle.dumps(dataset[i])) + txn.put(b"$$len$$", pickle.dumps(len(dataset))) + txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline)) + essential_attr = getattr(dataset, "essential_attr", list()) + txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr)) + for ea in essential_attr: + txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea))) + + +@DATASET.register_module() +class ImprovedImageFolder(ImageFolder): + def __init__(self, root, pipeline): + super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x) + self.classes_list = defaultdict(list) + self.essential_attr = ["classes_list"] + for i in range(len(self)): + self.classes_list[self.samples[i][-1]].append(i) + assert len(self.classes_list) == len(self.classes) + + +class EpisodicDataset(Dataset): + def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes): + self.origin = origin_dataset + self.num_class = num_class + assert self.num_class < len(self.origin.classes_list) + self.num_query = num_query # K + self.num_support = num_support # K + self.num_episodes = num_episodes + + def _fetch_list_data(self, id_list): + return [self.origin[i][0] for i in id_list] + + def __len__(self): + return self.num_episodes + + def __getitem__(self, _): + random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() + support_set_list = [] + query_set_list = [] + target_list = [] + for tag, c in enumerate(random_classes): + image_list = self.origin.classes_list[c] + + if len(image_list) >= self.num_query + self.num_support: + # have enough images belong to this class + idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist() + else: + idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist() + + support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support])) + query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:])) + support_set_list.extend(support) + query_set_list.extend(query) + target_list.extend([tag] * self.num_query) + return { + "support": torch.stack(support_set_list), + "query": torch.stack(query_set_list), + "target": torch.tensor(target_list) + } + + def __repr__(self): + return f"" + @DATASET.register_module() class SingleFolderDataset(Dataset): diff --git a/engine/crossdomain.py b/engine/crossdomain.py index be3db60..1e9bf57 100644 --- a/engine/crossdomain.py +++ b/engine/crossdomain.py @@ -14,40 +14,28 @@ from ignite.contrib.handlers import ProgressBar from util.build import build_model, build_optimizer from util.handler import setup_common_handlers from data.transform import transform_pipeline +from data.dataset import LMDBDataset -def baseline_trainer(config, logger, val_loader): +def baseline_trainer(config, logger): model = build_model(config.model, config.distributed.model) optimizer = build_optimizer(model.parameters(), config.baseline.optimizers) loss_fn = nn.CrossEntropyLoss() - trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True) + trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True, + output_transform=lambda x, y, y_pred, loss: (loss.item(), y_pred, y)) trainer.logger = logger - RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") + RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss") + Accuracy(output_transform=lambda x: (x[1], x[2])).attach(trainer, "acc") ProgressBar(ncols=0).attach(trainer) - val_metrics = { - "accuracy": Accuracy(), - "nll": Loss(loss_fn) - } - evaluator = create_supervised_evaluator(model, val_metrics, idist.device()) - ProgressBar(ncols=0).attach(evaluator) - - @trainer.on(Events.EPOCH_COMPLETED) - def log_training_loss(engine): - logger.info(f"Epoch[{engine.state.epoch}] Loss: {engine.state.output:.2f}") - evaluator.run(val_loader) - metrics = evaluator.state.metrics - logger.info("Training Results - Avg accuracy: {:.2f} Avg loss: {:.2f}" - .format(trainer.state.epoch, metrics["accuracy"], metrics["nll"])) - if idist.get_rank() == 0: GpuInfo().attach(trainer, name='gpu') tb_logger = TensorboardLogger(log_dir=config.output_dir) tb_logger.attach( - evaluator, + trainer, log_handler=OutputHandler( - tag="val", + tag="train", metric_names='all', global_step_transform=global_step_from_engine(trainer), ), @@ -70,8 +58,7 @@ def baseline_trainer(config, logger, val_loader): to_save = dict(model=model, optimizer=optimizer, trainer=trainer) setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.EPOCH_COMPLETED, to_save=to_save, save_interval_event=Events.EPOCH_COMPLETED(every=25), n_saved=5, - metrics_to_print=["loss"]) - save_best_model_by_val_score(config.output_dir, evaluator, model, "accuracy", 1, trainer) + metrics_to_print=["loss", "acc"]) return trainer @@ -80,14 +67,13 @@ def run(task, config, logger): torch.backends.cudnn.benchmark = True logger.info(f"start task {task}") if task == "baseline": - train_dataset = ImageFolder(config.baseline.data.dataset.train.path, - transform=transform_pipeline(config.baseline.data.dataset.train.pipeline)) - val_dataset = ImageFolder(config.baseline.data.dataset.val.path, - transform=transform_pipeline(config.baseline.data.dataset.val.pipeline)) + train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path, + pipeline=config.baseline.data.dataset.train.pipeline) + # train_dataset = ImageFolder(config.baseline.data.dataset.train.path, + # transform=transform_pipeline(config.baseline.data.dataset.train.pipeline)) logger.info(f"train with dataset:\n{train_dataset}") train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader) - val_data_loader = idist.auto_dataloader(val_dataset, **config.baseline.data.dataloader) - trainer = baseline_trainer(config, logger, val_data_loader) + trainer = baseline_trainer(config, logger) try: trainer.run(train_data_loader, max_epochs=400) except Exception: diff --git a/tool/lmdbify.py b/tool/lmdbify.py new file mode 100644 index 0000000..5a55f6b --- /dev/null +++ b/tool/lmdbify.py @@ -0,0 +1,20 @@ +import fire +from omegaconf import OmegaConf +from data.dataset import ImprovedImageFolder, LMDBDataset + +pipeline = """ +pipeline: + - Load +""" + + +def transform(dataset_path, save_path): + print(save_path, dataset_path) + conf = OmegaConf.create(pipeline) + print(conf.pipeline.pretty()) + origin_dataset = ImprovedImageFolder(dataset_path, conf.pipeline) + LMDBDataset.lmdbify(origin_dataset, conf.pipeline, save_path) + + +if __name__ == '__main__': + fire.Fire(transform)