move the same content to hander.py

This commit is contained in:
Ray Wong 2020-08-22 15:07:36 +08:00
parent 1a1cb9b00f
commit ccc3d7614a
6 changed files with 135 additions and 150 deletions

View File

@ -1,7 +1,7 @@
name: selfie2anime name: selfie2anime-origin
engine: UGATIT engine: UGATIT
result_dir: ./result result_dir: ./result
max_iteration: 100000 max_pairs: 1000000
distributed: distributed:
model: model:
@ -10,8 +10,15 @@ distributed:
misc: misc:
random_seed: 324 random_seed: 324
checkpoints: checkpoint:
interval: 1000 epoch_interval: 1 # one checkpoint every 1 epoch
n_saved: 5
interval:
print_per_iteration: 10 # print once per 10 iteration
tensorboard:
scalar: 10
image: 1000
model: model:
generator: generator:
@ -26,12 +33,12 @@ model:
_type: UGATIT-Discriminator _type: UGATIT-Discriminator
in_channels: 3 in_channels: 3
base_channels: 64 base_channels: 64
num_blocks: 3 num_blocks: 5
global_discriminator: global_discriminator:
_type: UGATIT-Discriminator _type: UGATIT-Discriminator
in_channels: 3 in_channels: 3
base_channels: 64 base_channels: 64
num_blocks: 5 num_blocks: 7
loss: loss:
gan: gan:
@ -62,9 +69,12 @@ optimizers:
data: data:
train: train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50 buffer_size: 50
dataloader: dataloader:
batch_size: 8 batch_size: 4
shuffle: True shuffle: True
num_workers: 2 num_workers: 2
pin_memory: True pin_memory: True
@ -85,9 +95,6 @@ data:
- Normalize: - Normalize:
mean: [0.5, 0.5, 0.5] mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5] std: [0.5, 0.5, 0.5]
scheduler:
start: 50000
target_lr: 0
test: test:
dataloader: dataloader:
batch_size: 4 batch_size: 4

View File

