Compare commits

...

2 Commits

Author SHA1 Message Date
09db0a413f restore UGATIT.yml 2020-08-25 09:33:42 +08:00
56b355737f add test handler 2020-08-24 22:46:36 +08:00
4 changed files with 62 additions and 5 deletions

View File

@ -128,7 +128,7 @@ class EpisodicDataset(Dataset):
@DATASET.register_module()
class SingleFolderDataset(Dataset):
def __init__(self, root, pipeline):
def __init__(self, root, pipeline, with_path=False):
assert os.path.isdir(root)
self.root = root
samples = []
@ -139,12 +139,16 @@ class SingleFolderDataset(Dataset):
samples.append(path)
self.samples = samples
self.pipeline = transform_pipeline(pipeline)
self.with_path = with_path
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.pipeline(self.samples[idx])
if not self.with_path:
return self.pipeline(self.samples[idx])
else:
return self.pipeline(self.samples[idx]), self.samples[idx]
def __repr__(self):
return f"<SingleFolderDataset root={self.root} len={len(self)}>"

View File

@ -1,9 +1,11 @@
from itertools import chain
from math import ceil
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import ignite.distributed as idist
from ignite.engine import Events, Engine
@ -175,7 +177,7 @@ def get_trainer(config, logger):
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update({f"generator_{k}": generators[k] for k in generators})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True,
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
@ -247,6 +249,43 @@ def get_trainer(config, logger):
return trainer
def get_tester(config, logger):
generator_a2b = build_model(config.model.generator, config.distributed.model)
def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device())
with torch.no_grad():
fake_b = generator_a2b(real_a)[0]
return {"path": path, "img": [real_a.detach(), fake_b.detach()]}
tester = Engine(_step)
tester.logger = logger
to_load = dict(generator_a2b=generator_a2b)
setup_common_handlers(tester, config, use_profiler=False, to_save=to_load)
@tester.on(Events.STARTED)
def mkdir(engine):
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
engine.state.img_output_dir = Path(img_output_dir)
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
engine.state.img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output["img"]
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
# image_name = f"{engine.state.iteration * batch_size - batch_size + i + 1}.png"
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors))
return tester
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
@ -267,5 +306,16 @@ def run(task, config, logger):
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, logger)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

View File

@ -17,7 +17,7 @@ def empty_cuda_cache(_):
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
to_save=None, end_event=None, set_epoch_for_dist_sampler=True):
to_save=None, end_event=None, set_epoch_for_dist_sampler=False):
"""
Helper method to setup trainer with common handlers.
1. TerminateOnNan

View File

@ -1,5 +1,6 @@
import logging
from typing import Optional
from pathlib import Path
def setup_logger(
@ -76,10 +77,12 @@ def setup_logger(
ch.setFormatter(formatter)
logger.addHandler(ch)
if filepath is not None:
if filepath is not None and Path(filepath).parent.exists():
fh = logging.FileHandler(filepath)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
else:
logger.warning("not set file logger")
return logger