Compare commits

...

3 Commits

Author SHA1 Message Date
323bf2f6ab add lmdb dataset support and EpisodicDataset 2020-08-10 10:51:24 +08:00
8102651a28 add code for few-shot baseline 2020-08-10 08:51:26 +08:00
649f2244f7 improve run.sh 2020-08-10 08:50:58 +08:00
9 changed files with 401 additions and 14 deletions

View File

@ -0,0 +1,66 @@
name: cross-domain-1
engine: crossdomain
result_dir: ./result
distributed:
model:
# broadcast_buffers: False
misc:
random_seed: 1004
checkpoints:
interval: 2000
log:
logger:
level: 20 # DEBUG(10) INFO(20)
model:
_type: resnet10
baseline:
plusplus: False
optimizers:
_type: Adam
data:
dataloader:
batch_size: 1024
shuffle: True
num_workers: 16
pin_memory: True
drop_last: True
dataset:
train:
path: /data/few-shot/mini_imagenet_full_size/train
lmdb_path: /data/few-shot/lmdb/mini-ImageNet/train.lmdb
pipeline:
- Load
- RandomResizedCrop:
size: [256, 256]
- ColorJitter:
brightness: 0.4
contrast: 0.4
saturation: 0.4
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
val:
path: /data/few-shot/mini_imagenet_full_size/val
lmdb_path: /data/few-shot/lmdb/mini-ImageNet/val.lmdb
pipeline:
- Load
- Resize:
size: [286, 286]
- RandomCrop:
size: [256, 256]
- ToTensor
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

View File

@ -1,23 +1,52 @@
import os import os
import pickle import pickle
from collections import defaultdict
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
import lmdb import lmdb
from tqdm import tqdm
from .transform import transform_pipeline from .transform import transform_pipeline
from .registry import DATASET from .registry import DATASET
def default_transform_way(transform, sample):
return [transform(sample[0]), *sample[1:]]
class LMDBDataset(Dataset): class LMDBDataset(Dataset):
def __init__(self, lmdb_path, output_transform=None, map_size=2 ** 40, readonly=True, **lmdb_kwargs): def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
**lmdb_kwargs):
self.path = lmdb_path
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly, self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
**lmdb_kwargs) lock=False, **lmdb_kwargs)
self.output_transform = output_transform
with self.db.begin(write=False) as txn: with self.db.begin(write=False) as txn:
self._len = pickle.loads(txn.get(b"__len__")) self._len = pickle.loads(txn.get(b"$$len$$"))
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
if pipeline is None:
self.not_done_pipeline = []
else:
self.not_done_pipeline = self._remain_pipeline(pipeline)
self.transform = transform_pipeline(self.not_done_pipeline)
self.transform_way = transform_way
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
for ea in essential_attr:
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
def _remain_pipeline(self, pipeline):
for i, dp in enumerate(self.done_pipeline):
if pipeline[i] != dp:
raise ValueError(
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
return pipeline[len(self.done_pipeline):]
def __repr__(self):
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
def __len__(self): def __len__(self):
return self._len return self._len
@ -25,10 +54,77 @@ class LMDBDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
with self.db.begin(write=False) as txn: with self.db.begin(write=False) as txn:
sample = pickle.loads(txn.get("{}".format(idx).encode())) sample = pickle.loads(txn.get("{}".format(idx).encode()))
if self.output_transform is not None: sample = self.transform_way(self.transform, sample)
sample = self.output_transform(sample)
return sample return sample
@staticmethod
def lmdbify(dataset, done_pipeline, lmdb_path):
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=0):
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
essential_attr = getattr(dataset, "essential_attr", list())
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
for ea in essential_attr:
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))
@DATASET.register_module()
class ImprovedImageFolder(ImageFolder):
def __init__(self, root, pipeline):
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
self.classes_list = defaultdict(list)
self.essential_attr = ["classes_list"]
for i in range(len(self)):
self.classes_list[self.samples[i][-1]].append(i)
assert len(self.classes_list) == len(self.classes)
class EpisodicDataset(Dataset):
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
self.origin = origin_dataset
self.num_class = num_class
assert self.num_class < len(self.origin.classes_list)
self.num_query = num_query # K
self.num_support = num_support # K
self.num_episodes = num_episodes
def _fetch_list_data(self, id_list):
return [self.origin[i][0] for i in id_list]
def __len__(self):
return self.num_episodes
def __getitem__(self, _):
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set_list = []
query_set_list = []
target_list = []
for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c]
if len(image_list) >= self.num_query + self.num_support:
# have enough images belong to this class
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
support_set_list.extend(support)
query_set_list.extend(query)
target_list.extend([tag] * self.num_query)
return {
"support": torch.stack(support_set_list),
"query": torch.stack(query_set_list),
"target": torch.tensor(target_list)
}
def __repr__(self):
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
@DATASET.register_module() @DATASET.register_module()
class SingleFolderDataset(Dataset): class SingleFolderDataset(Dataset):

