almost same as mmedit
This commit is contained in:
parent
7cf235781d
commit
a5133e6795
@ -1,14 +1,14 @@
|
|||||||
name: horse2zebra
|
name: horse2zebra
|
||||||
engine: cyclegan
|
engine: cyclegan
|
||||||
result_dir: ./result
|
result_dir: ./result
|
||||||
max_iteration: 18000
|
max_iteration: 16600
|
||||||
|
|
||||||
distributed:
|
distributed:
|
||||||
model:
|
model:
|
||||||
# broadcast_buffers: False
|
# broadcast_buffers: False
|
||||||
|
|
||||||
misc:
|
misc:
|
||||||
random_seed: 1004
|
random_seed: 324
|
||||||
|
|
||||||
checkpoints:
|
checkpoints:
|
||||||
interval: 2000
|
interval: 2000
|
||||||
@ -29,12 +29,12 @@ model:
|
|||||||
use_dropout: False
|
use_dropout: False
|
||||||
discriminator:
|
discriminator:
|
||||||
_type: PatchDiscriminator
|
_type: PatchDiscriminator
|
||||||
_distributed:
|
# _distributed:
|
||||||
bn_to_syncbn: True
|
# bn_to_syncbn: False
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
base_channels: 64
|
base_channels: 64
|
||||||
num_conv: 3
|
num_conv: 3
|
||||||
norm_type: BN
|
norm_type: IN
|
||||||
|
|
||||||
loss:
|
loss:
|
||||||
gan:
|
gan:
|
||||||
@ -82,7 +82,7 @@ data:
|
|||||||
- RandomHorizontalFlip
|
- RandomHorizontalFlip
|
||||||
- ToTensor
|
- ToTensor
|
||||||
scheduler:
|
scheduler:
|
||||||
start: 9000
|
start: 8300
|
||||||
target_lr: 0
|
target_lr: 0
|
||||||
test:
|
test:
|
||||||
dataloader:
|
dataloader:
|
||||||
|
|||||||
@ -40,14 +40,16 @@ def get_trainer(config, logger):
|
|||||||
config.optimizers.discriminator)
|
config.optimizers.discriminator)
|
||||||
|
|
||||||
milestones_values = [
|
milestones_values = [
|
||||||
(config.data.train.scheduler.start, config.optimizers.generator.lr),
|
(0, config.optimizers.generator.lr),
|
||||||
(config.max_iteration, config.data.train.scheduler.target_lr),
|
(100, config.optimizers.generator.lr),
|
||||||
|
(200, config.data.train.scheduler.target_lr)
|
||||||
]
|
]
|
||||||
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
||||||
|
|
||||||
milestones_values = [
|
milestones_values = [
|
||||||
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
|
(0, config.optimizers.discriminator.lr),
|
||||||
(config.max_iteration, config.data.train.scheduler.target_lr),
|
(100, config.optimizers.discriminator.lr),
|
||||||
|
(200, config.data.train.scheduler.target_lr)
|
||||||
]
|
]
|
||||||
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
||||||
|
|
||||||
@ -73,13 +75,14 @@ def get_trainer(config, logger):
|
|||||||
discriminator_a.requires_grad_(False)
|
discriminator_a.requires_grad_(False)
|
||||||
discriminator_b.requires_grad_(False)
|
discriminator_b.requires_grad_(False)
|
||||||
loss_g = dict(
|
loss_g = dict(
|
||||||
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
|
|
||||||
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
|
|
||||||
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
|
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
|
||||||
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
|
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
|
||||||
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
|
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
|
||||||
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
|
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
|
||||||
)
|
)
|
||||||
|
if config.loss.id.weight > 0:
|
||||||
|
loss_g["id_a"] = config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
|
||||||
|
loss_g["id_b"] = config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
|
||||||
sum(loss_g.values()).backward()
|
sum(loss_g.values()).backward()
|
||||||
optimizer_g.step()
|
optimizer_g.step()
|
||||||
|
|
||||||
@ -116,8 +119,8 @@ def get_trainer(config, logger):
|
|||||||
|
|
||||||
trainer = Engine(_step)
|
trainer = Engine(_step)
|
||||||
trainer.logger = logger
|
trainer.logger = logger
|
||||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
|
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_g)
|
||||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
|
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_d)
|
||||||
|
|
||||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
|
||||||
@ -129,10 +132,12 @@ def get_trainer(config, logger):
|
|||||||
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
|
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
|
||||||
)
|
)
|
||||||
|
|
||||||
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.ITERATION_COMPLETED(every=10),
|
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
|
||||||
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], to_save=to_save,
|
filename_prefix=config.name, to_save=to_save,
|
||||||
resume_from=config.resume_from, n_saved=5, filename_prefix=config.name,
|
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
|
||||||
save_interval_event=Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
|
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"],
|
||||||
|
save_interval_event=Events.ITERATION_COMPLETED(
|
||||||
|
every=config.checkpoints.interval) | Events.COMPLETED)
|
||||||
|
|
||||||
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
|
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||||
def terminate(engine):
|
def terminate(engine):
|
||||||
@ -147,12 +152,23 @@ def get_trainer(config, logger):
|
|||||||
def global_step_transform(*args, **kwargs):
|
def global_step_transform(*args, **kwargs):
|
||||||
return trainer.state.iteration
|
return trainer.state.iteration
|
||||||
|
|
||||||
|
def output_transform(output):
|
||||||
|
loss = dict()
|
||||||
|
for tl in output["loss"]:
|
||||||
|
if isinstance(output["loss"][tl], dict):
|
||||||
|
for l in output["loss"][tl]:
|
||||||
|
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||||
|
else:
|
||||||
|
loss[tl] = output["loss"][tl]
|
||||||
|
return loss
|
||||||
|
|
||||||
tb_logger.attach(
|
tb_logger.attach(
|
||||||
trainer,
|
trainer,
|
||||||
log_handler=OutputHandler(
|
log_handler=OutputHandler(
|
||||||
tag="loss",
|
tag="loss",
|
||||||
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
|
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
|
||||||
global_step_transform=global_step_transform,
|
global_step_transform=global_step_transform,
|
||||||
|
output_transform=output_transform
|
||||||
),
|
),
|
||||||
event_name=Events.ITERATION_COMPLETED(every=50)
|
event_name=Events.ITERATION_COMPLETED(every=50)
|
||||||
)
|
)
|
||||||
|
|||||||
27
util/build.py
Normal file
27
util/build.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
import ignite.distributed as idist
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from model import MODEL
|
||||||
|
from util.distributed import auto_model
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(cfg, distributed_args=None):
|
||||||
|
cfg = OmegaConf.to_container(cfg)
|
||||||
|
model_distributed_config = cfg.pop("_distributed", dict())
|
||||||
|
model = MODEL.build_with(cfg)
|
||||||
|
|
||||||
|
if model_distributed_config.get("bn_to_syncbn"):
|
||||||
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||||
|
|
||||||
|
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
|
||||||
|
return auto_model(model, **distributed_args)
|
||||||
|
|
||||||
|
|
||||||
|
def build_optimizer(params, cfg):
|
||||||
|
assert "_type" in cfg
|
||||||
|
cfg = OmegaConf.to_container(cfg)
|
||||||
|
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
|
||||||
|
return idist.auto_optim(optimizer)
|
||||||
Loading…
Reference in New Issue
Block a user