diff --git a/configs/synthesizers/cyclegan.yml b/configs/synthesizers/cyclegan.yml index 9647f63..1798ebe 100644 --- a/configs/synthesizers/cyclegan.yml +++ b/configs/synthesizers/cyclegan.yml @@ -1,14 +1,14 @@ name: horse2zebra engine: cyclegan result_dir: ./result -max_iteration: 18000 +max_iteration: 16600 distributed: model: # broadcast_buffers: False misc: - random_seed: 1004 + random_seed: 324 checkpoints: interval: 2000 @@ -29,12 +29,12 @@ model: use_dropout: False discriminator: _type: PatchDiscriminator - _distributed: - bn_to_syncbn: True +# _distributed: +# bn_to_syncbn: False in_channels: 3 base_channels: 64 num_conv: 3 - norm_type: BN + norm_type: IN loss: gan: @@ -82,7 +82,7 @@ data: - RandomHorizontalFlip - ToTensor scheduler: - start: 9000 + start: 8300 target_lr: 0 test: dataloader: diff --git a/engine/cyclegan.py b/engine/cyclegan.py index 80e38aa..aac8aee 100644 --- a/engine/cyclegan.py +++ b/engine/cyclegan.py @@ -40,14 +40,16 @@ def get_trainer(config, logger): config.optimizers.discriminator) milestones_values = [ - (config.data.train.scheduler.start, config.optimizers.generator.lr), - (config.max_iteration, config.data.train.scheduler.target_lr), + (0, config.optimizers.generator.lr), + (100, config.optimizers.generator.lr), + (200, config.data.train.scheduler.target_lr) ] lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values) milestones_values = [ - (config.data.train.scheduler.start, config.optimizers.discriminator.lr), - (config.max_iteration, config.data.train.scheduler.target_lr), + (0, config.optimizers.discriminator.lr), + (100, config.optimizers.discriminator.lr), + (200, config.data.train.scheduler.target_lr) ] lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values) @@ -73,13 +75,14 @@ def get_trainer(config, logger): discriminator_a.requires_grad_(False) discriminator_b.requires_grad_(False) loss_g = dict( - id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B) - id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A) cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a), cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b), gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True), gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True) ) + if config.loss.id.weight > 0: + loss_g["id_a"] = config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B) + loss_g["id_b"] = config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A) sum(loss_g.values()).backward() optimizer_g.step() @@ -116,8 +119,8 @@ def get_trainer(config, logger): trainer = Engine(_step) trainer.logger = logger - trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g) - trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d) + trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_g) + trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_d) RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a") @@ -129,10 +132,12 @@ def get_trainer(config, logger): lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g ) - setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.ITERATION_COMPLETED(every=10), - metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], to_save=to_save, - resume_from=config.resume_from, n_saved=5, filename_prefix=config.name, - save_interval_event=Events.ITERATION_COMPLETED(every=config.checkpoints.interval)) + setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5, + filename_prefix=config.name, to_save=to_save, + print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED, + metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], + save_interval_event=Events.ITERATION_COMPLETED( + every=config.checkpoints.interval) | Events.COMPLETED) @trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration)) def terminate(engine): @@ -147,12 +152,23 @@ def get_trainer(config, logger): def global_step_transform(*args, **kwargs): return trainer.state.iteration + def output_transform(output): + loss = dict() + for tl in output["loss"]: + if isinstance(output["loss"][tl], dict): + for l in output["loss"][tl]: + loss[f"{tl}_{l}"] = output["loss"][tl][l] + else: + loss[tl] = output["loss"][tl] + return loss + tb_logger.attach( trainer, log_handler=OutputHandler( tag="loss", metric_names=["loss_g", "loss_d_a", "loss_d_b"], global_step_transform=global_step_transform, + output_transform=output_transform ), event_name=Events.ITERATION_COMPLETED(every=50) ) diff --git a/util/build.py b/util/build.py new file mode 100644 index 0000000..0e53b98 --- /dev/null +++ b/util/build.py @@ -0,0 +1,27 @@ +import torch +import torch.optim as optim +import ignite.distributed as idist + +from omegaconf import OmegaConf + +from model import MODEL +from util.distributed import auto_model + + +def build_model(cfg, distributed_args=None): + cfg = OmegaConf.to_container(cfg) + model_distributed_config = cfg.pop("_distributed", dict()) + model = MODEL.build_with(cfg) + + if model_distributed_config.get("bn_to_syncbn"): + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + + distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args + return auto_model(model, **distributed_args) + + +def build_optimizer(params, cfg): + assert "_type" in cfg + cfg = OmegaConf.to_container(cfg) + optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg) + return idist.auto_optim(optimizer)