Compare commits

...

2 Commits

Author SHA1 Message Date
09db0a413f restore UGATIT.yml 2020-08-25 09:33:42 +08:00
56b355737f add test handler 2020-08-24 22:46:36 +08:00
4 changed files with 62 additions and 5 deletions

View File

@ -128,7 +128,7 @@ class EpisodicDataset(Dataset):
@DATASET.register_module() @DATASET.register_module()
class SingleFolderDataset(Dataset): class SingleFolderDataset(Dataset):
def __init__(self, root, pipeline): def __init__(self, root, pipeline, with_path=False):
assert os.path.isdir(root) assert os.path.isdir(root)
self.root = root self.root = root
samples = [] samples = []
@ -139,12 +139,16 @@ class SingleFolderDataset(Dataset):
samples.append(path) samples.append(path)
self.samples = samples self.samples = samples
self.pipeline = transform_pipeline(pipeline) self.pipeline = transform_pipeline(pipeline)
self.with_path = with_path
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.pipeline(self.samples[idx]) if not self.with_path:
return self.pipeline(self.samples[idx])
else:
return self.pipeline(self.samples[idx]), self.samples[idx]
def __repr__(self): def __repr__(self):
return f"<SingleFolderDataset root={self.root} len={len(self)}>" return f"<SingleFolderDataset root={self.root} len={len(self)}>"

View File

@ -1,9 +1,11 @@
from itertools import chain from itertools import chain
from math import ceil from math import ceil
from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision
import ignite.distributed as idist import ignite.distributed as idist
from ignite.engine import Events, Engine from ignite.engine import Events, Engine
@ -175,7 +177,7 @@ def get_trainer(config, logger):
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update({f"generator_{k}": generators[k] for k in generators}) to_save.update({f"generator_{k}": generators[k] for k in generators})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators}) to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output): def output_transform(output):
@ -247,6 +249,43 @@ def get_trainer(config, logger):
return trainer return trainer
def get_tester(config, logger):
generator_a2b = build_model(config.model.generator, config.distributed.model)
def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device())
with torch.no_grad():
fake_b = generator_a2b(real_a)[0]
return {"path": path, "img": [real_a.detach(), fake_b.detach()]}
tester = Engine(_step)
tester.logger = logger
to_load = dict(generator_a2b=generator_a2b)
setup_common_handlers(tester, config, use_profiler=False, to_save=to_load)
@tester.on(Events.STARTED)
def mkdir(engine):
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
engine.state.img_output_dir = Path(img_output_dir)
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
engine.state.img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output["img"]
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
# image_name = f"{engine.state.iteration * batch_size - batch_size + i + 1}.png"
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors))
return tester
def run(task, config, logger): def run(task, config, logger):
assert torch.backends.cudnn.enabled assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -267,5 +306,16 @@ def run(task, config, logger):
except Exception: except Exception:
import traceback import traceback
print(traceback.format_exc()) print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, logger)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else: else:
return NotImplemented(f"invalid task: {task}") return NotImplemented(f"invalid task: {task}")

View File

@ -17,7 +17,7 @@ def empty_cuda_cache(_):
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True, def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
to_save=None, end_event=None, set_epoch_for_dist_sampler=True): to_save=None, end_event=None, set_epoch_for_dist_sampler=False):
""" """
Helper method to setup trainer with common handlers. Helper method to setup trainer with common handlers.
1. TerminateOnNan 1. TerminateOnNan

View File

@ -1,5 +1,6 @@
import logging import logging
from typing import Optional from typing import Optional
from pathlib import Path
def setup_logger( def setup_logger(
@ -76,10 +77,12 @@ def setup_logger(
ch.setFormatter(formatter) ch.setFormatter(formatter)
logger.addHandler(ch) logger.addHandler(ch)
if filepath is not None: if filepath is not None and Path(filepath).parent.exists():
fh = logging.FileHandler(filepath) fh = logging.FileHandler(filepath)
fh.setLevel(logging.DEBUG) fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter) fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
else:
logger.warning("not set file logger")
return logger return logger