86 lines
3.7 KiB
Python
86 lines
3.7 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torchvision.datasets import ImageFolder
|
|
|
|
import ignite.distributed as idist
|
|
from ignite.contrib.metrics.gpu_info import GpuInfo
|
|
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, global_step_from_engine, OutputHandler, \
|
|
WeightsScalarHandler, GradsHistHandler, WeightsHistHandler, GradsScalarHandler
|
|
from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events
|
|
from ignite.metrics import Accuracy, Loss, RunningAverage
|
|
from ignite.contrib.engines.common import save_best_model_by_val_score
|
|
from ignite.contrib.handlers import ProgressBar
|
|
|
|
from util.build import build_model, build_optimizer
|
|
from util.handler import setup_common_handlers
|
|
from data.transform import transform_pipeline
|
|
from data.dataset import LMDBDataset
|
|
|
|
|
|
def warmup_trainer(config, logger):
|
|
model = build_model(config.model, config.distributed.model)
|
|
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True,
|
|
output_transform=lambda x, y, y_pred, loss: (loss.item(), y_pred, y))
|
|
trainer.logger = logger
|
|
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
|
|
Accuracy(output_transform=lambda x: (x[1], x[2])).attach(trainer, "acc")
|
|
ProgressBar(ncols=0).attach(trainer)
|
|
|
|
if idist.get_rank() == 0:
|
|
GpuInfo().attach(trainer, name='gpu')
|
|
|
|
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
|
tb_logger.attach(
|
|
trainer,
|
|
log_handler=OutputHandler(
|
|
tag="train",
|
|
metric_names='all',
|
|
global_step_transform=global_step_from_engine(trainer),
|
|
),
|
|
event_name=Events.EPOCH_COMPLETED
|
|
)
|
|
tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model),
|
|
event_name=Events.EPOCH_COMPLETED(every=10))
|
|
|
|
tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
|
|
|
|
tb_logger.attach(trainer, log_handler=GradsScalarHandler(model),
|
|
event_name=Events.EPOCH_COMPLETED(every=10))
|
|
|
|
tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
|
|
|
|
@trainer.on(Events.COMPLETED)
|
|
def _():
|
|
tb_logger.close()
|
|
|
|
to_save = dict(model=model, optimizer=optimizer, trainer=trainer)
|
|
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.EPOCH_COMPLETED, to_save=to_save,
|
|
save_interval_event=Events.EPOCH_COMPLETED(every=25), n_saved=5,
|
|
metrics_to_print=["loss", "acc"])
|
|
return trainer
|
|
|
|
|
|
def run(task, config, logger):
|
|
assert torch.backends.cudnn.enabled
|
|
torch.backends.cudnn.benchmark = True
|
|
logger.info(f"start task {task}")
|
|
if task == "warmup":
|
|
train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path,
|
|
pipeline=config.baseline.data.dataset.train.pipeline)
|
|
logger.info(f"train with dataset:\n{train_dataset}")
|
|
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
|
|
trainer = warmup_trainer(config, logger)
|
|
try:
|
|
trainer.run(train_data_loader, max_epochs=400)
|
|
except Exception:
|
|
import traceback
|
|
print(traceback.format_exc())
|
|
elif task == "protonet-wo":
|
|
pass
|
|
elif task == "protonet-w":
|
|
pass
|
|
else:
|
|
return ValueError(f"invalid task: {task}")
|