import torch import torch.nn as nn from torchvision.datasets import ImageFolder import ignite.distributed as idist from ignite.contrib.metrics.gpu_info import GpuInfo from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, global_step_from_engine, OutputHandler, \ WeightsScalarHandler, GradsHistHandler, WeightsHistHandler, GradsScalarHandler from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events from ignite.metrics import Accuracy, Loss, RunningAverage from ignite.contrib.engines.common import save_best_model_by_val_score from ignite.contrib.handlers import ProgressBar from util.build import build_model, build_optimizer from util.handler import setup_common_handlers from data.transform import transform_pipeline from data.dataset import LMDBDataset def warmup_trainer(config, logger): model = build_model(config.model, config.distributed.model) optimizer = build_optimizer(model.parameters(), config.baseline.optimizers) loss_fn = nn.CrossEntropyLoss() trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True, output_transform=lambda x, y, y_pred, loss: (loss.item(), y_pred, y)) trainer.logger = logger RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss") Accuracy(output_transform=lambda x: (x[1], x[2])).attach(trainer, "acc") ProgressBar(ncols=0).attach(trainer) if idist.get_rank() == 0: GpuInfo().attach(trainer, name='gpu') tb_logger = TensorboardLogger(log_dir=config.output_dir) tb_logger.attach( trainer, log_handler=OutputHandler( tag="train", metric_names='all', global_step_transform=global_step_from_engine(trainer), ), event_name=Events.EPOCH_COMPLETED ) tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.EPOCH_COMPLETED(every=10)) tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25)) tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.EPOCH_COMPLETED(every=10)) tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25)) @trainer.on(Events.COMPLETED) def _(): tb_logger.close() to_save = dict(model=model, optimizer=optimizer, trainer=trainer) setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.EPOCH_COMPLETED, to_save=to_save, save_interval_event=Events.EPOCH_COMPLETED(every=25), n_saved=5, metrics_to_print=["loss", "acc"]) return trainer def run(task, config, logger): assert torch.backends.cudnn.enabled torch.backends.cudnn.benchmark = True logger.info(f"start task {task}") if task == "warmup": train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path, pipeline=config.baseline.data.dataset.train.pipeline) logger.info(f"train with dataset:\n{train_dataset}") train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader) trainer = warmup_trainer(config, logger) try: trainer.run(train_data_loader, max_epochs=400) except Exception: import traceback print(traceback.format_exc()) elif task == "protonet-wo": pass elif task == "protonet-w": pass else: return ValueError(f"invalid task: {task}")