diff --git a/main.py b/main.py index 2edb96e..05bcf5c 100644 --- a/main.py +++ b/main.py @@ -30,7 +30,10 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir config.output_dir = str(output_dir) if output_dir.exists(): - assert not any(output_dir.iterdir()), "output_dir must be empty" + # assert not any(output_dir.iterdir()), "output_dir must be empty" + contains = list(output_dir.iterdir()) + assert (len(contains) == 0) or (len(contains) == 1 and contains[0].name == "config.yml"), \ + f"output_dir must by empty or only contains config.yml, but now got {len(contains)} files" else: if idist.get_rank() == 0: output_dir.mkdir(parents=True) diff --git a/run.sh b/run.sh index d075595..21127a1 100644 --- a/run.sh +++ b/run.sh @@ -1,8 +1,14 @@ #!/usr/bin/env bash CONFIG=$1 -GPUS=$2 +TASK=$2 +GPUS=$3 -# CUDA_VISIBLE_DEVICES=$GPUS \ -PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPUS" \ - main.py train "$CONFIG" --backup_config --setup_output_dir --setup_random_seed +_command="print(len('${GPUS}'.split(',')))" +GPU_COUNT=$(python3 -c "${_command}") + +echo "GPU_COUNT:${GPU_COUNT}" + +CUDA_VISIBLE_DEVICES=$GPUS \ +PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \ + main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed diff --git a/util/handler.py b/util/handler.py index db1d3d1..180b872 100644 --- a/util/handler.py +++ b/util/handler.py @@ -3,13 +3,13 @@ from pathlib import Path import torch import ignite.distributed as idist -from ignite.engine import Events +from ignite.engine import Events, Engine from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan from ignite.contrib.handlers import BasicTimeProfiler def setup_common_handlers( - trainer, + trainer: Engine, output_dir=None, stop_on_nan=True, use_profiler=True, @@ -39,6 +39,11 @@ def setup_common_handlers( :param checkpoint_kwargs: :return: """ + @trainer.on(Events.STARTED) + @idist.one_rank_only() + def print_dataloader_size(engine): + engine.logger.info(f"data loader length: {len(engine.state.dataloader)}") + if stop_on_nan: trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) @@ -68,6 +73,8 @@ def setup_common_handlers( def print_interval(engine): print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t" for m in metrics_to_print: + if m not in engine.state.metrics: + continue print_str += f"{m}={engine.state.metrics[m]:.3f} " engine.logger.info(print_str)