update tester

This commit is contained in:
budui 2020-09-10 18:34:52 +08:00
parent 7ea9c6d0d5
commit 72d09aa483
7 changed files with 79 additions and 44 deletions

View File

@ -114,6 +114,7 @@ data:
mean: [ 0.5, 0.5, 0.5 ] mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ] std: [ 0.5, 0.5, 0.5 ]
test: test:
which: video_dataset
dataloader: dataloader:
batch_size: 8 batch_size: 8
shuffle: False shuffle: False

View File

@ -109,6 +109,7 @@ data:
mean: [ 0.5, 0.5, 0.5 ] mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ] std: [ 0.5, 0.5, 0.5 ]
test: test:
which: video_dataset
dataloader: dataloader:
batch_size: 8 batch_size: 8
shuffle: False shuffle: False
@ -116,14 +117,11 @@ data:
pin_memory: False pin_memory: False
drop_last: False drop_last: False
dataset: dataset:
_type: GenerationUnpairedDatasetWithEdge _type: GenerationUnpairedDataset
root_a: "/data/i2i/VoxCeleb2Anime/testA" root_a: "/data/i2i/VoxCeleb2Anime/testA"
root_b: "/data/i2i/VoxCeleb2Anime/testB" root_b: "/data/i2i/VoxCeleb2Anime/testB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges" with_path: True
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
edge_type: "landmark_hed"
random_pair: False random_pair: False
size: [ 128, 128 ]
pipeline: pipeline:
- Load - Load
- Resize: - Resize:

View File

@ -1,20 +1,19 @@
import os import os
import pickle import pickle
from pathlib import Path
from collections import defaultdict from collections import defaultdict
from PIL import Image from pathlib import Path
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import functional as F
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
import lmdb import lmdb
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
from torchvision.transforms import functional as F
from tqdm import tqdm from tqdm import tqdm
from .transform import transform_pipeline
from .registry import DATASET from .registry import DATASET
from .transform import transform_pipeline
from .util import dlib_landmark from .util import dlib_landmark
@ -160,9 +159,9 @@ class SingleFolderDataset(Dataset):
@DATASET.register_module() @DATASET.register_module()
class GenerationUnpairedDataset(Dataset): class GenerationUnpairedDataset(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline): def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
self.A = SingleFolderDataset(root_a, pipeline) self.A = SingleFolderDataset(root_a, pipeline, with_path)
self.B = SingleFolderDataset(root_b, pipeline) self.B = SingleFolderDataset(root_b, pipeline, with_path)
self.random_pair = random_pair self.random_pair = random_pair
def __getitem__(self, idx): def __getitem__(self, idx):
@ -186,7 +185,8 @@ def normalize_tensor(tensor):
@DATASET.register_module() @DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset): class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)): def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
with_path=False):
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"] assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
self.edge_type = edge_type self.edge_type = edge_type
self.size = size self.size = size
@ -197,6 +197,7 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
self.A = SingleFolderDataset(root_a, pipeline, with_path=True) self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True) self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair self.random_pair = random_pair
self.with_path = with_path
def get_edge(self, origin_path): def get_edge(self, origin_path):
op = Path(origin_path) op = Path(origin_path)
@ -224,6 +225,10 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
a_idx = idx % len(self.A) a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
if self.with_path:
output = {"a": self.A[a_idx], "b": self.B[b_idx]}
output["edge_a"] = output["a"][1]
return output
output = dict() output = dict()
output["a"], path_a = self.A[a_idx] output["a"], path_a = self.A[a_idx]
output["b"], path_b = self.B[b_idx] output["b"], path_b = self.B[b_idx]

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from omegaconf import OmegaConf from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model from engine.util.build import build_model
from loss.I2I.perceptual_loss import PerceptualLoss from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss from loss.gan import GANLoss
@ -101,6 +101,31 @@ class TSITEngineKernel(EngineKernel):
) )
class TSITTestEngineKernel(TestEngineKernel):
def __init__(self, config):
super().__init__(config)
def build_generators(self) -> dict:
generators = dict(
main=build_model(self.config.model.generator)
)
return generators
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
with torch.no_grad():
fake = self.generators["main"](content_img=batch["a"][0], style_img=batch["b"][0])
return {"a": fake.detach()}
def run(task, config, _): def run(task, config, _):
kernel = TSITEngineKernel(config) if task == "train":
run_kernel(task, config, kernel) kernel = TSITEngineKernel(config)
run_kernel(task, config, kernel)
elif task == "test":
kernel = TSITTestEngineKernel(config)
run_kernel(task, config, kernel)
else:
raise NotImplemented

