add test image handler

This commit is contained in:
Ray Wong 2020-08-23 19:49:04 +08:00
parent 35ab7ecd51
commit 1e7f63cf85
3 changed files with 54 additions and 2 deletions

View File

@ -97,7 +97,7 @@ data:
std: [0.5, 0.5, 0.5]
test:
dataloader:
batch_size: 4
batch_size: 8
shuffle: False
num_workers: 1
pin_memory: False

View File

@ -4,6 +4,7 @@ from math import ceil
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import ignite.distributed as idist
from ignite.engine import Events, Engine
@ -168,7 +169,7 @@ def get_trainer(config, logger):
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"],
clear_cuda_cache=True, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
clear_cuda_cache=False, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
loss = dict()
@ -210,6 +211,51 @@ def get_trainer(config, logger):
engine.state.iteration
)
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10]
empty_grid = torch.zeros(0, config.model.generator.in_channels, config.model.generator.img_size,
config.model.generator.img_size)
fake = dict(a=empty_grid.clone(), b=empty_grid.clone())
rec = dict(a=empty_grid.clone(), b=empty_grid.clone())
heatmap = dict(a2b=torch.zeros(0, 1, config.model.generator.img_size,
config.model.generator.img_size),
b2a=torch.zeros(0, 1, config.model.generator.img_size,
config.model.generator.img_size))
real = dict(a=empty_grid.clone(), b=empty_grid.clone())
for i in indices:
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size())
fake_b, _, heatmap_a2b = generators["a2b"](real_a)
fake_a, _, heatmap_b2a = generators["b2a"](real_b)
rec_a = generators["b2a"](fake_b)[0]
rec_b = generators["a2b"](fake_a)[0]
fake["a"] = torch.cat([fake["a"], fake_a.cpu()])
fake["b"] = torch.cat([fake["b"], fake_b.cpu()])
real["a"] = torch.cat([real["a"], real_a.cpu()])
real["b"] = torch.cat([real["b"], real_b.cpu()])
rec["a"] = torch.cat([rec["a"], rec_a.cpu()])
rec["b"] = torch.cat([rec["b"], rec_b.cpu()])
heatmap["a2b"] = torch.cat(
[heatmap["a2b"], torch.nn.functional.interpolate(heatmap_a2b, real_a.size()[-2:]).cpu()])
heatmap["b2a"] = torch.cat(
[heatmap["b2a"], torch.nn.functional.interpolate(heatmap_b2a, real_a.size()[-2:]).cpu()])
tensorboard_handler.writer.add_image(
"test/a",
make_2d_grid([heatmap["a2b"].expand_as(real["a"]), real["a"], fake["b"], rec["a"]]),
engine.state.iteration
)
tensorboard_handler.writer.add_image(
"test/b",
make_2d_grid([heatmap["b2a"].expand_as(real["a"]), real["b"], fake["a"], rec["b"]]),
engine.state.iteration
)
return trainer
@ -225,6 +271,9 @@ def run(task, config, logger):
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception:

View File

@ -16,6 +16,9 @@ def log_basic_info(logger, config):
logger.info(f"Train {config.name}")
logger.info(f"- PyTorch version: {torch.__version__}")
logger.info(f"- Ignite version: {ignite.__version__}")
logger.info(f"- CUDA version: {torch.version.cuda}")
logger.info(f"- cuDNN version: {torch.backends.cudnn.version()}")
logger.info(f"- GPU type: {torch.cuda.get_device_name(0)}")
if idist.get_world_size() > 1:
logger.info("Distributed setting:\n")
idist.show_config()