From 72d09aa48316506b0d2559633cac819d017984d2 Mon Sep 17 00:00:00 2001 From: budui Date: Thu, 10 Sep 2020 18:34:52 +0800 Subject: [PATCH] update tester --- configs/synthesizers/TAFG.yml | 1 + configs/synthesizers/TSIT.yml | 8 +++----- data/dataset.py | 31 ++++++++++++++++++------------- engine/TSIT.py | 31 ++++++++++++++++++++++++++++--- engine/base/i2i.py | 31 ++++++++++++++++++++----------- main.py | 17 ++++++++--------- util/misc.py | 4 +--- 7 files changed, 79 insertions(+), 44 deletions(-) diff --git a/configs/synthesizers/TAFG.yml b/configs/synthesizers/TAFG.yml index 2ad4419..2d85a03 100644 --- a/configs/synthesizers/TAFG.yml +++ b/configs/synthesizers/TAFG.yml @@ -114,6 +114,7 @@ data: mean: [ 0.5, 0.5, 0.5 ] std: [ 0.5, 0.5, 0.5 ] test: + which: video_dataset dataloader: batch_size: 8 shuffle: False diff --git a/configs/synthesizers/TSIT.yml b/configs/synthesizers/TSIT.yml index a6edfbf..b2192a3 100644 --- a/configs/synthesizers/TSIT.yml +++ b/configs/synthesizers/TSIT.yml @@ -109,6 +109,7 @@ data: mean: [ 0.5, 0.5, 0.5 ] std: [ 0.5, 0.5, 0.5 ] test: + which: video_dataset dataloader: batch_size: 8 shuffle: False @@ -116,14 +117,11 @@ data: pin_memory: False drop_last: False dataset: - _type: GenerationUnpairedDatasetWithEdge + _type: GenerationUnpairedDataset root_a: "/data/i2i/VoxCeleb2Anime/testA" root_b: "/data/i2i/VoxCeleb2Anime/testB" - edges_path: "/data/i2i/VoxCeleb2Anime/edges" - landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks" - edge_type: "landmark_hed" + with_path: True random_pair: False - size: [ 128, 128 ] pipeline: - Load - Resize: diff --git a/data/dataset.py b/data/dataset.py index a26db32..23f074f 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -1,20 +1,19 @@ import os import pickle -from pathlib import Path from collections import defaultdict -from PIL import Image - -import torch -from torch.utils.data import Dataset -from torchvision.datasets import ImageFolder -from torchvision.transforms import functional as F -from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader +from pathlib import Path import lmdb +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision.datasets import ImageFolder +from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS +from torchvision.transforms import functional as F from tqdm import tqdm -from .transform import transform_pipeline from .registry import DATASET +from .transform import transform_pipeline from .util import dlib_landmark @@ -160,9 +159,9 @@ class SingleFolderDataset(Dataset): @DATASET.register_module() class GenerationUnpairedDataset(Dataset): - def __init__(self, root_a, root_b, random_pair, pipeline): - self.A = SingleFolderDataset(root_a, pipeline) - self.B = SingleFolderDataset(root_b, pipeline) + def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False): + self.A = SingleFolderDataset(root_a, pipeline, with_path) + self.B = SingleFolderDataset(root_b, pipeline, with_path) self.random_pair = random_pair def __getitem__(self, idx): @@ -186,7 +185,8 @@ def normalize_tensor(tensor): @DATASET.register_module() class GenerationUnpairedDatasetWithEdge(Dataset): - def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)): + def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256), + with_path=False): assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"] self.edge_type = edge_type self.size = size @@ -197,6 +197,7 @@ class GenerationUnpairedDatasetWithEdge(Dataset): self.A = SingleFolderDataset(root_a, pipeline, with_path=True) self.B = SingleFolderDataset(root_b, pipeline, with_path=True) self.random_pair = random_pair + self.with_path = with_path def get_edge(self, origin_path): op = Path(origin_path) @@ -224,6 +225,10 @@ class GenerationUnpairedDatasetWithEdge(Dataset): def __getitem__(self, idx): a_idx = idx % len(self.A) b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() + if self.with_path: + output = {"a": self.A[a_idx], "b": self.B[b_idx]} + output["edge_a"] = output["a"][1] + return output output = dict() output["a"], path_a = self.A[a_idx] output["b"], path_b = self.B[b_idx] diff --git a/engine/TSIT.py b/engine/TSIT.py index 6ed2eeb..f93838c 100644 --- a/engine/TSIT.py +++ b/engine/TSIT.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from omegaconf import OmegaConf -from engine.base.i2i import EngineKernel, run_kernel +from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel from engine.util.build import build_model from loss.I2I.perceptual_loss import PerceptualLoss from loss.gan import GANLoss @@ -101,6 +101,31 @@ class TSITEngineKernel(EngineKernel): ) +class TSITTestEngineKernel(TestEngineKernel): + def __init__(self, config): + super().__init__(config) + + def build_generators(self) -> dict: + generators = dict( + main=build_model(self.config.model.generator) + ) + return generators + + def to_load(self): + return {f"generator_{k}": self.generators[k] for k in self.generators} + + def inference(self, batch): + with torch.no_grad(): + fake = self.generators["main"](content_img=batch["a"][0], style_img=batch["b"][0]) + return {"a": fake.detach()} + + def run(task, config, _): - kernel = TSITEngineKernel(config) - run_kernel(task, config, kernel) + if task == "train": + kernel = TSITEngineKernel(config) + run_kernel(task, config, kernel) + elif task == "test": + kernel = TSITTestEngineKernel(config) + run_kernel(task, config, kernel) + else: + raise NotImplemented diff --git a/engine/base/i2i.py b/engine/base/i2i.py index 9dbba6e..414db76 100644 --- a/engine/base/i2i.py +++ b/engine/base/i2i.py @@ -204,13 +204,21 @@ def get_trainer(config, kernel: EngineKernel): return trainer +def save_images_helper(output_dir, paths, images_list): + batch_size = len(paths) + for i in range(batch_size): + image_name = Path(paths[i]).name + img_list = [img[i] for img in images_list] + torchvision.utils.save_image(img_list, Path(output_dir) / image_name, nrow=len(img_list), padding=0, + normalize=True, range=(-1, 1)) + + def get_tester(config, kernel: TestEngineKernel): logger = logging.getLogger(config.name) def _step(engine, batch): - real_a, path = convert_tensor(batch, idist.device()) - fake = kernel.inference({"a": real_a})["a"] - return {"path": path, "img": [real_a.detach(), fake.detach()]} + batch = convert_tensor(batch, idist.device()) + return {"batch": batch, "generated": kernel.inference(batch)} tester = Engine(_step) tester.logger = logger @@ -227,13 +235,14 @@ def get_tester(config, kernel: TestEngineKernel): @tester.on(Events.ITERATION_COMPLETED) def save_images(engine): - img_tensors = engine.state.output["img"] - paths = engine.state.output["path"] - batch_size = img_tensors[0].size(0) - for i in range(batch_size): - image_name = Path(paths[i]).name - torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name, - nrow=len(img_tensors), padding=0, normalize=True, range=(-1, 1)) + if engine.state.dataloader.dataset.__class__.__name__ == "SingleFolderDataset": + images, paths = engine.state.output["batch"] + save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"]]) + + else: + for k in engine.state.output['generated']: + images, paths = engine.state.output["batch"][k] + save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"][k]]) return tester @@ -264,7 +273,7 @@ def run_kernel(task, config, kernel): print(traceback.format_exc()) elif task == "test": assert config.resume_from is not None - test_dataset = data.DATASET.build_with(config.data.test.video_dataset) + test_dataset = data.DATASET.build_with(config.data.test[config.data.test.which]) logger.info(f"test with dataset:\n{test_dataset}") test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader) tester = get_tester(config, kernel) diff --git a/main.py b/main.py index a515813..c4a010e 100644 --- a/main.py +++ b/main.py @@ -1,16 +1,15 @@ -from pathlib import Path from importlib import import_module - -import torch - -import ignite -import ignite.distributed as idist -from ignite.utils import manual_seed -from util.misc import setup_logger +from pathlib import Path import fire +import ignite +import ignite.distributed as idist +import torch +from ignite.utils import manual_seed from omegaconf import OmegaConf +from util.misc import setup_logger + def log_basic_info(logger, config): logger.info(f"Train {config.name}") @@ -28,7 +27,7 @@ def log_basic_info(logger, config): def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False): if setup_random_seed: manual_seed(config.misc.random_seed + idist.get_rank()) - output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir + output_dir = Path(config.result_dir) / config.name if config.output_dir is None else Path(config.output_dir) config.output_dir = str(output_dir) if setup_output_dir and config.resume_from is None: diff --git a/util/misc.py b/util/misc.py index f9669db..2bd6d98 100644 --- a/util/misc.py +++ b/util/misc.py @@ -1,6 +1,6 @@ import logging -from typing import Optional from pathlib import Path +from typing import Optional def setup_logger( @@ -8,7 +8,6 @@ def setup_logger( level: int = logging.INFO, logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s", filepath: Optional[str] = None, - file_level: int = logging.DEBUG, distributed_rank: Optional[int] = None, ) -> logging.Logger: """Setups logger: name, level, format etc. @@ -18,7 +17,6 @@ def setup_logger( level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s` filepath (str, optional): Optional logging file path. If not None, logs are written to the file. - file_level (int): Optional logging level for logging file. distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers. If None, distributed_rank is initialized to the rank of process.