View File

@ -204,13 +204,21 @@ def get_trainer(config, kernel: EngineKernel):
return trainer return trainer
def save_images_helper(output_dir, paths, images_list):
batch_size = len(paths)
for i in range(batch_size):
image_name = Path(paths[i]).name
img_list = [img[i] for img in images_list]
torchvision.utils.save_image(img_list, Path(output_dir) / image_name, nrow=len(img_list), padding=0,
normalize=True, range=(-1, 1))
def get_tester(config, kernel: TestEngineKernel): def get_tester(config, kernel: TestEngineKernel):
logger = logging.getLogger(config.name) logger = logging.getLogger(config.name)
def _step(engine, batch): def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device()) batch = convert_tensor(batch, idist.device())
fake = kernel.inference({"a": real_a})["a"] return {"batch": batch, "generated": kernel.inference(batch)}
return {"path": path, "img": [real_a.detach(), fake.detach()]}
tester = Engine(_step) tester = Engine(_step)
tester.logger = logger tester.logger = logger
@ -227,13 +235,14 @@ def get_tester(config, kernel: TestEngineKernel):
@tester.on(Events.ITERATION_COMPLETED) @tester.on(Events.ITERATION_COMPLETED)
def save_images(engine): def save_images(engine):
img_tensors = engine.state.output["img"] if engine.state.dataloader.dataset.__class__.__name__ == "SingleFolderDataset":
paths = engine.state.output["path"] images, paths = engine.state.output["batch"]
batch_size = img_tensors[0].size(0) save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"]])
for i in range(batch_size):
image_name = Path(paths[i]).name else:
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name, for k in engine.state.output['generated']:
nrow=len(img_tensors), padding=0, normalize=True, range=(-1, 1)) images, paths = engine.state.output["batch"][k]
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"][k]])
return tester return tester
@ -264,7 +273,7 @@ def run_kernel(task, config, kernel):
print(traceback.format_exc()) print(traceback.format_exc())
elif task == "test": elif task == "test":
assert config.resume_from is not None assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.video_dataset) test_dataset = data.DATASET.build_with(config.data.test[config.data.test.which])
logger.info(f"test with dataset:\n{test_dataset}") logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader) test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, kernel) tester = get_tester(config, kernel)

17
main.py
View File

@ -1,16 +1,15 @@
from pathlib import Path
from importlib import import_module from importlib import import_module
from pathlib import Path
import torch
import ignite
import ignite.distributed as idist
from ignite.utils import manual_seed
from util.misc import setup_logger
import fire import fire
import ignite
import ignite.distributed as idist
import torch
from ignite.utils import manual_seed
from omegaconf import OmegaConf from omegaconf import OmegaConf
from util.misc import setup_logger
def log_basic_info(logger, config): def log_basic_info(logger, config):
logger.info(f"Train {config.name}") logger.info(f"Train {config.name}")
@ -28,7 +27,7 @@ def log_basic_info(logger, config):
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False): def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
if setup_random_seed: if setup_random_seed:
manual_seed(config.misc.random_seed + idist.get_rank()) manual_seed(config.misc.random_seed + idist.get_rank())
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir output_dir = Path(config.result_dir) / config.name if config.output_dir is None else Path(config.output_dir)
config.output_dir = str(output_dir) config.output_dir = str(output_dir)
if setup_output_dir and config.resume_from is None: if setup_output_dir and config.resume_from is None:

View File

@ -1,6 +1,6 @@
import logging import logging
from typing import Optional
from pathlib import Path from pathlib import Path
from typing import Optional
def setup_logger( def setup_logger(
@ -8,7 +8,6 @@ def setup_logger(
level: int = logging.INFO, level: int = logging.INFO,
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s", logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
filepath: Optional[str] = None, filepath: Optional[str] = None,
file_level: int = logging.DEBUG,
distributed_rank: Optional[int] = None, distributed_rank: Optional[int] = None,
) -> logging.Logger: ) -> logging.Logger:
"""Setups logger: name, level, format etc. """Setups logger: name, level, format etc.
@ -18,7 +17,6 @@ def setup_logger(
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s` logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
filepath (str, optional): Optional logging file path. If not None, logs are written to the file. filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
file_level (int): Optional logging level for logging file.
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers. distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
If None, distributed_rank is initialized to the rank of process. If None, distributed_rank is initialized to the rank of process.