83
engine/crossdomain.py Normal file
View File

@ -0,0 +1,83 @@
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
import ignite.distributed as idist
from ignite.contrib.metrics.gpu_info import GpuInfo
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, global_step_from_engine, OutputHandler, \
WeightsScalarHandler, GradsHistHandler, WeightsHistHandler, GradsScalarHandler
from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events
from ignite.metrics import Accuracy, Loss, RunningAverage
from ignite.contrib.engines.common import save_best_model_by_val_score
from ignite.contrib.handlers import ProgressBar
from util.build import build_model, build_optimizer
from util.handler import setup_common_handlers
from data.transform import transform_pipeline
from data.dataset import LMDBDataset
def baseline_trainer(config, logger):
model = build_model(config.model, config.distributed.model)
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
loss_fn = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True,
output_transform=lambda x, y, y_pred, loss: (loss.item(), y_pred, y))
trainer.logger = logger
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
Accuracy(output_transform=lambda x: (x[1], x[2])).attach(trainer, "acc")
ProgressBar(ncols=0).attach(trainer)
if idist.get_rank() == 0:
GpuInfo().attach(trainer, name='gpu')
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="train",
metric_names='all',
global_step_transform=global_step_from_engine(trainer),
),
event_name=Events.EPOCH_COMPLETED
)
tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model),
event_name=Events.EPOCH_COMPLETED(every=10))
tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
tb_logger.attach(trainer, log_handler=GradsScalarHandler(model),
event_name=Events.EPOCH_COMPLETED(every=10))
tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
@trainer.on(Events.COMPLETED)
def _():
tb_logger.close()
to_save = dict(model=model, optimizer=optimizer, trainer=trainer)
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.EPOCH_COMPLETED, to_save=to_save,
save_interval_event=Events.EPOCH_COMPLETED(every=25), n_saved=5,
metrics_to_print=["loss", "acc"])
return trainer
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "baseline":
train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path,
pipeline=config.baseline.data.dataset.train.pipeline)
# train_dataset = ImageFolder(config.baseline.data.dataset.train.path,
# transform=transform_pipeline(config.baseline.data.dataset.train.pipeline))
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
trainer = baseline_trainer(config, logger)
try:
trainer.run(train_data_loader, max_epochs=400)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

View File

@ -30,7 +30,10 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
config.output_dir = str(output_dir) config.output_dir = str(output_dir)
if output_dir.exists(): if output_dir.exists():
assert not any(output_dir.iterdir()), "output_dir must be empty" # assert not any(output_dir.iterdir()), "output_dir must be empty"
contains = list(output_dir.iterdir())
assert (len(contains) == 0) or (len(contains) == 1 and contains[0].name == "config.yml"), \
f"output_dir must by empty or only contains config.yml, but now got {len(contains)} files"
else: else:
if idist.get_rank() == 0: if idist.get_rank() == 0:
output_dir.mkdir(parents=True) output_dir.mkdir(parents=True)

View File

@ -1,2 +1,3 @@
from model.registry import MODEL from model.registry import MODEL
import model.residual_generator import model.residual_generator
import model.fewshot

105
model/fewshot.py Normal file
View File

