Compare commits

..

No commits in common. "09db0a413f8a1c10564571583c07e0b1da3b4dd4" and "31aafb347041707f59ce52e198712c06749d7a92" have entirely different histories.

4 changed files with 5 additions and 62 deletions

View File

@ -128,7 +128,7 @@ class EpisodicDataset(Dataset):
@DATASET.register_module()
class SingleFolderDataset(Dataset):
def __init__(self, root, pipeline, with_path=False):
def __init__(self, root, pipeline):
assert os.path.isdir(root)
self.root = root
samples = []
@ -139,16 +139,12 @@ class SingleFolderDataset(Dataset):
samples.append(path)
self.samples = samples
self.pipeline = transform_pipeline(pipeline)
self.with_path = with_path
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
if not self.with_path:
return self.pipeline(self.samples[idx])
else:
return self.pipeline(self.samples[idx]), self.samples[idx]
return self.pipeline(self.samples[idx])
def __repr__(self):
return f"<SingleFolderDataset root={self.root} len={len(self)}>"

View File

@ -1,11 +1,9 @@
from itertools import chain
from math import ceil
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import ignite.distributed as idist
from ignite.engine import Events, Engine
@ -177,7 +175,7 @@ def get_trainer(config, logger):
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update({f"generator_{k}": generators[k] for k in generators})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
@ -249,43 +247,6 @@ def get_trainer(config, logger):
return trainer
def get_tester(config, logger):
generator_a2b = build_model(config.model.generator, config.distributed.model)
def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device())
with torch.no_grad():
fake_b = generator_a2b(real_a)[0]
return {"path": path, "img": [real_a.detach(), fake_b.detach()]}
tester = Engine(_step)
tester.logger = logger
to_load = dict(generator_a2b=generator_a2b)
setup_common_handlers(tester, config, use_profiler=False, to_save=to_load)
@tester.on(Events.STARTED)
def mkdir(engine):
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
engine.state.img_output_dir = Path(img_output_dir)
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
engine.state.img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output["img"]
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
# image_name = f"{engine.state.iteration * batch_size - batch_size + i + 1}.png"
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors))
return tester
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
@ -306,16 +267,5 @@ def run(task, config, logger):
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, logger)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

View File

@ -17,7 +17,7 @@ def empty_cuda_cache(_):
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
to_save=None, end_event=None, set_epoch_for_dist_sampler=False):
to_save=None, end_event=None, set_epoch_for_dist_sampler=True):
"""
Helper method to setup trainer with common handlers.
1. TerminateOnNan

View File

@ -1,6 +1,5 @@
import logging
from typing import Optional
from pathlib import Path
def setup_logger(
@ -77,12 +76,10 @@ def setup_logger(
ch.setFormatter(formatter)
logger.addHandler(ch)
if filepath is not None and Path(filepath).parent.exists():
if filepath is not None:
fh = logging.FileHandler(filepath)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
else:
logger.warning("not set file logger")
return logger