almost same as mmedit

This commit is contained in:
Ray Wong 2020-08-08 13:17:26 +08:00
parent 7cf235781d
commit a5133e6795
3 changed files with 61 additions and 18 deletions

View File

@ -1,14 +1,14 @@
name: horse2zebra name: horse2zebra
engine: cyclegan engine: cyclegan
result_dir: ./result result_dir: ./result
max_iteration: 18000 max_iteration: 16600
distributed: distributed:
model: model:
# broadcast_buffers: False # broadcast_buffers: False
misc: misc:
random_seed: 1004 random_seed: 324
checkpoints: checkpoints:
interval: 2000 interval: 2000
@ -29,12 +29,12 @@ model:
use_dropout: False use_dropout: False
discriminator: discriminator:
_type: PatchDiscriminator _type: PatchDiscriminator
_distributed: # _distributed:
bn_to_syncbn: True # bn_to_syncbn: False
in_channels: 3 in_channels: 3
base_channels: 64 base_channels: 64
num_conv: 3 num_conv: 3
norm_type: BN norm_type: IN
loss: loss:
gan: gan:
@ -82,7 +82,7 @@ data:
- RandomHorizontalFlip - RandomHorizontalFlip
- ToTensor - ToTensor
scheduler: scheduler:
start: 9000 start: 8300
target_lr: 0 target_lr: 0
test: test:
dataloader: dataloader:

View File

@ -40,14 +40,16 @@ def get_trainer(config, logger):
config.optimizers.discriminator) config.optimizers.discriminator)
milestones_values = [ milestones_values = [
(config.data.train.scheduler.start, config.optimizers.generator.lr), (0, config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr), (100, config.optimizers.generator.lr),
(200, config.data.train.scheduler.target_lr)
] ]
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values) lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
milestones_values = [ milestones_values = [
(config.data.train.scheduler.start, config.optimizers.discriminator.lr), (0, config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr), (100, config.optimizers.discriminator.lr),
(200, config.data.train.scheduler.target_lr)
] ]
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values) lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
@ -73,13 +75,14 @@ def get_trainer(config, logger):
discriminator_a.requires_grad_(False) discriminator_a.requires_grad_(False)
discriminator_b.requires_grad_(False) discriminator_b.requires_grad_(False)
loss_g = dict( loss_g = dict(
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a), cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b), cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True), gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True) gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
) )
if config.loss.id.weight > 0:
loss_g["id_a"] = config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
loss_g["id_b"] = config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
sum(loss_g.values()).backward() sum(loss_g.values()).backward()
optimizer_g.step() optimizer_g.step()
@ -116,8 +119,8 @@ def get_trainer(config, logger):
trainer = Engine(_step) trainer = Engine(_step)
trainer.logger = logger trainer.logger = logger
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g) trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_g)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d) trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_d)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a") RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
@ -129,10 +132,12 @@ def get_trainer(config, logger):
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
) )
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.ITERATION_COMPLETED(every=10), setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], to_save=to_save, filename_prefix=config.name, to_save=to_save,
resume_from=config.resume_from, n_saved=5, filename_prefix=config.name, print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
save_interval_event=Events.ITERATION_COMPLETED(every=config.checkpoints.interval)) metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"],
save_interval_event=Events.ITERATION_COMPLETED(
every=config.checkpoints.interval) | Events.COMPLETED)
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration)) @trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
def terminate(engine): def terminate(engine):
@ -147,12 +152,23 @@ def get_trainer(config, logger):
def global_step_transform(*args, **kwargs): def global_step_transform(*args, **kwargs):
return trainer.state.iteration return trainer.state.iteration
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
tb_logger.attach( tb_logger.attach(
trainer, trainer,
log_handler=OutputHandler( log_handler=OutputHandler(
tag="loss", tag="loss",
metric_names=["loss_g", "loss_d_a", "loss_d_b"], metric_names=["loss_g", "loss_d_a", "loss_d_b"],
global_step_transform=global_step_transform, global_step_transform=global_step_transform,
output_transform=output_transform
), ),
event_name=Events.ITERATION_COMPLETED(every=50) event_name=Events.ITERATION_COMPLETED(every=50)
) )

27
util/build.py Normal file
View File

@ -0,0 +1,27 @@
import torch
import torch.optim as optim
import ignite.distributed as idist
from omegaconf import OmegaConf
from model import MODEL
from util.distributed import auto_model
def build_model(cfg, distributed_args=None):
cfg = OmegaConf.to_container(cfg)
model_distributed_config = cfg.pop("_distributed", dict())
model = MODEL.build_with(cfg)
if model_distributed_config.get("bn_to_syncbn"):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
return auto_model(model, **distributed_args)
def build_optimizer(params, cfg):
assert "_type" in cfg
cfg = OmegaConf.to_container(cfg)
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
return idist.auto_optim(optimizer)