almost same as mmedit
This commit is contained in:
parent
7cf235781d
commit
a5133e6795
@ -1,14 +1,14 @@
|
||||
name: horse2zebra
|
||||
engine: cyclegan
|
||||
result_dir: ./result
|
||||
max_iteration: 18000
|
||||
max_iteration: 16600
|
||||
|
||||
distributed:
|
||||
model:
|
||||
# broadcast_buffers: False
|
||||
|
||||
misc:
|
||||
random_seed: 1004
|
||||
random_seed: 324
|
||||
|
||||
checkpoints:
|
||||
interval: 2000
|
||||
@ -29,12 +29,12 @@ model:
|
||||
use_dropout: False
|
||||
discriminator:
|
||||
_type: PatchDiscriminator
|
||||
_distributed:
|
||||
bn_to_syncbn: True
|
||||
# _distributed:
|
||||
# bn_to_syncbn: False
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_conv: 3
|
||||
norm_type: BN
|
||||
norm_type: IN
|
||||
|
||||
loss:
|
||||
gan:
|
||||
@ -82,7 +82,7 @@ data:
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
scheduler:
|
||||
start: 9000
|
||||
start: 8300
|
||||
target_lr: 0
|
||||
test:
|
||||
dataloader:
|
||||
|
||||
@ -40,14 +40,16 @@ def get_trainer(config, logger):
|
||||
config.optimizers.discriminator)
|
||||
|
||||
milestones_values = [
|
||||
(config.data.train.scheduler.start, config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr),
|
||||
(0, config.optimizers.generator.lr),
|
||||
(100, config.optimizers.generator.lr),
|
||||
(200, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
milestones_values = [
|
||||
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr),
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(100, config.optimizers.discriminator.lr),
|
||||
(200, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
@ -73,13 +75,14 @@ def get_trainer(config, logger):
|
||||
discriminator_a.requires_grad_(False)
|
||||
discriminator_b.requires_grad_(False)
|
||||
loss_g = dict(
|
||||
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
|
||||
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
|
||||
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
|
||||
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
|
||||
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
|
||||
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
|
||||
)
|
||||
if config.loss.id.weight > 0:
|
||||
loss_g["id_a"] = config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
|
||||
loss_g["id_b"] = config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
|
||||
sum(loss_g.values()).backward()
|
||||
optimizer_g.step()
|
||||
|
||||
@ -116,8 +119,8 @@ def get_trainer(config, logger):
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_g)
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_d)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
|
||||
@ -129,10 +132,12 @@ def get_trainer(config, logger):
|
||||
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
|
||||
)
|
||||
|
||||
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.ITERATION_COMPLETED(every=10),
|
||||
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], to_save=to_save,
|
||||
resume_from=config.resume_from, n_saved=5, filename_prefix=config.name,
|
||||
save_interval_event=Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
|
||||
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
|
||||
filename_prefix=config.name, to_save=to_save,
|
||||
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
|
||||
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"],
|
||||
save_interval_event=Events.ITERATION_COMPLETED(
|
||||
every=config.checkpoints.interval) | Events.COMPLETED)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
def terminate(engine):
|
||||
@ -147,12 +152,23 @@ def get_trainer(config, logger):
|
||||
def global_step_transform(*args, **kwargs):
|
||||
return trainer.state.iteration
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
for tl in output["loss"]:
|
||||
if isinstance(output["loss"][tl], dict):
|
||||
for l in output["loss"][tl]:
|
||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||
else:
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="loss",
|
||||
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
|
||||
global_step_transform=global_step_transform,
|
||||
output_transform=output_transform
|
||||
),
|
||||
event_name=Events.ITERATION_COMPLETED(every=50)
|
||||
)
|
||||
|
||||
27
util/build.py
Normal file
27
util/build.py
Normal file
@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import ignite.distributed as idist
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model import MODEL
|
||||
from util.distributed import auto_model
|
||||
|
||||
|
||||
def build_model(cfg, distributed_args=None):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
model_distributed_config = cfg.pop("_distributed", dict())
|
||||
model = MODEL.build_with(cfg)
|
||||
|
||||
if model_distributed_config.get("bn_to_syncbn"):
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
|
||||
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
|
||||
return auto_model(model, **distributed_args)
|
||||
|
||||
|
||||
def build_optimizer(params, cfg):
|
||||
assert "_type" in cfg
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
|
||||
return idist.auto_optim(optimizer)
|
||||
Loading…
Reference in New Issue
Block a user