add test image handler
This commit is contained in:
parent
35ab7ecd51
commit
1e7f63cf85
@ -97,7 +97,7 @@ data:
|
||||
std: [0.5, 0.5, 0.5]
|
||||
test:
|
||||
dataloader:
|
||||
batch_size: 4
|
||||
batch_size: 8
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
|
||||
@ -4,6 +4,7 @@ from math import ceil
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
@ -168,7 +169,7 @@ def get_trainer(config, logger):
|
||||
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||
|
||||
setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"],
|
||||
clear_cuda_cache=True, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
clear_cuda_cache=False, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
@ -210,6 +211,51 @@ def get_trainer(config, logger):
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed)
|
||||
indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10]
|
||||
|
||||
empty_grid = torch.zeros(0, config.model.generator.in_channels, config.model.generator.img_size,
|
||||
config.model.generator.img_size)
|
||||
fake = dict(a=empty_grid.clone(), b=empty_grid.clone())
|
||||
rec = dict(a=empty_grid.clone(), b=empty_grid.clone())
|
||||
heatmap = dict(a2b=torch.zeros(0, 1, config.model.generator.img_size,
|
||||
config.model.generator.img_size),
|
||||
b2a=torch.zeros(0, 1, config.model.generator.img_size,
|
||||
config.model.generator.img_size))
|
||||
real = dict(a=empty_grid.clone(), b=empty_grid.clone())
|
||||
for i in indices:
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
|
||||
real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size())
|
||||
fake_b, _, heatmap_a2b = generators["a2b"](real_a)
|
||||
fake_a, _, heatmap_b2a = generators["b2a"](real_b)
|
||||
rec_a = generators["b2a"](fake_b)[0]
|
||||
rec_b = generators["a2b"](fake_a)[0]
|
||||
|
||||
fake["a"] = torch.cat([fake["a"], fake_a.cpu()])
|
||||
fake["b"] = torch.cat([fake["b"], fake_b.cpu()])
|
||||
real["a"] = torch.cat([real["a"], real_a.cpu()])
|
||||
real["b"] = torch.cat([real["b"], real_b.cpu()])
|
||||
rec["a"] = torch.cat([rec["a"], rec_a.cpu()])
|
||||
rec["b"] = torch.cat([rec["b"], rec_b.cpu()])
|
||||
|
||||
heatmap["a2b"] = torch.cat(
|
||||
[heatmap["a2b"], torch.nn.functional.interpolate(heatmap_a2b, real_a.size()[-2:]).cpu()])
|
||||
heatmap["b2a"] = torch.cat(
|
||||
[heatmap["b2a"], torch.nn.functional.interpolate(heatmap_b2a, real_a.size()[-2:]).cpu()])
|
||||
tensorboard_handler.writer.add_image(
|
||||
"test/a",
|
||||
make_2d_grid([heatmap["a2b"].expand_as(real["a"]), real["a"], fake["b"], rec["a"]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
tensorboard_handler.writer.add_image(
|
||||
"test/b",
|
||||
make_2d_grid([heatmap["b2a"].expand_as(real["a"]), real["b"], fake["a"], rec["b"]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
@ -225,6 +271,9 @@ def run(task, config, logger):
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, logger)
|
||||
if idist.get_rank() == 0:
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
trainer.state.test_dataset = test_dataset
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||
except Exception:
|
||||
|
||||
3
main.py
3
main.py
@ -16,6 +16,9 @@ def log_basic_info(logger, config):
|
||||
logger.info(f"Train {config.name}")
|
||||
logger.info(f"- PyTorch version: {torch.__version__}")
|
||||
logger.info(f"- Ignite version: {ignite.__version__}")
|
||||
logger.info(f"- CUDA version: {torch.version.cuda}")
|
||||
logger.info(f"- cuDNN version: {torch.backends.cudnn.version()}")
|
||||
logger.info(f"- GPU type: {torch.cuda.get_device_name(0)}")
|
||||
if idist.get_world_size() > 1:
|
||||
logger.info("Distributed setting:\n")
|
||||
idist.show_config()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user