improve run.sh

This commit is contained in:
Ray Wong 2020-08-10 08:50:58 +08:00
parent a5133e6795
commit 649f2244f7
3 changed files with 23 additions and 7 deletions

View File

@ -30,7 +30,10 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
config.output_dir = str(output_dir)
if output_dir.exists():
assert not any(output_dir.iterdir()), "output_dir must be empty"
# assert not any(output_dir.iterdir()), "output_dir must be empty"
contains = list(output_dir.iterdir())
assert (len(contains) == 0) or (len(contains) == 1 and contains[0].name == "config.yml"), \
f"output_dir must by empty or only contains config.yml, but now got {len(contains)} files"
else:
if idist.get_rank() == 0:
output_dir.mkdir(parents=True)

14
run.sh
View File

@ -1,8 +1,14 @@
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
TASK=$2
GPUS=$3
# CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPUS" \
main.py train "$CONFIG" --backup_config --setup_output_dir --setup_random_seed
_command="print(len('${GPUS}'.split(',')))"
GPU_COUNT=$(python3 -c "${_command}")
echo "GPU_COUNT:${GPU_COUNT}"
CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed

View File

@ -3,13 +3,13 @@ from pathlib import Path
import torch
import ignite.distributed as idist
from ignite.engine import Events
from ignite.engine import Events, Engine
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler
def setup_common_handlers(
trainer,
trainer: Engine,
output_dir=None,
stop_on_nan=True,
use_profiler=True,
@ -39,6 +39,11 @@ def setup_common_handlers(
:param checkpoint_kwargs:
:return:
"""
@trainer.on(Events.STARTED)
@idist.one_rank_only()
def print_dataloader_size(engine):
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
@ -68,6 +73,8 @@ def setup_common_handlers(
def print_interval(engine):
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
for m in metrics_to_print:
if m not in engine.state.metrics:
continue
print_str += f"{m}={engine.state.metrics[m]:.3f} "
engine.logger.info(print_str)