improve run.sh
This commit is contained in:
parent
a5133e6795
commit
649f2244f7
5
main.py
5
main.py
@ -30,7 +30,10 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
|
|||||||
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
|
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
|
||||||
config.output_dir = str(output_dir)
|
config.output_dir = str(output_dir)
|
||||||
if output_dir.exists():
|
if output_dir.exists():
|
||||||
assert not any(output_dir.iterdir()), "output_dir must be empty"
|
# assert not any(output_dir.iterdir()), "output_dir must be empty"
|
||||||
|
contains = list(output_dir.iterdir())
|
||||||
|
assert (len(contains) == 0) or (len(contains) == 1 and contains[0].name == "config.yml"), \
|
||||||
|
f"output_dir must by empty or only contains config.yml, but now got {len(contains)} files"
|
||||||
else:
|
else:
|
||||||
if idist.get_rank() == 0:
|
if idist.get_rank() == 0:
|
||||||
output_dir.mkdir(parents=True)
|
output_dir.mkdir(parents=True)
|
||||||
|
|||||||
14
run.sh
14
run.sh
@ -1,8 +1,14 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
CONFIG=$1
|
CONFIG=$1
|
||||||
GPUS=$2
|
TASK=$2
|
||||||
|
GPUS=$3
|
||||||
|
|
||||||
# CUDA_VISIBLE_DEVICES=$GPUS \
|
_command="print(len('${GPUS}'.split(',')))"
|
||||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPUS" \
|
GPU_COUNT=$(python3 -c "${_command}")
|
||||||
main.py train "$CONFIG" --backup_config --setup_output_dir --setup_random_seed
|
|
||||||
|
echo "GPU_COUNT:${GPU_COUNT}"
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=$GPUS \
|
||||||
|
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
||||||
|
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed
|
||||||
|
|||||||
@ -3,13 +3,13 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import ignite.distributed as idist
|
import ignite.distributed as idist
|
||||||
from ignite.engine import Events
|
from ignite.engine import Events, Engine
|
||||||
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
||||||
from ignite.contrib.handlers import BasicTimeProfiler
|
from ignite.contrib.handlers import BasicTimeProfiler
|
||||||
|
|
||||||
|
|
||||||
def setup_common_handlers(
|
def setup_common_handlers(
|
||||||
trainer,
|
trainer: Engine,
|
||||||
output_dir=None,
|
output_dir=None,
|
||||||
stop_on_nan=True,
|
stop_on_nan=True,
|
||||||
use_profiler=True,
|
use_profiler=True,
|
||||||
@ -39,6 +39,11 @@ def setup_common_handlers(
|
|||||||
:param checkpoint_kwargs:
|
:param checkpoint_kwargs:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
@trainer.on(Events.STARTED)
|
||||||
|
@idist.one_rank_only()
|
||||||
|
def print_dataloader_size(engine):
|
||||||
|
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
|
||||||
|
|
||||||
if stop_on_nan:
|
if stop_on_nan:
|
||||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
||||||
|
|
||||||
@ -68,6 +73,8 @@ def setup_common_handlers(
|
|||||||
def print_interval(engine):
|
def print_interval(engine):
|
||||||
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
|
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
|
||||||
for m in metrics_to_print:
|
for m in metrics_to_print:
|
||||||
|
if m not in engine.state.metrics:
|
||||||
|
continue
|
||||||
print_str += f"{m}={engine.state.metrics[m]:.3f} "
|
print_str += f"{m}={engine.state.metrics[m]:.3f} "
|
||||||
engine.logger.info(print_str)
|
engine.logger.info(print_str)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user