Compare commits
2 Commits
31aafb3470
...
09db0a413f
| Author | SHA1 | Date | |
|---|---|---|---|
| 09db0a413f | |||
| 56b355737f |
@ -128,7 +128,7 @@ class EpisodicDataset(Dataset):
|
||||
|
||||
@DATASET.register_module()
|
||||
class SingleFolderDataset(Dataset):
|
||||
def __init__(self, root, pipeline):
|
||||
def __init__(self, root, pipeline, with_path=False):
|
||||
assert os.path.isdir(root)
|
||||
self.root = root
|
||||
samples = []
|
||||
@ -139,12 +139,16 @@ class SingleFolderDataset(Dataset):
|
||||
samples.append(path)
|
||||
self.samples = samples
|
||||
self.pipeline = transform_pipeline(pipeline)
|
||||
self.with_path = with_path
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.pipeline(self.samples[idx])
|
||||
if not self.with_path:
|
||||
return self.pipeline(self.samples[idx])
|
||||
else:
|
||||
return self.pipeline(self.samples[idx]), self.samples[idx]
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SingleFolderDataset root={self.root} len={len(self)}>"
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from itertools import chain
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
@ -175,7 +177,7 @@ def get_trainer(config, logger):
|
||||
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||
to_save.update({f"generator_{k}": generators[k] for k in generators})
|
||||
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True,
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
|
||||
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
|
||||
def output_transform(output):
|
||||
@ -247,6 +249,43 @@ def get_trainer(config, logger):
|
||||
return trainer
|
||||
|
||||
|
||||
def get_tester(config, logger):
|
||||
generator_a2b = build_model(config.model.generator, config.distributed.model)
|
||||
|
||||
def _step(engine, batch):
|
||||
real_a, path = convert_tensor(batch, idist.device())
|
||||
with torch.no_grad():
|
||||
fake_b = generator_a2b(real_a)[0]
|
||||
return {"path": path, "img": [real_a.detach(), fake_b.detach()]}
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
|
||||
to_load = dict(generator_a2b=generator_a2b)
|
||||
setup_common_handlers(tester, config, use_profiler=False, to_save=to_load)
|
||||
|
||||
@tester.on(Events.STARTED)
|
||||
def mkdir(engine):
|
||||
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
|
||||
engine.state.img_output_dir = Path(img_output_dir)
|
||||
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
|
||||
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
|
||||
engine.state.img_output_dir.mkdir()
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
img_tensors = engine.state.output["img"]
|
||||
paths = engine.state.output["path"]
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
# image_name = f"{engine.state.iteration * batch_size - batch_size + i + 1}.png"
|
||||
image_name = Path(paths[i]).name
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
|
||||
nrow=len(img_tensors))
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
def run(task, config, logger):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
@ -267,5 +306,16 @@ def run(task, config, logger):
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
assert config.resume_from is not None
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, logger)
|
||||
try:
|
||||
tester.run(test_data_loader, max_epochs=1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
|
||||
@ -17,7 +17,7 @@ def empty_cuda_cache(_):
|
||||
|
||||
|
||||
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
|
||||
to_save=None, end_event=None, set_epoch_for_dist_sampler=True):
|
||||
to_save=None, end_event=None, set_epoch_for_dist_sampler=False):
|
||||
"""
|
||||
Helper method to setup trainer with common handlers.
|
||||
1. TerminateOnNan
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def setup_logger(
|
||||
@ -76,10 +77,12 @@ def setup_logger(
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
if filepath is not None:
|
||||
if filepath is not None and Path(filepath).parent.exists():
|
||||
fh = logging.FileHandler(filepath)
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
else:
|
||||
logger.warning("not set file logger")
|
||||
|
||||
return logger
|
||||
|
||||
Loading…
Reference in New Issue
Block a user