move the same content to hander.py
This commit is contained in:
parent
1a1cb9b00f
commit
ccc3d7614a
@ -1,7 +1,7 @@
|
|||||||
name: selfie2anime
|
name: selfie2anime-origin
|
||||||
engine: UGATIT
|
engine: UGATIT
|
||||||
result_dir: ./result
|
result_dir: ./result
|
||||||
max_iteration: 100000
|
max_pairs: 1000000
|
||||||
|
|
||||||
distributed:
|
distributed:
|
||||||
model:
|
model:
|
||||||
@ -10,8 +10,15 @@ distributed:
|
|||||||
misc:
|
misc:
|
||||||
random_seed: 324
|
random_seed: 324
|
||||||
|
|
||||||
checkpoints:
|
checkpoint:
|
||||||
interval: 1000
|
epoch_interval: 1 # one checkpoint every 1 epoch
|
||||||
|
n_saved: 5
|
||||||
|
|
||||||
|
interval:
|
||||||
|
print_per_iteration: 10 # print once per 10 iteration
|
||||||
|
tensorboard:
|
||||||
|
scalar: 10
|
||||||
|
image: 1000
|
||||||
|
|
||||||
model:
|
model:
|
||||||
generator:
|
generator:
|
||||||
@ -26,12 +33,12 @@ model:
|
|||||||
_type: UGATIT-Discriminator
|
_type: UGATIT-Discriminator
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
base_channels: 64
|
base_channels: 64
|
||||||
num_blocks: 3
|
num_blocks: 5
|
||||||
global_discriminator:
|
global_discriminator:
|
||||||
_type: UGATIT-Discriminator
|
_type: UGATIT-Discriminator
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
base_channels: 64
|
base_channels: 64
|
||||||
num_blocks: 5
|
num_blocks: 7
|
||||||
|
|
||||||
loss:
|
loss:
|
||||||
gan:
|
gan:
|
||||||
@ -62,9 +69,12 @@ optimizers:
|
|||||||
|
|
||||||
data:
|
data:
|
||||||
train:
|
train:
|
||||||
|
scheduler:
|
||||||
|
start_proportion: 0.5
|
||||||
|
target_lr: 0
|
||||||
buffer_size: 50
|
buffer_size: 50
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 8
|
batch_size: 4
|
||||||
shuffle: True
|
shuffle: True
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
pin_memory: True
|
pin_memory: True
|
||||||
@ -85,9 +95,6 @@ data:
|
|||||||
- Normalize:
|
- Normalize:
|
||||||
mean: [0.5, 0.5, 0.5]
|
mean: [0.5, 0.5, 0.5]
|
||||||
std: [0.5, 0.5, 0.5]
|
std: [0.5, 0.5, 0.5]
|
||||||
scheduler:
|
|
||||||
start: 50000
|
|
||||||
target_lr: 0
|
|
||||||
test:
|
test:
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
|
|||||||
140
engine/UGATIT.py
140
engine/UGATIT.py
@ -1,20 +1,18 @@
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from math import ceil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision.utils
|
|
||||||
|
|
||||||
import ignite.distributed as idist
|
import ignite.distributed as idist
|
||||||
from ignite.engine import Events, Engine
|
from ignite.engine import Events, Engine
|
||||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
|
||||||
from ignite.metrics import RunningAverage
|
from ignite.metrics import RunningAverage
|
||||||
from ignite.contrib.handlers import ProgressBar
|
|
||||||
from ignite.utils import convert_tensor
|
from ignite.utils import convert_tensor
|
||||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
|
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
|
||||||
|
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf, read_write
|
||||||
|
|
||||||
import data
|
import data
|
||||||
from loss.gan import GANLoss
|
from loss.gan import GANLoss
|
||||||
@ -22,7 +20,7 @@ from model.weight_init import generation_init_weights
|
|||||||
from model.GAN.residual_generator import GANImageBuffer
|
from model.GAN.residual_generator import GANImageBuffer
|
||||||
from model.GAN.UGATIT import RhoClipper
|
from model.GAN.UGATIT import RhoClipper
|
||||||
from util.image import make_2d_grid
|
from util.image import make_2d_grid
|
||||||
from util.handler import setup_common_handlers
|
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||||
from util.build import build_model, build_optimizer
|
from util.build import build_model, build_optimizer
|
||||||
|
|
||||||
|
|
||||||
@ -49,14 +47,14 @@ def get_trainer(config, logger):
|
|||||||
|
|
||||||
milestones_values = [
|
milestones_values = [
|
||||||
(0, config.optimizers.generator.lr),
|
(0, config.optimizers.generator.lr),
|
||||||
(config.data.train.scheduler.start, config.optimizers.generator.lr),
|
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
||||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||||
]
|
]
|
||||||
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
||||||
|
|
||||||
milestones_values = [
|
milestones_values = [
|
||||||
(0, config.optimizers.discriminator.lr),
|
(0, config.optimizers.discriminator.lr),
|
||||||
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
|
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
||||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||||
]
|
]
|
||||||
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
||||||
@ -66,18 +64,18 @@ def get_trainer(config, logger):
|
|||||||
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||||
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||||
bce_loss = nn.BCEWithLogitsLoss().to(idist.device())
|
|
||||||
mse_loss = lambda x, t: F.mse_loss(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
|
|
||||||
bce_loss = lambda x, t: F.binary_cross_entropy_with_logits(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
|
|
||||||
|
|
||||||
image_buffers = {
|
def mse_loss(x, target_flag):
|
||||||
k: GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) for k in
|
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||||
discriminators.keys()}
|
|
||||||
|
|
||||||
|
def bce_loss(x, target_flag):
|
||||||
|
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||||
|
|
||||||
|
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
|
||||||
rho_clipper = RhoClipper(0, 1)
|
rho_clipper = RhoClipper(0, 1)
|
||||||
|
|
||||||
def cal_generator_loss(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
|
def criterion_generator(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
|
||||||
discriminator_g):
|
discriminator_g):
|
||||||
discriminator_g.requires_grad_(False)
|
discriminator_g.requires_grad_(False)
|
||||||
discriminator_l.requires_grad_(False)
|
discriminator_l.requires_grad_(False)
|
||||||
pred_fake_g, cam_gd_pred = discriminator_g(fake)
|
pred_fake_g, cam_gd_pred = discriminator_g(fake)
|
||||||
@ -92,7 +90,7 @@ def get_trainer(config, logger):
|
|||||||
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
|
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
|
||||||
}
|
}
|
||||||
|
|
||||||
def cal_discriminator_loss(name, discriminator, real, fake):
|
def criterion_discriminator(name, discriminator, real, fake):
|
||||||
pred_real, cam_real = discriminator(real)
|
pred_real, cam_real = discriminator(real)
|
||||||
pred_fake, cam_fake = discriminator(fake)
|
pred_fake, cam_fake = discriminator(fake)
|
||||||
# TODO: origin do not divide 2, but I think it better to divide 2.
|
# TODO: origin do not divide 2, but I think it better to divide 2.
|
||||||
@ -100,9 +98,8 @@ def get_trainer(config, logger):
|
|||||||
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
|
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
|
||||||
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
|
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
|
||||||
|
|
||||||
def _step(engine, batch):
|
def _step(engine, real):
|
||||||
batch = convert_tensor(batch, idist.device())
|
real = convert_tensor(real, idist.device())
|
||||||
real_a, real_b = batch["a"], batch["b"]
|
|
||||||
|
|
||||||
fake = dict()
|
fake = dict()
|
||||||
cam_generator_pred = dict()
|
cam_generator_pred = dict()
|
||||||
@ -111,18 +108,18 @@ def get_trainer(config, logger):
|
|||||||
cam_identity_pred = dict()
|
cam_identity_pred = dict()
|
||||||
heatmap = dict()
|
heatmap = dict()
|
||||||
|
|
||||||
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real_a)
|
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real["a"])
|
||||||
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real_b)
|
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real["b"])
|
||||||
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
|
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
|
||||||
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
|
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
|
||||||
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real_a)
|
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"])
|
||||||
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real_b)
|
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"])
|
||||||
|
|
||||||
optimizer_g.zero_grad()
|
optimizer_g.zero_grad()
|
||||||
loss_g = dict()
|
loss_g = dict()
|
||||||
for n in ["a", "b"]:
|
for n in ["a", "b"]:
|
||||||
loss_g.update(cal_generator_loss(n, batch[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
|
loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
|
||||||
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
|
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
|
||||||
sum(loss_g.values()).backward()
|
sum(loss_g.values()).backward()
|
||||||
optimizer_g.step()
|
optimizer_g.step()
|
||||||
for generator in generators.values():
|
for generator in generators.values():
|
||||||
@ -135,13 +132,14 @@ def get_trainer(config, logger):
|
|||||||
for k in discriminators.keys():
|
for k in discriminators.keys():
|
||||||
n = k[-1] # "a" or "b"
|
n = k[-1] # "a" or "b"
|
||||||
loss_d.update(
|
loss_d.update(
|
||||||
cal_discriminator_loss(k, discriminators[k], batch[n], image_buffers[k].query(fake[n].detach())))
|
criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach())))
|
||||||
sum(loss_d.values()).backward()
|
sum(loss_d.values()).backward()
|
||||||
optimizer_d.step()
|
optimizer_d.step()
|
||||||
|
|
||||||
for h in heatmap:
|
for h in heatmap:
|
||||||
heatmap[h] = heatmap[h].detach()
|
heatmap[h] = heatmap[h].detach()
|
||||||
generated_img = {f"fake_{k}": fake[k].detach() for k in fake}
|
generated_img = {f"real_{k}": real[k].detach() for k in real}
|
||||||
|
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
|
||||||
generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
|
generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
|
||||||
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
|
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
|
||||||
|
|
||||||
@ -169,64 +167,41 @@ def get_trainer(config, logger):
|
|||||||
to_save.update({f"generator_{k}": generators[k] for k in generators})
|
to_save.update({f"generator_{k}": generators[k] for k in generators})
|
||||||
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||||
|
|
||||||
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
|
setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"],
|
||||||
filename_prefix=config.name, to_save=to_save,
|
clear_cuda_cache=True, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||||
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
|
|
||||||
metrics_to_print=["loss_g", "loss_d"],
|
|
||||||
save_interval_event=Events.ITERATION_COMPLETED(
|
|
||||||
every=config.checkpoints.interval) | Events.COMPLETED)
|
|
||||||
|
|
||||||
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
|
def output_transform(output):
|
||||||
def terminate(engine):
|
loss = dict()
|
||||||
engine.terminate()
|
for tl in output["loss"]:
|
||||||
|
if isinstance(output["loss"][tl], dict):
|
||||||
|
for l in output["loss"][tl]:
|
||||||
|
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||||
|
else:
|
||||||
|
loss[tl] = output["loss"][tl]
|
||||||
|
return loss
|
||||||
|
|
||||||
if idist.get_rank() == 0:
|
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
|
||||||
# Create a logger
|
if tensorboard_handler is not None:
|
||||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
tensorboard_handler.attach(
|
||||||
tb_writer = tb_logger.writer
|
|
||||||
|
|
||||||
# Attach the logger to the trainer to log training loss at each iteration
|
|
||||||
def global_step_transform(*args, **kwargs):
|
|
||||||
return trainer.state.iteration
|
|
||||||
|
|
||||||
def output_transform(output):
|
|
||||||
loss = dict()
|
|
||||||
for tl in output["loss"]:
|
|
||||||
if isinstance(output["loss"][tl], dict):
|
|
||||||
for l in output["loss"][tl]:
|
|
||||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
|
||||||
else:
|
|
||||||
loss[tl] = output["loss"][tl]
|
|
||||||
return loss
|
|
||||||
|
|
||||||
tb_logger.attach(
|
|
||||||
trainer,
|
|
||||||
log_handler=OutputHandler(
|
|
||||||
tag="loss",
|
|
||||||
metric_names=["loss_g", "loss_d"],
|
|
||||||
global_step_transform=global_step_transform,
|
|
||||||
output_transform=output_transform
|
|
||||||
),
|
|
||||||
event_name=Events.ITERATION_COMPLETED(every=50)
|
|
||||||
)
|
|
||||||
tb_logger.attach(
|
|
||||||
trainer,
|
trainer,
|
||||||
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
||||||
event_name=Events.ITERATION_STARTED(every=50)
|
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
|
||||||
)
|
)
|
||||||
|
|
||||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
|
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
||||||
def show_images(engine):
|
def show_images(engine):
|
||||||
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]["generated"].values()),
|
image_a_order = ["real_a", "fake_b", "rec_a", "id_a"]
|
||||||
engine.state.iteration)
|
image_b_order = ["real_b", "fake_a", "rec_b", "id_b"]
|
||||||
tb_writer.add_image("train/heatmap", make_2d_grid(engine.state.output["img"]["heatmap"].values()),
|
tensorboard_handler.writer.add_image(
|
||||||
engine.state.iteration)
|
"train/a",
|
||||||
|
make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_a_order]),
|
||||||
@trainer.on(Events.COMPLETED)
|
engine.state.iteration
|
||||||
@idist.one_rank_only()
|
)
|
||||||
def _():
|
tensorboard_handler.writer.add_image(
|
||||||
# We need to close the logger with we are done
|
"train/b",
|
||||||
tb_logger.close()
|
make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_b_order]),
|
||||||
|
engine.state.iteration
|
||||||
|
)
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
@ -235,13 +210,16 @@ def run(task, config, logger):
|
|||||||
assert torch.backends.cudnn.enabled
|
assert torch.backends.cudnn.enabled
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
logger.info(f"start task {task}")
|
logger.info(f"start task {task}")
|
||||||
|
with read_write(config):
|
||||||
|
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
|
||||||
|
|
||||||
if task == "train":
|
if task == "train":
|
||||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||||
logger.info(f"train with dataset:\n{train_dataset}")
|
logger.info(f"train with dataset:\n{train_dataset}")
|
||||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||||
trainer = get_trainer(config, logger)
|
trainer = get_trainer(config, logger)
|
||||||
try:
|
try:
|
||||||
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|||||||
@ -90,22 +90,6 @@ class Generator(nn.Module):
|
|||||||
padding_mode="reflect", bias=False),
|
padding_mode="reflect", bias=False),
|
||||||
nn.Tanh()]
|
nn.Tanh()]
|
||||||
self.up_decoder = nn.Sequential(*up_decoder)
|
self.up_decoder = nn.Sequential(*up_decoder)
|
||||||
# self.up_decoder = nn.ModuleDict({
|
|
||||||
# "up_1": nn.Upsample(scale_factor=2, mode='nearest'),
|
|
||||||
# "up_conv_1": nn.Sequential(
|
|
||||||
# nn.Conv2d(base_channels * 4, base_channels * 4 // 2, kernel_size=3, stride=1,
|
|
||||||
# padding=1, padding_mode="reflect", bias=False),
|
|
||||||
# ILN(base_channels * 4 // 2),
|
|
||||||
# nn.ReLU(True)),
|
|
||||||
# "up_2": nn.Upsample(scale_factor=2, mode='nearest'),
|
|
||||||
# "up_conv_2": nn.Sequential(
|
|
||||||
# nn.Conv2d(base_channels * 2, base_channels * 2 // 2, kernel_size=3, stride=1,
|
|
||||||
# padding=1, padding_mode="reflect", bias=False),
|
|
||||||
# ILN(base_channels * 2 // 2),
|
|
||||||
# nn.ReLU(True)),
|
|
||||||
# "up_end": nn.Sequential(nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
|
|
||||||
# padding_mode="reflect", bias=False), nn.Tanh())
|
|
||||||
# })
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.down_encoder(x)
|
x = self.down_encoder(x)
|
||||||
|
|||||||
2
run.sh
2
run.sh
@ -16,5 +16,5 @@ PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --
|
|||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=$GPUS \
|
CUDA_VISIBLE_DEVICES=$GPUS \
|
||||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
||||||
main.py "$TASK" "$CONFIG" "$MORE_ARG" --backup_config --setup_output_dir --setup_random_seed
|
main.py "$TASK" "$CONFIG" $MORE_ARG --backup_config --setup_output_dir --setup_random_seed
|
||||||
|
|
||||||
|
|||||||
@ -5,38 +5,33 @@ import torch
|
|||||||
import ignite.distributed as idist
|
import ignite.distributed as idist
|
||||||
from ignite.engine import Events, Engine
|
from ignite.engine import Events, Engine
|
||||||
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
||||||
from ignite.contrib.handlers import BasicTimeProfiler
|
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
|
||||||
|
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler
|
||||||
|
|
||||||
|
|
||||||
def setup_common_handlers(
|
def empty_cuda_cache(_):
|
||||||
trainer: Engine,
|
torch.cuda.empty_cache()
|
||||||
output_dir=None,
|
import gc
|
||||||
stop_on_nan=True,
|
|
||||||
use_profiler=True,
|
gc.collect()
|
||||||
print_interval_event=None,
|
|
||||||
metrics_to_print=None,
|
|
||||||
to_save=None,
|
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
|
||||||
resume_from=None,
|
to_save=None, metrics_to_print=None, end_event=None):
|
||||||
save_interval_event=None,
|
|
||||||
**checkpoint_kwargs
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Helper method to setup trainer with common handlers.
|
Helper method to setup trainer with common handlers.
|
||||||
1. TerminateOnNan
|
1. TerminateOnNan
|
||||||
2. BasicTimeProfiler
|
2. BasicTimeProfiler
|
||||||
3. Print
|
3. Print
|
||||||
4. Checkpoint
|
4. Checkpoint
|
||||||
:param trainer: trainer engine. Output of trainer's `update_function` should be a dictionary
|
:param trainer:
|
||||||
or sequence or a single tensor.
|
:param config:
|
||||||
:param output_dir: output path to indicate where `to_save` objects are stored. Argument is mutually
|
:param stop_on_nan:
|
||||||
:param stop_on_nan: if True, :class:`~ignite.handlers.TerminateOnNan` handler is added to the trainer.
|
:param clear_cuda_cache:
|
||||||
:param use_profiler:
|
:param use_profiler:
|
||||||
:param print_interval_event:
|
|
||||||
:param metrics_to_print:
|
|
||||||
:param to_save:
|
:param to_save:
|
||||||
:param resume_from:
|
:param metrics_to_print:
|
||||||
:param save_interval_event:
|
:param end_event:
|
||||||
:param checkpoint_kwargs:
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -48,28 +43,24 @@ def setup_common_handlers(
|
|||||||
if stop_on_nan:
|
if stop_on_nan:
|
||||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
||||||
|
|
||||||
|
if torch.cuda.is_available() and clear_cuda_cache:
|
||||||
|
trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)
|
||||||
|
|
||||||
if use_profiler:
|
if use_profiler:
|
||||||
# Create an object of the profiler and attach an engine to it
|
# Create an object of the profiler and attach an engine to it
|
||||||
profiler = BasicTimeProfiler()
|
profiler = BasicTimeProfiler()
|
||||||
profiler.attach(trainer)
|
profiler.attach(trainer)
|
||||||
|
|
||||||
@trainer.on(Events.EPOCH_COMPLETED(once=1))
|
@trainer.on(Events.EPOCH_COMPLETED(once=1) | Events.COMPLETED)
|
||||||
@idist.one_rank_only()
|
@idist.one_rank_only()
|
||||||
def log_intermediate_results():
|
def log_intermediate_results():
|
||||||
profiler.print_results(profiler.get_results())
|
profiler.print_results(profiler.get_results())
|
||||||
|
|
||||||
@trainer.on(Events.COMPLETED)
|
print_interval_event = Events.ITERATION_COMPLETED(every=config.interval.print_per_iteration) | Events.COMPLETED
|
||||||
@idist.one_rank_only()
|
|
||||||
def _():
|
ProgressBar(ncols=0).attach(trainer, "all")
|
||||||
profiler.print_results(profiler.get_results())
|
|
||||||
# profiler.write_results(f"{output_dir}/time_profiling.csv")
|
|
||||||
|
|
||||||
if metrics_to_print is not None:
|
if metrics_to_print is not None:
|
||||||
if print_interval_event is None:
|
|
||||||
raise ValueError(
|
|
||||||
"If metrics_to_print argument is provided then print_interval_event arguments should be also defined"
|
|
||||||
)
|
|
||||||
|
|
||||||
@trainer.on(print_interval_event)
|
@trainer.on(print_interval_event)
|
||||||
def print_interval(engine):
|
def print_interval(engine):
|
||||||
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
|
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
|
||||||
@ -77,19 +68,44 @@ def setup_common_handlers(
|
|||||||
if m not in engine.state.metrics:
|
if m not in engine.state.metrics:
|
||||||
continue
|
continue
|
||||||
print_str += f"{m}={engine.state.metrics[m]:.3f} "
|
print_str += f"{m}={engine.state.metrics[m]:.3f} "
|
||||||
engine.logger.info(print_str)
|
engine.logger.debug(print_str)
|
||||||
|
|
||||||
if to_save is not None:
|
if to_save is not None:
|
||||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir, require_empty=False),
|
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
|
||||||
**checkpoint_kwargs)
|
n_saved=config.checkpoint.n_saved, filename_prefix=config.name)
|
||||||
if resume_from is not None:
|
if config.resume_from is not None:
|
||||||
@trainer.on(Events.STARTED)
|
@trainer.on(Events.STARTED)
|
||||||
def resume(engine):
|
def resume(engine):
|
||||||
checkpoint_path = Path(resume_from)
|
checkpoint_path = Path(config.resume_from)
|
||||||
if not checkpoint_path.exists():
|
if not checkpoint_path.exists():
|
||||||
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
|
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
|
||||||
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
|
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
|
||||||
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
||||||
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
||||||
if save_interval_event is not None:
|
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
|
||||||
trainer.add_event_handler(save_interval_event, checkpoint_handler)
|
checkpoint_handler)
|
||||||
|
if end_event is not None:
|
||||||
|
@trainer.on(end_event)
|
||||||
|
def terminate(engine):
|
||||||
|
engine.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_tensorboard_handler(trainer: Engine, config, output_transform):
|
||||||
|
if config.interval.tensorboard is None:
|
||||||
|
return None
|
||||||
|
if idist.get_rank() == 0:
|
||||||
|
# Create a logger
|
||||||
|
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||||
|
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
|
||||||
|
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
|
||||||
|
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
|
||||||
|
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
|
||||||
|
|
||||||
|
@trainer.on(Events.COMPLETED)
|
||||||
|
@idist.one_rank_only()
|
||||||
|
def _():
|
||||||
|
# We need to close the logger with we are done
|
||||||
|
tb_logger.close()
|
||||||
|
|
||||||
|
return tb_logger
|
||||||
|
return None
|
||||||
|
|||||||
@ -69,7 +69,7 @@ def setup_logger(
|
|||||||
if distributed_rank > 0:
|
if distributed_rank > 0:
|
||||||
logger.addHandler(logging.NullHandler())
|
logger.addHandler(logging.NullHandler())
|
||||||
else:
|
else:
|
||||||
logger.setLevel(level)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
ch = logging.StreamHandler()
|
ch = logging.StreamHandler()
|
||||||
ch.setLevel(level)
|
ch.setLevel(level)
|
||||||
@ -78,7 +78,7 @@ def setup_logger(
|
|||||||
|
|
||||||
if filepath is not None:
|
if filepath is not None:
|
||||||
fh = logging.FileHandler(filepath)
|
fh = logging.FileHandler(filepath)
|
||||||
fh.setLevel(file_level)
|
fh.setLevel(logging.DEBUG)
|
||||||
fh.setFormatter(formatter)
|
fh.setFormatter(formatter)
|
||||||
logger.addHandler(fh)
|
logger.addHandler(fh)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user