raycv/engine/crossdomain.py

98 lines
4.4 KiB
Python

import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
import ignite.distributed as idist
from ignite.contrib.metrics.gpu_info import GpuInfo
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, global_step_from_engine, OutputHandler, \
WeightsScalarHandler, GradsHistHandler, WeightsHistHandler, GradsScalarHandler
from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events
from ignite.metrics import Accuracy, Loss, RunningAverage
from ignite.contrib.engines.common import save_best_model_by_val_score
from ignite.contrib.handlers import ProgressBar
from util.build import build_model, build_optimizer
from util.handler import setup_common_handlers
from data.transform import transform_pipeline
def baseline_trainer(config, logger, val_loader):
model = build_model(config.model, config.distributed.model)
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
loss_fn = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True)
trainer.logger = logger
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
ProgressBar(ncols=0).attach(trainer)
val_metrics = {
"accuracy": Accuracy(),
"nll": Loss(loss_fn)
}
evaluator = create_supervised_evaluator(model, val_metrics, idist.device())
ProgressBar(ncols=0).attach(evaluator)
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(engine):
logger.info(f"Epoch[{engine.state.epoch}] Loss: {engine.state.output:.2f}")
evaluator.run(val_loader)
metrics = evaluator.state.metrics
logger.info("Training Results - Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))
if idist.get_rank() == 0:
GpuInfo().attach(trainer, name='gpu')
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_logger.attach(
evaluator,
log_handler=OutputHandler(
tag="val",
metric_names='all',
global_step_transform=global_step_from_engine(trainer),
),
event_name=Events.EPOCH_COMPLETED
)
tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model),
event_name=Events.EPOCH_COMPLETED(every=10))
tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
tb_logger.attach(trainer, log_handler=GradsScalarHandler(model),
event_name=Events.EPOCH_COMPLETED(every=10))
tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
@trainer.on(Events.COMPLETED)
def _():
tb_logger.close()
to_save = dict(model=model, optimizer=optimizer, trainer=trainer)
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.EPOCH_COMPLETED, to_save=to_save,
save_interval_event=Events.EPOCH_COMPLETED(every=25), n_saved=5,
metrics_to_print=["loss"])
save_best_model_by_val_score(config.output_dir, evaluator, model, "accuracy", 1, trainer)
return trainer
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "baseline":
train_dataset = ImageFolder(config.baseline.data.dataset.train.path,
transform=transform_pipeline(config.baseline.data.dataset.train.pipeline))
val_dataset = ImageFolder(config.baseline.data.dataset.val.path,
transform=transform_pipeline(config.baseline.data.dataset.val.pipeline))
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
val_data_loader = idist.auto_dataloader(val_dataset, **config.baseline.data.dataloader)
trainer = baseline_trainer(config, logger, val_data_loader)
try:
trainer.run(train_data_loader, max_epochs=400)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")