Compare commits
5 Commits
f7843de45d
...
7cf235781d
| Author | SHA1 | Date | |
|---|---|---|---|
| 7cf235781d | |||
| c520ce9501 | |||
| 8abd35467c | |||
| 888a052f05 | |||
| 206d9343cd |
@ -61,6 +61,7 @@ optimizers:
|
|||||||
|
|
||||||
data:
|
data:
|
||||||
train:
|
train:
|
||||||
|
buffer_size: 50
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
shuffle: True
|
shuffle: True
|
||||||
|
|||||||
@ -3,59 +3,41 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
|
||||||
import torchvision.utils
|
import torchvision.utils
|
||||||
|
|
||||||
import ignite.distributed as idist
|
import ignite.distributed as idist
|
||||||
from ignite.engine import Events, Engine
|
from ignite.engine import Events, Engine
|
||||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||||
from ignite.metrics import RunningAverage
|
from ignite.metrics import RunningAverage
|
||||||
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
from ignite.contrib.handlers import ProgressBar
|
||||||
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
|
|
||||||
from ignite.utils import convert_tensor
|
from ignite.utils import convert_tensor
|
||||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
|
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import data
|
import data
|
||||||
from model import MODEL
|
|
||||||
from loss.gan import GANLoss
|
from loss.gan import GANLoss
|
||||||
from util.distributed import auto_model
|
from model.weight_init import generation_init_weights
|
||||||
|
from model.residual_generator import GANImageBuffer
|
||||||
from util.image import make_2d_grid
|
from util.image import make_2d_grid
|
||||||
from util.handler import Resumer
|
from util.handler import setup_common_handlers
|
||||||
|
from util.build import build_model, build_optimizer
|
||||||
|
|
||||||
def _build_model(cfg, distributed_args=None):
|
|
||||||
cfg = OmegaConf.to_container(cfg)
|
|
||||||
model_distributed_config = cfg.pop("_distributed", dict())
|
|
||||||
model = MODEL.build_with(cfg)
|
|
||||||
|
|
||||||
if model_distributed_config.get("bn_to_syncbn"):
|
|
||||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
||||||
|
|
||||||
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
|
|
||||||
return auto_model(model, **distributed_args)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_optimizer(params, cfg):
|
|
||||||
assert "_type" in cfg
|
|
||||||
cfg = OmegaConf.to_container(cfg)
|
|
||||||
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
|
|
||||||
return idist.auto_optim(optimizer)
|
|
||||||
|
|
||||||
|
|
||||||
def get_trainer(config, logger):
|
def get_trainer(config, logger):
|
||||||
generator_a = _build_model(config.model.generator, config.distributed.model)
|
generator_a = build_model(config.model.generator, config.distributed.model)
|
||||||
generator_b = _build_model(config.model.generator, config.distributed.model)
|
generator_b = build_model(config.model.generator, config.distributed.model)
|
||||||
discriminator_a = _build_model(config.model.discriminator, config.distributed.model)
|
discriminator_a = build_model(config.model.discriminator, config.distributed.model)
|
||||||
discriminator_b = _build_model(config.model.discriminator, config.distributed.model)
|
discriminator_b = build_model(config.model.discriminator, config.distributed.model)
|
||||||
|
for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
|
||||||
|
generation_init_weights(m)
|
||||||
logger.debug(discriminator_a)
|
logger.debug(discriminator_a)
|
||||||
logger.debug(generator_a)
|
logger.debug(generator_a)
|
||||||
|
|
||||||
optimizer_g = _build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
|
optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
|
||||||
config.optimizers.generator)
|
config.optimizers.generator)
|
||||||
optimizer_d = _build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
|
optimizer_d = build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
|
||||||
config.optimizers.discriminator)
|
config.optimizers.discriminator)
|
||||||
|
|
||||||
milestones_values = [
|
milestones_values = [
|
||||||
(config.data.train.scheduler.start, config.optimizers.generator.lr),
|
(config.data.train.scheduler.start, config.optimizers.generator.lr),
|
||||||
@ -75,16 +57,21 @@ def get_trainer(config, logger):
|
|||||||
cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
|
cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
|
||||||
id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
|
id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
|
||||||
|
|
||||||
|
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
|
||||||
|
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
|
||||||
|
|
||||||
def _step(engine, batch):
|
def _step(engine, batch):
|
||||||
batch = convert_tensor(batch, idist.device())
|
batch = convert_tensor(batch, idist.device())
|
||||||
real_a, real_b = batch["a"], batch["b"]
|
real_a, real_b = batch["a"], batch["b"]
|
||||||
|
|
||||||
optimizer_g.zero_grad()
|
|
||||||
fake_b = generator_a(real_a) # G_A(A)
|
fake_b = generator_a(real_a) # G_A(A)
|
||||||
rec_a = generator_b(fake_b) # G_B(G_A(A))
|
rec_a = generator_b(fake_b) # G_B(G_A(A))
|
||||||
fake_a = generator_b(real_b) # G_B(B)
|
fake_a = generator_b(real_b) # G_B(B)
|
||||||
rec_b = generator_a(fake_a) # G_A(G_B(B))
|
rec_b = generator_a(fake_a) # G_A(G_B(B))
|
||||||
|
|
||||||
|
optimizer_g.zero_grad()
|
||||||
|
discriminator_a.requires_grad_(False)
|
||||||
|
discriminator_b.requires_grad_(False)
|
||||||
loss_g = dict(
|
loss_g = dict(
|
||||||
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
|
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
|
||||||
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
|
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
|
||||||
@ -96,17 +83,19 @@ def get_trainer(config, logger):
|
|||||||
sum(loss_g.values()).backward()
|
sum(loss_g.values()).backward()
|
||||||
optimizer_g.step()
|
optimizer_g.step()
|
||||||
|
|
||||||
|
discriminator_a.requires_grad_(True)
|
||||||
|
discriminator_b.requires_grad_(True)
|
||||||
optimizer_d.zero_grad()
|
optimizer_d.zero_grad()
|
||||||
loss_d_a = dict(
|
loss_d_a = dict(
|
||||||
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
|
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
|
||||||
fake=gan_loss(discriminator_a(fake_b.detach()), False, is_discriminator=True),
|
fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
|
||||||
)
|
)
|
||||||
|
(sum(loss_d_a.values()) * 0.5).backward()
|
||||||
loss_d_b = dict(
|
loss_d_b = dict(
|
||||||
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
|
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
|
||||||
fake=gan_loss(discriminator_b(fake_a.detach()), False, is_discriminator=True),
|
fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
|
||||||
)
|
)
|
||||||
loss_d = sum(loss_d_a.values()) / 2 + sum(loss_d_b.values()) / 2
|
(sum(loss_d_b.values()) * 0.5).backward()
|
||||||
loss_d.backward()
|
|
||||||
optimizer_d.step()
|
optimizer_d.step()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -129,27 +118,25 @@ def get_trainer(config, logger):
|
|||||||
trainer.logger = logger
|
trainer.logger = logger
|
||||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
|
||||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
|
||||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
|
||||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
|
||||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
|
||||||
|
|
||||||
@trainer.on(Events.ITERATION_COMPLETED(every=10))
|
|
||||||
def print_log(engine):
|
|
||||||
engine.logger.info(f"iter:[{engine.state.iteration}/{config.max_iteration}]"
|
|
||||||
f"loss_g={engine.state.metrics['loss_g']:.3f} "
|
|
||||||
f"loss_d_a={engine.state.metrics['loss_d_a']:.3f} "
|
|
||||||
f"loss_d_b={engine.state.metrics['loss_d_b']:.3f} ")
|
|
||||||
|
|
||||||
to_save = dict(
|
to_save = dict(
|
||||||
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
|
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
|
||||||
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
|
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
|
||||||
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
|
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.add_event_handler(Events.STARTED, Resumer(to_save, config.resume_from))
|
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.ITERATION_COMPLETED(every=10),
|
||||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir), n_saved=None)
|
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], to_save=to_save,
|
||||||
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.checkpoints.interval), checkpoint_handler)
|
resume_from=config.resume_from, n_saved=5, filename_prefix=config.name,
|
||||||
|
save_interval_event=Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
|
||||||
|
|
||||||
|
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||||
|
def terminate(engine):
|
||||||
|
engine.terminate()
|
||||||
|
|
||||||
if idist.get_rank() == 0:
|
if idist.get_rank() == 0:
|
||||||
# Create a logger
|
# Create a logger
|
||||||
@ -169,7 +156,6 @@ def get_trainer(config, logger):
|
|||||||
),
|
),
|
||||||
event_name=Events.ITERATION_COMPLETED(every=50)
|
event_name=Events.ITERATION_COMPLETED(every=50)
|
||||||
)
|
)
|
||||||
# Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration
|
|
||||||
tb_logger.attach(
|
tb_logger.attach(
|
||||||
trainer,
|
trainer,
|
||||||
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
||||||
@ -180,28 +166,18 @@ def get_trainer(config, logger):
|
|||||||
def show_images(engine):
|
def show_images(engine):
|
||||||
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
|
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
|
||||||
|
|
||||||
# Create an object of the profiler and attach an engine to it
|
@trainer.on(Events.COMPLETED)
|
||||||
profiler = BasicTimeProfiler()
|
@idist.one_rank_only()
|
||||||
profiler.attach(trainer)
|
def _():
|
||||||
|
# We need to close the logger with we are done
|
||||||
@trainer.on(Events.EPOCH_COMPLETED(once=1))
|
tb_logger.close()
|
||||||
@idist.one_rank_only()
|
|
||||||
def log_intermediate_results():
|
|
||||||
profiler.print_results(profiler.get_results())
|
|
||||||
|
|
||||||
@trainer.on(Events.COMPLETED)
|
|
||||||
@idist.one_rank_only()
|
|
||||||
def _():
|
|
||||||
profiler.write_results(f"{config.output_dir}/time_profiling.csv")
|
|
||||||
# We need to close the logger with we are done
|
|
||||||
tb_logger.close()
|
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
def get_tester(config, logger):
|
def get_tester(config, logger):
|
||||||
generator_a = _build_model(config.model.generator, config.distributed.model)
|
generator_a = build_model(config.model.generator, config.distributed.model)
|
||||||
generator_b = _build_model(config.model.generator, config.distributed.model)
|
generator_b = build_model(config.model.generator, config.distributed.model)
|
||||||
|
|
||||||
def _step(engine, batch):
|
def _step(engine, batch):
|
||||||
batch = convert_tensor(batch, idist.device())
|
batch = convert_tensor(batch, idist.device())
|
||||||
@ -225,7 +201,7 @@ def get_tester(config, logger):
|
|||||||
if idist.get_rank == 0:
|
if idist.get_rank == 0:
|
||||||
ProgressBar(ncols=0).attach(tester)
|
ProgressBar(ncols=0).attach(tester)
|
||||||
to_load = dict(generator_a=generator_a, generator_b=generator_b)
|
to_load = dict(generator_a=generator_a, generator_b=generator_b)
|
||||||
tester.add_event_handler(Events.STARTED, Resumer(to_load, config.resume_from))
|
setup_common_handlers(tester, use_profiler=False, to_save=to_load, resume_from=config.resume_from)
|
||||||
|
|
||||||
@tester.on(Events.STARTED)
|
@tester.on(Events.STARTED)
|
||||||
@idist.one_rank_only()
|
@idist.one_rank_only()
|
||||||
@ -248,19 +224,21 @@ def get_tester(config, logger):
|
|||||||
|
|
||||||
|
|
||||||
def run(task, config, logger):
|
def run(task, config, logger):
|
||||||
|
assert torch.backends.cudnn.enabled
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
logger.info(f"start task {task}")
|
logger.info(f"start task {task}")
|
||||||
if task == "train":
|
if task == "train":
|
||||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||||
logger.info(f"train with dataset:\n{train_dataset}")
|
logger.info(f"train with dataset:\n{train_dataset}")
|
||||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||||
trainer = get_trainer(config, logger)
|
trainer = get_trainer(config, logger)
|
||||||
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
|
||||||
try:
|
try:
|
||||||
trainer.run(train_data_loader, max_epochs=1)
|
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
elif task == "test":
|
elif task == "test":
|
||||||
|
assert config.resume_from is not None
|
||||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||||
logger.info(f"test with dataset:\n{test_dataset}")
|
logger.info(f"test with dataset:\n{test_dataset}")
|
||||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||||
|
|||||||
19
main.py
19
main.py
@ -29,14 +29,16 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
|
|||||||
if setup_output_dir:
|
if setup_output_dir:
|
||||||
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
|
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
|
||||||
config.output_dir = str(output_dir)
|
config.output_dir = str(output_dir)
|
||||||
if idist.get_rank() == 0:
|
if output_dir.exists():
|
||||||
if not output_dir.exists():
|
assert not any(output_dir.iterdir()), "output_dir must be empty"
|
||||||
|
else:
|
||||||
|
if idist.get_rank() == 0:
|
||||||
output_dir.mkdir(parents=True)
|
output_dir.mkdir(parents=True)
|
||||||
logger.info(f"mkdir -p {output_dir}")
|
logger.info(f"mkdir -p {output_dir}")
|
||||||
logger.info(f"output path: {config.output_dir}")
|
logger.info(f"output path: {config.output_dir}")
|
||||||
if backup_config:
|
if backup_config and idist.get_rank() == 0:
|
||||||
with open(output_dir / "config.yml", "w+") as f:
|
with open(output_dir / "config.yml", "w+") as f:
|
||||||
print(config.pretty(), file=f)
|
print(config.pretty(), file=f)
|
||||||
|
|
||||||
OmegaConf.set_readonly(config, True)
|
OmegaConf.set_readonly(config, True)
|
||||||
|
|
||||||
@ -46,7 +48,10 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
|
|||||||
|
|
||||||
def run(task, config: str, *omega_options, **kwargs):
|
def run(task, config: str, *omega_options, **kwargs):
|
||||||
omega_options = [str(o) for o in omega_options]
|
omega_options = [str(o) for o in omega_options]
|
||||||
conf = OmegaConf.merge(OmegaConf.load(config), OmegaConf.from_cli(omega_options))
|
cli_conf = OmegaConf.from_cli(omega_options)
|
||||||
|
if len(cli_conf) > 0:
|
||||||
|
print(cli_conf.pretty())
|
||||||
|
conf = OmegaConf.merge(OmegaConf.load(config), cli_conf)
|
||||||
backend = kwargs.get("backend", "nccl")
|
backend = kwargs.get("backend", "nccl")
|
||||||
backup_config = kwargs.get("backup_config", False)
|
backup_config = kwargs.get("backup_config", False)
|
||||||
setup_output_dir = kwargs.get("setup_output_dir", False)
|
setup_output_dir = kwargs.get("setup_output_dir", False)
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import functools
|
import functools
|
||||||
from .registry import MODEL
|
from .registry import MODEL
|
||||||
@ -14,6 +15,60 @@ def _select_norm_layer(norm_type):
|
|||||||
raise NotImplemented(f'normalization layer {norm_type} is not found')
|
raise NotImplemented(f'normalization layer {norm_type} is not found')
|
||||||
|
|
||||||
|
|
||||||
|
class GANImageBuffer(object):
|
||||||
|
"""This class implements an image buffer that stores previously
|
||||||
|
generated images.
|
||||||
|
This buffer allows us to update the discriminator using a history of
|
||||||
|
generated images rather than the ones produced by the latest generator
|
||||||
|
to reduce model oscillation.
|
||||||
|
Args:
|
||||||
|
buffer_size (int): The size of image buffer. If buffer_size = 0,
|
||||||
|
no buffer will be created.
|
||||||
|
buffer_ratio (float): The chance / possibility to use the images
|
||||||
|
previously stored in the buffer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, buffer_size, buffer_ratio=0.5):
|
||||||
|
self.buffer_size = buffer_size
|
||||||
|
# create an empty buffer
|
||||||
|
if self.buffer_size > 0:
|
||||||
|
self.img_num = 0
|
||||||
|
self.image_buffer = []
|
||||||
|
self.buffer_ratio = buffer_ratio
|
||||||
|
|
||||||
|
def query(self, images):
|
||||||
|
"""Query current image batch using a history of generated images.
|
||||||
|
Args:
|
||||||
|
images (Tensor): Current image batch without history information.
|
||||||
|
"""
|
||||||
|
if self.buffer_size == 0: # if the buffer size is 0, do nothing
|
||||||
|
return images
|
||||||
|
return_images = []
|
||||||
|
for image in images:
|
||||||
|
image = torch.unsqueeze(image.data, 0)
|
||||||
|
# if the buffer is not full, keep inserting current images
|
||||||
|
if self.img_num < self.buffer_size:
|
||||||
|
self.img_num = self.img_num + 1
|
||||||
|
self.image_buffer.append(image)
|
||||||
|
return_images.append(image)
|
||||||
|
else:
|
||||||
|
use_buffer = torch.rand(1) < self.buffer_ratio
|
||||||
|
# by self.buffer_ratio, the buffer will return a previously
|
||||||
|
# stored image, and insert the current image into the buffer
|
||||||
|
if use_buffer:
|
||||||
|
random_id = torch.randint(0, self.buffer_size, (1,)).item()
|
||||||
|
image_tmp = self.image_buffer[random_id].clone()
|
||||||
|
self.image_buffer[random_id] = image
|
||||||
|
return_images.append(image_tmp)
|
||||||
|
# by (1 - self.buffer_ratio), the buffer will return the
|
||||||
|
# current image
|
||||||
|
else:
|
||||||
|
return_images.append(image)
|
||||||
|
# collect all the images and return
|
||||||
|
return_images = torch.cat(return_images, 0)
|
||||||
|
return return_images
|
||||||
|
|
||||||
|
|
||||||
@MODEL.register_module()
|
@MODEL.register_module()
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False):
|
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False):
|
||||||
|
|||||||
@ -67,4 +67,6 @@ def generation_init_weights(module, init_type='normal', init_gain=0.02):
|
|||||||
# only normal distribution applies.
|
# only normal distribution applies.
|
||||||
normal_init(m, 1.0, init_gain)
|
normal_init(m, 1.0, init_gain)
|
||||||
|
|
||||||
|
assert isinstance(module, nn.Module)
|
||||||
module.apply(init_func)
|
module.apply(init_func)
|
||||||
|
|
||||||
|
|||||||
@ -1,21 +1,86 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from ignite.engine import Engine
|
|
||||||
from ignite.handlers import Checkpoint
|
import ignite.distributed as idist
|
||||||
|
from ignite.engine import Events
|
||||||
|
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
||||||
|
from ignite.contrib.handlers import BasicTimeProfiler
|
||||||
|
|
||||||
|
|
||||||
class Resumer:
|
def setup_common_handlers(
|
||||||
def __init__(self, to_load, checkpoint_path):
|
trainer,
|
||||||
self.to_load = to_load
|
output_dir=None,
|
||||||
if checkpoint_path is not None:
|
stop_on_nan=True,
|
||||||
checkpoint_path = Path(checkpoint_path)
|
use_profiler=True,
|
||||||
if not checkpoint_path.exists():
|
print_interval_event=None,
|
||||||
raise ValueError(f"Checkpoint '{checkpoint_path}' is not found")
|
metrics_to_print=None,
|
||||||
self.checkpoint_path = checkpoint_path
|
to_save=None,
|
||||||
|
resume_from=None,
|
||||||
|
save_interval_event=None,
|
||||||
|
**checkpoint_kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper method to setup trainer with common handlers.
|
||||||
|
1. TerminateOnNan
|
||||||
|
2. BasicTimeProfiler
|
||||||
|
3. Print
|
||||||
|
4. Checkpoint
|
||||||
|
:param trainer: trainer engine. Output of trainer's `update_function` should be a dictionary
|
||||||
|
or sequence or a single tensor.
|
||||||
|
:param output_dir: output path to indicate where `to_save` objects are stored. Argument is mutually
|
||||||
|
:param stop_on_nan: if True, :class:`~ignite.handlers.TerminateOnNan` handler is added to the trainer.
|
||||||
|
:param use_profiler:
|
||||||
|
:param print_interval_event:
|
||||||
|
:param metrics_to_print:
|
||||||
|
:param to_save:
|
||||||
|
:param resume_from:
|
||||||
|
:param save_interval_event:
|
||||||
|
:param checkpoint_kwargs:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if stop_on_nan:
|
||||||
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
||||||
|
|
||||||
def __call__(self, engine: Engine):
|
if use_profiler:
|
||||||
if self.checkpoint_path is not None:
|
# Create an object of the profiler and attach an engine to it
|
||||||
ckp = torch.load(self.checkpoint_path.as_posix(), map_location="cpu")
|
profiler = BasicTimeProfiler()
|
||||||
Checkpoint.load_objects(to_load=self.to_load, checkpoint=ckp)
|
profiler.attach(trainer)
|
||||||
engine.logger.info(f"resume from a checkpoint {self.checkpoint_path}")
|
|
||||||
|
@trainer.on(Events.EPOCH_COMPLETED(once=1))
|
||||||
|
@idist.one_rank_only()
|
||||||
|
def log_intermediate_results():
|
||||||
|
profiler.print_results(profiler.get_results())
|
||||||
|
|
||||||
|
@trainer.on(Events.COMPLETED)
|
||||||
|
@idist.one_rank_only()
|
||||||
|
def _():
|
||||||
|
profiler.print_results(profiler.get_results())
|
||||||
|
# profiler.write_results(f"{output_dir}/time_profiling.csv")
|
||||||
|
|
||||||
|
if metrics_to_print is not None:
|
||||||
|
if print_interval_event is None:
|
||||||
|
raise ValueError(
|
||||||
|
"If metrics_to_print argument is provided then print_interval_event arguments should be also defined"
|
||||||
|
)
|
||||||
|
|
||||||
|
@trainer.on(print_interval_event)
|
||||||
|
def print_interval(engine):
|
||||||
|
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
|
||||||
|
for m in metrics_to_print:
|
||||||
|
print_str += f"{m}={engine.state.metrics[m]:.3f} "
|
||||||
|
engine.logger.info(print_str)
|
||||||
|
|
||||||
|
if to_save is not None:
|
||||||
|
if resume_from is not None:
|
||||||
|
@trainer.on(Events.STARTED)
|
||||||
|
def resume(engine):
|
||||||
|
checkpoint_path = Path(resume_from)
|
||||||
|
if not checkpoint_path.exists():
|
||||||
|
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
|
||||||
|
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
|
||||||
|
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
||||||
|
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
||||||
|
if save_interval_event is not None:
|
||||||
|
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir), **checkpoint_kwargs)
|
||||||
|
trainer.add_event_handler(save_interval_event, checkpoint_handler)
|
||||||
|
|||||||
@ -119,9 +119,9 @@ class Registry(_Registry):
|
|||||||
return self._module_dict.get(key, None)
|
return self._module_dict.get(key, None)
|
||||||
|
|
||||||
def _register_module(self, module_class, module_name=None, force=False):
|
def _register_module(self, module_class, module_name=None, force=False):
|
||||||
if not inspect.isclass(module_class):
|
# if not inspect.isclass(module_class):
|
||||||
raise TypeError('module must be a class, '
|
# raise TypeError('module must be a class, '
|
||||||
f'but got {type(module_class)}')
|
# f'but got {type(module_class)}')
|
||||||
|
|
||||||
if module_name is None:
|
if module_name is None:
|
||||||
module_name = module_class.__name__
|
module_name = module_class.__name__
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user