@ -1,20 +1,18 @@
from itertools import chain from itertools import chain
from pathlib import Path from math import ceil
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.utils
import ignite.distributed as idist import ignite.distributed as idist
from ignite.engine import Events, Engine from ignite.engine import Events, Engine
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.metrics import RunningAverage from ignite.metrics import RunningAverage
from ignite.contrib.handlers import ProgressBar
from ignite.utils import convert_tensor from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from omegaconf import OmegaConf from omegaconf import OmegaConf, read_write
import data import data
from loss.gan import GANLoss from loss.gan import GANLoss
@ -22,7 +20,7 @@ from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer from model.GAN.residual_generator import GANImageBuffer
from model.GAN.UGATIT import RhoClipper from model.GAN.UGATIT import RhoClipper
from util.image import make_2d_grid from util.image import make_2d_grid
from util.handler import setup_common_handlers from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_model, build_optimizer from util.build import build_model, build_optimizer
@ -49,14 +47,14 @@ def get_trainer(config, logger):
milestones_values = [ milestones_values = [
(0, config.optimizers.generator.lr), (0, config.optimizers.generator.lr),
(config.data.train.scheduler.start, config.optimizers.generator.lr), (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr) (config.max_iteration, config.data.train.scheduler.target_lr)
] ]
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values) lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
milestones_values = [ milestones_values = [
(0, config.optimizers.discriminator.lr), (0, config.optimizers.discriminator.lr),
(config.data.train.scheduler.start, config.optimizers.discriminator.lr), (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr) (config.max_iteration, config.data.train.scheduler.target_lr)
] ]
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values) lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
@ -66,18 +64,18 @@ def get_trainer(config, logger):
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
bce_loss = nn.BCEWithLogitsLoss().to(idist.device())
mse_loss = lambda x, t: F.mse_loss(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
bce_loss = lambda x, t: F.binary_cross_entropy_with_logits(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
image_buffers = { def mse_loss(x, target_flag):
k: GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) for k in return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
discriminators.keys()}
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
rho_clipper = RhoClipper(0, 1) rho_clipper = RhoClipper(0, 1)
def cal_generator_loss(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l, def criterion_generator(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
discriminator_g): discriminator_g):
discriminator_g.requires_grad_(False) discriminator_g.requires_grad_(False)
discriminator_l.requires_grad_(False) discriminator_l.requires_grad_(False)
pred_fake_g, cam_gd_pred = discriminator_g(fake) pred_fake_g, cam_gd_pred = discriminator_g(fake)
@ -92,7 +90,7 @@ def get_trainer(config, logger):
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True), f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
} }
def cal_discriminator_loss(name, discriminator, real, fake): def criterion_discriminator(name, discriminator, real, fake):
pred_real, cam_real = discriminator(real) pred_real, cam_real = discriminator(real)
pred_fake, cam_fake = discriminator(fake) pred_fake, cam_fake = discriminator(fake)
# TODO: origin do not divide 2, but I think it better to divide 2. # TODO: origin do not divide 2, but I think it better to divide 2.
@ -100,9 +98,8 @@ def get_trainer(config, logger):
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False) loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam} return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
def _step(engine, batch): def _step(engine, real):
batch = convert_tensor(batch, idist.device()) real = convert_tensor(real, idist.device())
real_a, real_b = batch["a"], batch["b"]
fake = dict() fake = dict()
cam_generator_pred = dict() cam_generator_pred = dict()
@ -111,18 +108,18 @@ def get_trainer(config, logger):
cam_identity_pred = dict() cam_identity_pred = dict()
heatmap = dict() heatmap = dict()
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real_a) fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real["a"])
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real_b) fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real["b"])
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"]) rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"]) rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real_a) identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"])
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real_b) identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"])
optimizer_g.zero_grad() optimizer_g.zero_grad()
loss_g = dict() loss_g = dict()
for n in ["a", "b"]: for n in ["a", "b"]:
loss_g.update(cal_generator_loss(n, batch[n], fake[n], rec[n], identity[n], cam_generator_pred[n], loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n])) cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
sum(loss_g.values()).backward() sum(loss_g.values()).backward()
optimizer_g.step() optimizer_g.step()
for generator in generators.values(): for generator in generators.values():
@ -135,13 +132,14 @@ def get_trainer(config, logger):
for k in discriminators.keys(): for k in discriminators.keys():
n = k[-1] # "a" or "b" n = k[-1] # "a" or "b"
loss_d.update( loss_d.update(
cal_discriminator_loss(k, discriminators[k], batch[n], image_buffers[k].query(fake[n].detach()))) criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach())))
sum(loss_d.values()).backward() sum(loss_d.values()).backward()
optimizer_d.step() optimizer_d.step()
for h in heatmap: for h in heatmap:
heatmap[h] = heatmap[h].detach() heatmap[h] = heatmap[h].detach()
generated_img = {f"fake_{k}": fake[k].detach() for k in fake} generated_img = {f"real_{k}": real[k].detach() for k in real}
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
generated_img.update({f"id_{k}": identity[k].detach() for k in identity}) generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec}) generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
@ -169,64 +167,41 @@ def get_trainer(config, logger):
to_save.update({f"generator_{k}": generators[k] for k in generators}) to_save.update({f"generator_{k}": generators[k] for k in generators})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators}) to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5, setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"],
filename_prefix=config.name, to_save=to_save, clear_cuda_cache=True, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
metrics_to_print=["loss_g", "loss_d"],
save_interval_event=Events.ITERATION_COMPLETED(
every=config.checkpoints.interval) | Events.COMPLETED)
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration)) def output_transform(output):
def terminate(engine): loss = dict()
engine.terminate() for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
if idist.get_rank() == 0: tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
# Create a logger if tensorboard_handler is not None:
tb_logger = TensorboardLogger(log_dir=config.output_dir) tensorboard_handler.attach(
tb_writer = tb_logger.writer
# Attach the logger to the trainer to log training loss at each iteration
def global_step_transform(*args, **kwargs):
return trainer.state.iteration
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="loss",
metric_names=["loss_g", "loss_d"],
global_step_transform=global_step_transform,
output_transform=output_transform
),
event_name=Events.ITERATION_COMPLETED(every=50)
)
tb_logger.attach(
trainer, trainer,
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"), log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=50) event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
) )
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval)) @trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
def show_images(engine): def show_images(engine):
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]["generated"].values()), image_a_order = ["real_a", "fake_b", "rec_a", "id_a"]
engine.state.iteration) image_b_order = ["real_b", "fake_a", "rec_b", "id_b"]
tb_writer.add_image("train/heatmap", make_2d_grid(engine.state.output["img"]["heatmap"].values()), tensorboard_handler.writer.add_image(
engine.state.iteration) "train/a",
make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_a_order]),
@trainer.on(Events.COMPLETED) engine.state.iteration
@idist.one_rank_only() )
def _(): tensorboard_handler.writer.add_image(
# We need to close the logger with we are done "train/b",
tb_logger.close() make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_b_order]),
engine.state.iteration
)
return trainer return trainer
@ -235,13 +210,16 @@ def run(task, config, logger):
assert torch.backends.cudnn.enabled assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}") logger.info(f"start task {task}")
with read_write(config):
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
if task == "train": if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset) train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}") logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger) trainer = get_trainer(config, logger)
try: try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1) trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception: except Exception:
import traceback import traceback
print(traceback.format_exc()) print(traceback.format_exc())

