improve run.sh
This commit is contained in:
parent
a5133e6795
commit
649f2244f7
5
main.py
5
main.py
@ -30,7 +30,10 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
|
||||
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
|
||||
config.output_dir = str(output_dir)
|
||||
if output_dir.exists():
|
||||
assert not any(output_dir.iterdir()), "output_dir must be empty"
|
||||
# assert not any(output_dir.iterdir()), "output_dir must be empty"
|
||||
contains = list(output_dir.iterdir())
|
||||
assert (len(contains) == 0) or (len(contains) == 1 and contains[0].name == "config.yml"), \
|
||||
f"output_dir must by empty or only contains config.yml, but now got {len(contains)} files"
|
||||
else:
|
||||
if idist.get_rank() == 0:
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
14
run.sh
14
run.sh
@ -1,8 +1,14 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
CONFIG=$1
|
||||
GPUS=$2
|
||||
TASK=$2
|
||||
GPUS=$3
|
||||
|
||||
# CUDA_VISIBLE_DEVICES=$GPUS \
|
||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPUS" \
|
||||
main.py train "$CONFIG" --backup_config --setup_output_dir --setup_random_seed
|
||||
_command="print(len('${GPUS}'.split(',')))"
|
||||
GPU_COUNT=$(python3 -c "${_command}")
|
||||
|
||||
echo "GPU_COUNT:${GPU_COUNT}"
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$GPUS \
|
||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
||||
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed
|
||||
|
||||
@ -3,13 +3,13 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
||||
from ignite.contrib.handlers import BasicTimeProfiler
|
||||
|
||||
|
||||
def setup_common_handlers(
|
||||
trainer,
|
||||
trainer: Engine,
|
||||
output_dir=None,
|
||||
stop_on_nan=True,
|
||||
use_profiler=True,
|
||||
@ -39,6 +39,11 @@ def setup_common_handlers(
|
||||
:param checkpoint_kwargs:
|
||||
:return:
|
||||
"""
|
||||
@trainer.on(Events.STARTED)
|
||||
@idist.one_rank_only()
|
||||
def print_dataloader_size(engine):
|
||||
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
|
||||
|
||||
if stop_on_nan:
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
||||
|
||||
@ -68,6 +73,8 @@ def setup_common_handlers(
|
||||
def print_interval(engine):
|
||||
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
|
||||
for m in metrics_to_print:
|
||||
if m not in engine.state.metrics:
|
||||
continue
|
||||
print_str += f"{m}={engine.state.metrics[m]:.3f} "
|
||||
engine.logger.info(print_str)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user