diff --git a/configs/synthesizers/UGATIT.yml b/configs/synthesizers/UGATIT.yml index 05d5311..07160fe 100644 --- a/configs/synthesizers/UGATIT.yml +++ b/configs/synthesizers/UGATIT.yml @@ -103,10 +103,9 @@ data: pin_memory: False drop_last: False dataset: - _type: GenerationUnpairedDataset - root_a: "/data/i2i/selfie2anime/testA" - root_b: "/data/i2i/selfie2anime/testB" - random_pair: False + _type: SingleFolderDataset + root: "path/to/images/" + with_path: True pipeline: - Load - Resize: diff --git a/data/dataset.py b/data/dataset.py index f8a57bc..37f19a9 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -128,7 +128,7 @@ class EpisodicDataset(Dataset): @DATASET.register_module() class SingleFolderDataset(Dataset): - def __init__(self, root, pipeline): + def __init__(self, root, pipeline, with_path=False): assert os.path.isdir(root) self.root = root samples = [] @@ -139,12 +139,16 @@ class SingleFolderDataset(Dataset): samples.append(path) self.samples = samples self.pipeline = transform_pipeline(pipeline) + self.with_path = with_path def __len__(self): return len(self.samples) def __getitem__(self, idx): - return self.pipeline(self.samples[idx]) + if not self.with_path: + return self.pipeline(self.samples[idx]) + else: + return self.pipeline(self.samples[idx]), self.samples[idx] def __repr__(self): return f"" diff --git a/engine/UGATIT.py b/engine/UGATIT.py index 6eb96e7..51956ad 100644 --- a/engine/UGATIT.py +++ b/engine/UGATIT.py @@ -1,9 +1,11 @@ from itertools import chain from math import ceil +from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F +import torchvision import ignite.distributed as idist from ignite.engine import Events, Engine @@ -175,7 +177,7 @@ def get_trainer(config, logger): to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) to_save.update({f"generator_{k}": generators[k] for k in generators}) to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators}) - setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, + setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) def output_transform(output): @@ -247,6 +249,43 @@ def get_trainer(config, logger): return trainer +def get_tester(config, logger): + generator_a2b = build_model(config.model.generator, config.distributed.model) + + def _step(engine, batch): + real_a, path = convert_tensor(batch, idist.device()) + with torch.no_grad(): + fake_b = generator_a2b(real_a)[0] + return {"path": path, "img": [real_a.detach(), fake_b.detach()]} + + tester = Engine(_step) + tester.logger = logger + + to_load = dict(generator_a2b=generator_a2b) + setup_common_handlers(tester, config, use_profiler=False, to_save=to_load) + + @tester.on(Events.STARTED) + def mkdir(engine): + img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images" + engine.state.img_output_dir = Path(img_output_dir) + if idist.get_rank() == 0 and not engine.state.img_output_dir.exists(): + engine.logger.info(f"mkdir {engine.state.img_output_dir}") + engine.state.img_output_dir.mkdir() + + @tester.on(Events.ITERATION_COMPLETED) + def save_images(engine): + img_tensors = engine.state.output["img"] + paths = engine.state.output["path"] + batch_size = img_tensors[0].size(0) + for i in range(batch_size): + # image_name = f"{engine.state.iteration * batch_size - batch_size + i + 1}.png" + image_name = Path(paths[i]).name + torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name, + nrow=len(img_tensors)) + + return tester + + def run(task, config, logger): assert torch.backends.cudnn.enabled torch.backends.cudnn.benchmark = True @@ -267,5 +306,16 @@ def run(task, config, logger): except Exception: import traceback print(traceback.format_exc()) + elif task == "test": + assert config.resume_from is not None + test_dataset = data.DATASET.build_with(config.data.test.dataset) + logger.info(f"test with dataset:\n{test_dataset}") + test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader) + tester = get_tester(config, logger) + try: + tester.run(test_data_loader, max_epochs=1) + except Exception: + import traceback + print(traceback.format_exc()) else: return NotImplemented(f"invalid task: {task}") diff --git a/util/handler.py b/util/handler.py index fc3df75..510390f 100644 --- a/util/handler.py +++ b/util/handler.py @@ -17,7 +17,7 @@ def empty_cuda_cache(_): def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True, - to_save=None, end_event=None, set_epoch_for_dist_sampler=True): + to_save=None, end_event=None, set_epoch_for_dist_sampler=False): """ Helper method to setup trainer with common handlers. 1. TerminateOnNan diff --git a/util/misc.py b/util/misc.py index 8462271..f9669db 100644 --- a/util/misc.py +++ b/util/misc.py @@ -1,5 +1,6 @@ import logging from typing import Optional +from pathlib import Path def setup_logger( @@ -76,10 +77,12 @@ def setup_logger( ch.setFormatter(formatter) logger.addHandler(ch) - if filepath is not None: + if filepath is not None and Path(filepath).parent.exists(): fh = logging.FileHandler(filepath) fh.setLevel(logging.DEBUG) fh.setFormatter(formatter) logger.addHandler(fh) + else: + logger.warning("not set file logger") return logger