View File

@ -90,22 +90,6 @@ class Generator(nn.Module):
padding_mode="reflect", bias=False), padding_mode="reflect", bias=False),
nn.Tanh()] nn.Tanh()]
self.up_decoder = nn.Sequential(*up_decoder) self.up_decoder = nn.Sequential(*up_decoder)
# self.up_decoder = nn.ModuleDict({
# "up_1": nn.Upsample(scale_factor=2, mode='nearest'),
# "up_conv_1": nn.Sequential(
# nn.Conv2d(base_channels * 4, base_channels * 4 // 2, kernel_size=3, stride=1,
# padding=1, padding_mode="reflect", bias=False),
# ILN(base_channels * 4 // 2),
# nn.ReLU(True)),
# "up_2": nn.Upsample(scale_factor=2, mode='nearest'),
# "up_conv_2": nn.Sequential(
# nn.Conv2d(base_channels * 2, base_channels * 2 // 2, kernel_size=3, stride=1,
# padding=1, padding_mode="reflect", bias=False),
# ILN(base_channels * 2 // 2),
# nn.ReLU(True)),
# "up_end": nn.Sequential(nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
# padding_mode="reflect", bias=False), nn.Tanh())
# })
def forward(self, x): def forward(self, x):
x = self.down_encoder(x) x = self.down_encoder(x)

2
run.sh
View File

@ -16,5 +16,5 @@ PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --
CUDA_VISIBLE_DEVICES=$GPUS \ CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \ PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
main.py "$TASK" "$CONFIG" "$MORE_ARG" --backup_config --setup_output_dir --setup_random_seed main.py "$TASK" "$CONFIG" $MORE_ARG --backup_config --setup_output_dir --setup_random_seed

View File

