diff --git a/configs/synthesizers/UGATIT.yml b/configs/synthesizers/UGATIT.yml index 55214b1..a92a541 100644 --- a/configs/synthesizers/UGATIT.yml +++ b/configs/synthesizers/UGATIT.yml @@ -97,7 +97,7 @@ data: std: [0.5, 0.5, 0.5] test: dataloader: - batch_size: 4 + batch_size: 8 shuffle: False num_workers: 1 pin_memory: False diff --git a/engine/UGATIT.py b/engine/UGATIT.py index 568758d..518cbfe 100644 --- a/engine/UGATIT.py +++ b/engine/UGATIT.py @@ -4,6 +4,7 @@ from math import ceil import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.data import DataLoader import ignite.distributed as idist from ignite.engine import Events, Engine @@ -168,7 +169,7 @@ def get_trainer(config, logger): to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators}) setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"], - clear_cuda_cache=True, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) + clear_cuda_cache=False, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) def output_transform(output): loss = dict() @@ -210,6 +211,51 @@ def get_trainer(config, logger): engine.state.iteration ) + with torch.no_grad(): + g = torch.Generator() + g.manual_seed(config.misc.random_seed) + indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10] + + empty_grid = torch.zeros(0, config.model.generator.in_channels, config.model.generator.img_size, + config.model.generator.img_size) + fake = dict(a=empty_grid.clone(), b=empty_grid.clone()) + rec = dict(a=empty_grid.clone(), b=empty_grid.clone()) + heatmap = dict(a2b=torch.zeros(0, 1, config.model.generator.img_size, + config.model.generator.img_size), + b2a=torch.zeros(0, 1, config.model.generator.img_size, + config.model.generator.img_size)) + real = dict(a=empty_grid.clone(), b=empty_grid.clone()) + for i in indices: + batch = convert_tensor(engine.state.test_dataset[i], idist.device()) + + real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size()) + fake_b, _, heatmap_a2b = generators["a2b"](real_a) + fake_a, _, heatmap_b2a = generators["b2a"](real_b) + rec_a = generators["b2a"](fake_b)[0] + rec_b = generators["a2b"](fake_a)[0] + + fake["a"] = torch.cat([fake["a"], fake_a.cpu()]) + fake["b"] = torch.cat([fake["b"], fake_b.cpu()]) + real["a"] = torch.cat([real["a"], real_a.cpu()]) + real["b"] = torch.cat([real["b"], real_b.cpu()]) + rec["a"] = torch.cat([rec["a"], rec_a.cpu()]) + rec["b"] = torch.cat([rec["b"], rec_b.cpu()]) + + heatmap["a2b"] = torch.cat( + [heatmap["a2b"], torch.nn.functional.interpolate(heatmap_a2b, real_a.size()[-2:]).cpu()]) + heatmap["b2a"] = torch.cat( + [heatmap["b2a"], torch.nn.functional.interpolate(heatmap_b2a, real_a.size()[-2:]).cpu()]) + tensorboard_handler.writer.add_image( + "test/a", + make_2d_grid([heatmap["a2b"].expand_as(real["a"]), real["a"], fake["b"], rec["a"]]), + engine.state.iteration + ) + tensorboard_handler.writer.add_image( + "test/b", + make_2d_grid([heatmap["b2a"].expand_as(real["a"]), real["b"], fake["a"], rec["b"]]), + engine.state.iteration + ) + return trainer @@ -225,6 +271,9 @@ def run(task, config, logger): logger.info(f"train with dataset:\n{train_dataset}") train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) trainer = get_trainer(config, logger) + if idist.get_rank() == 0: + test_dataset = data.DATASET.build_with(config.data.test.dataset) + trainer.state.test_dataset = test_dataset try: trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader))) except Exception: diff --git a/main.py b/main.py index 9dff274..eeba793 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,9 @@ def log_basic_info(logger, config): logger.info(f"Train {config.name}") logger.info(f"- PyTorch version: {torch.__version__}") logger.info(f"- Ignite version: {ignite.__version__}") + logger.info(f"- CUDA version: {torch.version.cuda}") + logger.info(f"- cuDNN version: {torch.backends.cudnn.version()}") + logger.info(f"- GPU type: {torch.cuda.get_device_name(0)}") if idist.get_world_size() > 1: logger.info("Distributed setting:\n") idist.show_config()