add test image handler
This commit is contained in:
parent
35ab7ecd51
commit
1e7f63cf85
@ -97,7 +97,7 @@ data:
|
|||||||
std: [0.5, 0.5, 0.5]
|
std: [0.5, 0.5, 0.5]
|
||||||
test:
|
test:
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 4
|
batch_size: 8
|
||||||
shuffle: False
|
shuffle: False
|
||||||
num_workers: 1
|
num_workers: 1
|
||||||
pin_memory: False
|
pin_memory: False
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from math import ceil
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import ignite.distributed as idist
|
import ignite.distributed as idist
|
||||||
from ignite.engine import Events, Engine
|
from ignite.engine import Events, Engine
|
||||||
@ -168,7 +169,7 @@ def get_trainer(config, logger):
|
|||||||
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||||
|
|
||||||
setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"],
|
setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"],
|
||||||
clear_cuda_cache=True, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
clear_cuda_cache=False, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||||
|
|
||||||
def output_transform(output):
|
def output_transform(output):
|
||||||
loss = dict()
|
loss = dict()
|
||||||
@ -210,6 +211,51 @@ def get_trainer(config, logger):
|
|||||||
engine.state.iteration
|
engine.state.iteration
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(config.misc.random_seed)
|
||||||
|
indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10]
|
||||||
|
|
||||||
|
empty_grid = torch.zeros(0, config.model.generator.in_channels, config.model.generator.img_size,
|
||||||
|
config.model.generator.img_size)
|
||||||
|
fake = dict(a=empty_grid.clone(), b=empty_grid.clone())
|
||||||
|
rec = dict(a=empty_grid.clone(), b=empty_grid.clone())
|
||||||
|
heatmap = dict(a2b=torch.zeros(0, 1, config.model.generator.img_size,
|
||||||
|
config.model.generator.img_size),
|
||||||
|
b2a=torch.zeros(0, 1, config.model.generator.img_size,
|
||||||
|
config.model.generator.img_size))
|
||||||
|
real = dict(a=empty_grid.clone(), b=empty_grid.clone())
|
||||||
|
for i in indices:
|
||||||
|
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||||
|
|
||||||
|
real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size())
|
||||||
|
fake_b, _, heatmap_a2b = generators["a2b"](real_a)
|
||||||
|
fake_a, _, heatmap_b2a = generators["b2a"](real_b)
|
||||||
|
rec_a = generators["b2a"](fake_b)[0]
|
||||||
|
rec_b = generators["a2b"](fake_a)[0]
|
||||||
|
|
||||||
|
fake["a"] = torch.cat([fake["a"], fake_a.cpu()])
|
||||||
|
fake["b"] = torch.cat([fake["b"], fake_b.cpu()])
|
||||||
|
real["a"] = torch.cat([real["a"], real_a.cpu()])
|
||||||
|
real["b"] = torch.cat([real["b"], real_b.cpu()])
|
||||||
|
rec["a"] = torch.cat([rec["a"], rec_a.cpu()])
|
||||||
|
rec["b"] = torch.cat([rec["b"], rec_b.cpu()])
|
||||||
|
|
||||||
|
heatmap["a2b"] = torch.cat(
|
||||||
|
[heatmap["a2b"], torch.nn.functional.interpolate(heatmap_a2b, real_a.size()[-2:]).cpu()])
|
||||||
|
heatmap["b2a"] = torch.cat(
|
||||||
|
[heatmap["b2a"], torch.nn.functional.interpolate(heatmap_b2a, real_a.size()[-2:]).cpu()])
|
||||||
|
tensorboard_handler.writer.add_image(
|
||||||
|
"test/a",
|
||||||
|
make_2d_grid([heatmap["a2b"].expand_as(real["a"]), real["a"], fake["b"], rec["a"]]),
|
||||||
|
engine.state.iteration
|
||||||
|
)
|
||||||
|
tensorboard_handler.writer.add_image(
|
||||||
|
"test/b",
|
||||||
|
make_2d_grid([heatmap["b2a"].expand_as(real["a"]), real["b"], fake["a"], rec["b"]]),
|
||||||
|
engine.state.iteration
|
||||||
|
)
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
@ -225,6 +271,9 @@ def run(task, config, logger):
|
|||||||
logger.info(f"train with dataset:\n{train_dataset}")
|
logger.info(f"train with dataset:\n{train_dataset}")
|
||||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||||
trainer = get_trainer(config, logger)
|
trainer = get_trainer(config, logger)
|
||||||
|
if idist.get_rank() == 0:
|
||||||
|
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||||
|
trainer.state.test_dataset = test_dataset
|
||||||
try:
|
try:
|
||||||
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
3
main.py
3
main.py
@ -16,6 +16,9 @@ def log_basic_info(logger, config):
|
|||||||
logger.info(f"Train {config.name}")
|
logger.info(f"Train {config.name}")
|
||||||
logger.info(f"- PyTorch version: {torch.__version__}")
|
logger.info(f"- PyTorch version: {torch.__version__}")
|
||||||
logger.info(f"- Ignite version: {ignite.__version__}")
|
logger.info(f"- Ignite version: {ignite.__version__}")
|
||||||
|
logger.info(f"- CUDA version: {torch.version.cuda}")
|
||||||
|
logger.info(f"- cuDNN version: {torch.backends.cudnn.version()}")
|
||||||
|
logger.info(f"- GPU type: {torch.cuda.get_device_name(0)}")
|
||||||
if idist.get_world_size() > 1:
|
if idist.get_world_size() > 1:
|
||||||
logger.info("Distributed setting:\n")
|
logger.info("Distributed setting:\n")
|
||||||
idist.show_config()
|
idist.show_config()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user