@ -5,38 +5,33 @@ import torch
import ignite.distributed as idist import ignite.distributed as idist
from ignite.engine import Events, Engine from ignite.engine import Events, Engine
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler
def setup_common_handlers( def empty_cuda_cache(_):
trainer: Engine, torch.cuda.empty_cache()
output_dir=None, import gc
stop_on_nan=True,
use_profiler=True, gc.collect()
print_interval_event=None,
metrics_to_print=None,
to_save=None, def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
resume_from=None, to_save=None, metrics_to_print=None, end_event=None):
save_interval_event=None,
**checkpoint_kwargs
):
""" """
Helper method to setup trainer with common handlers. Helper method to setup trainer with common handlers.
1. TerminateOnNan 1. TerminateOnNan
2. BasicTimeProfiler 2. BasicTimeProfiler
3. Print 3. Print
4. Checkpoint 4. Checkpoint
:param trainer: trainer engine. Output of trainer's `update_function` should be a dictionary :param trainer:
or sequence or a single tensor. :param config:
:param output_dir: output path to indicate where `to_save` objects are stored. Argument is mutually :param stop_on_nan:
:param stop_on_nan: if True, :class:`~ignite.handlers.TerminateOnNan` handler is added to the trainer. :param clear_cuda_cache:
:param use_profiler: :param use_profiler:
:param print_interval_event:
:param metrics_to_print:
:param to_save: :param to_save:
:param resume_from: :param metrics_to_print:
:param save_interval_event: :param end_event:
:param checkpoint_kwargs:
:return: :return:
""" """
@ -48,28 +43,24 @@ def setup_common_handlers(
if stop_on_nan: if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
if torch.cuda.is_available() and clear_cuda_cache:
trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)
if use_profiler: if use_profiler:
# Create an object of the profiler and attach an engine to it # Create an object of the profiler and attach an engine to it
profiler = BasicTimeProfiler() profiler = BasicTimeProfiler()
profiler.attach(trainer) profiler.attach(trainer)
@trainer.on(Events.EPOCH_COMPLETED(once=1)) @trainer.on(Events.EPOCH_COMPLETED(once=1) | Events.COMPLETED)
@idist.one_rank_only() @idist.one_rank_only()
def log_intermediate_results(): def log_intermediate_results():
profiler.print_results(profiler.get_results()) profiler.print_results(profiler.get_results())
@trainer.on(Events.COMPLETED) print_interval_event = Events.ITERATION_COMPLETED(every=config.interval.print_per_iteration) | Events.COMPLETED
@idist.one_rank_only()
def _(): ProgressBar(ncols=0).attach(trainer, "all")
profiler.print_results(profiler.get_results())
# profiler.write_results(f"{output_dir}/time_profiling.csv")
if metrics_to_print is not None: if metrics_to_print is not None:
if print_interval_event is None:
raise ValueError(
"If metrics_to_print argument is provided then print_interval_event arguments should be also defined"
)
@trainer.on(print_interval_event) @trainer.on(print_interval_event)
def print_interval(engine): def print_interval(engine):
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t" print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
@ -77,19 +68,44 @@ def setup_common_handlers(
if m not in engine.state.metrics: if m not in engine.state.metrics:
continue continue
print_str += f"{m}={engine.state.metrics[m]:.3f} " print_str += f"{m}={engine.state.metrics[m]:.3f} "
engine.logger.info(print_str) engine.logger.debug(print_str)
if to_save is not None: if to_save is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir, require_empty=False), checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
**checkpoint_kwargs) n_saved=config.checkpoint.n_saved, filename_prefix=config.name)
if resume_from is not None: if config.resume_from is not None:
@trainer.on(Events.STARTED) @trainer.on(Events.STARTED)
def resume(engine): def resume(engine):
checkpoint_path = Path(resume_from) checkpoint_path = Path(config.resume_from)
if not checkpoint_path.exists(): if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found") raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu") ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp) Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}") engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
if save_interval_event is not None: trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
trainer.add_event_handler(save_interval_event, checkpoint_handler) checkpoint_handler)
if end_event is not None:
@trainer.on(end_event)
def terminate(engine):
engine.terminate()
def setup_tensorboard_handler(trainer: Engine, config, output_transform):
if config.interval.tensorboard is None:
return None
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
# We need to close the logger with we are done
tb_logger.close()
return tb_logger
return None

View File

@ -69,7 +69,7 @@ def setup_logger(
if distributed_rank > 0: if distributed_rank > 0:
logger.addHandler(logging.NullHandler()) logger.addHandler(logging.NullHandler())
else: else:
logger.setLevel(level) logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setLevel(level) ch.setLevel(level)
@ -78,7 +78,7 @@ def setup_logger(
if filepath is not None: if filepath is not None:
fh = logging.FileHandler(filepath) fh = logging.FileHandler(filepath)
fh.setLevel(file_level) fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter) fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)