@ -0,0 +1,105 @@
import math
import torch.nn as nn
from .registry import MODEL
# --- gaussian initialize ---
def init_layer(l):
# Initialization using fan-in
if isinstance(l, nn.Conv2d):
n = l.kernel_size[0] * l.kernel_size[1] * l.out_channels
l.weight.data.normal_(0, math.sqrt(2.0 / float(n)))
elif isinstance(l, nn.BatchNorm2d):
l.weight.data.fill_(1)
l.bias.data.fill_(0)
elif isinstance(l, nn.Linear):
l.bias.data.fill_(0)
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.size(0), -1)
class SimpleBlock(nn.Module):
def __init__(self, in_channels, out_channels, half_res, leakyrelu=False):
super(SimpleBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
)
self.relu = nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True)
if in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, 2 if half_res else 1, bias=False),
nn.BatchNorm2d(out_channels)
)
else:
self.shortcut = nn.Identity()
def forward(self, x):
o = self.block(x)
return self.relu(o + self.shortcut(x))
class ResNet(nn.Module):
def __init__(self, block, layers, dims, num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
super().__init__()
assert len(layers) == 4, 'Can have only four stages'
self.inplanes = 64
self.start = nn.Sequential(
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(self.inplanes),
nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
trunk = []
in_channels = self.inplanes
for i in range(4):
for j in range(layers[i]):
half_res = i >= 1 and j == 0
trunk.append(block(in_channels, dims[i], half_res, leakyrelu))
in_channels = dims[i]
if flatten:
trunk.append(nn.AvgPool2d(7))
trunk.append(Flatten())
if num_classes is not None:
if classifier_type == "linear":
trunk.append(nn.Linear(in_channels, num_classes))
elif classifier_type == "distlinear":
pass
else:
raise ValueError(f"invalid classifier_type:{classifier_type}")
self.trunk = nn.Sequential(*trunk)
self.apply(init_layer)
def forward(self, x):
return self.trunk(self.start(x))
@MODEL.register_module()
def resnet10(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
@MODEL.register_module()
def resnet18(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
@MODEL.register_module()
def resnet34(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)

14
run.sh
View File

@ -1,8 +1,14 @@
#!/usr/bin/env bash #!/usr/bin/env bash
CONFIG=$1 CONFIG=$1
GPUS=$2 TASK=$2
GPUS=$3
# CUDA_VISIBLE_DEVICES=$GPUS \ _command="print(len('${GPUS}'.split(',')))"
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPUS" \ GPU_COUNT=$(python3 -c "${_command}")
main.py train "$CONFIG" --backup_config --setup_output_dir --setup_random_seed
echo "GPU_COUNT:${GPU_COUNT}"
CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed

20
tool/lmdbify.py Normal file
View File

@ -0,0 +1,20 @@
import fire
from omegaconf import OmegaConf
from data.dataset import ImprovedImageFolder, LMDBDataset
pipeline = """
pipeline:
- Load
"""
def transform(dataset_path, save_path):
print(save_path, dataset_path)
conf = OmegaConf.create(pipeline)
print(conf.pipeline.pretty())
origin_dataset = ImprovedImageFolder(dataset_path, conf.pipeline)
LMDBDataset.lmdbify(origin_dataset, conf.pipeline, save_path)
if __name__ == '__main__':
fire.Fire(transform)

View File

@ -3,13 +3,13 @@ from pathlib import Path
import torch import torch
import ignite.distributed as idist import ignite.distributed as idist
from ignite.engine import Events from ignite.engine import Events, Engine
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler from ignite.contrib.handlers import BasicTimeProfiler
def setup_common_handlers( def setup_common_handlers(
trainer, trainer: Engine,
output_dir=None, output_dir=None,
stop_on_nan=True, stop_on_nan=True,
use_profiler=True, use_profiler=True,
@ -39,6 +39,11 @@ def setup_common_handlers(
:param checkpoint_kwargs: :param checkpoint_kwargs:
:return: :return:
""" """
@trainer.on(Events.STARTED)
@idist.one_rank_only()
def print_dataloader_size(engine):
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
if stop_on_nan: if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
@ -68,6 +73,8 @@ def setup_common_handlers(
def print_interval(engine): def print_interval(engine):
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t" print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
for m in metrics_to_print: for m in metrics_to_print:
if m not in engine.state.metrics:
continue
print_str += f"{m}={engine.state.metrics[m]:.3f} " print_str += f"{m}={engine.state.metrics[m]:.3f} "
engine.logger.info(print_str) engine.logger.info(print_str)