update tester

This commit is contained in:
budui 2020-09-10 18:34:52 +08:00
parent 7ea9c6d0d5
commit 72d09aa483
7 changed files with 79 additions and 44 deletions

View File

@ -114,6 +114,7 @@ data:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 8
shuffle: False

View File

@ -109,6 +109,7 @@ data:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 8
shuffle: False
@ -116,14 +117,11 @@ data:
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDatasetWithEdge
_type: GenerationUnpairedDataset
root_a: "/data/i2i/VoxCeleb2Anime/testA"
root_b: "/data/i2i/VoxCeleb2Anime/testB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
edge_type: "landmark_hed"
with_path: True
random_pair: False
size: [ 128, 128 ]
pipeline:
- Load
- Resize:

View File

@ -1,20 +1,19 @@
import os
import pickle
from pathlib import Path
from collections import defaultdict
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import functional as F
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
from pathlib import Path
import lmdb
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
from torchvision.transforms import functional as F
from tqdm import tqdm
from .transform import transform_pipeline
from .registry import DATASET
from .transform import transform_pipeline
from .util import dlib_landmark
@ -160,9 +159,9 @@ class SingleFolderDataset(Dataset):
@DATASET.register_module()
class GenerationUnpairedDataset(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline):
self.A = SingleFolderDataset(root_a, pipeline)
self.B = SingleFolderDataset(root_b, pipeline)
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
self.A = SingleFolderDataset(root_a, pipeline, with_path)
self.B = SingleFolderDataset(root_b, pipeline, with_path)
self.random_pair = random_pair
def __getitem__(self, idx):
@ -186,7 +185,8 @@ def normalize_tensor(tensor):
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
with_path=False):
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
self.edge_type = edge_type
self.size = size
@ -197,6 +197,7 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
self.with_path = with_path
def get_edge(self, origin_path):
op = Path(origin_path)
@ -224,6 +225,10 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
if self.with_path:
output = {"a": self.A[a_idx], "b": self.B[b_idx]}
output["edge_a"] = output["a"][1]
return output
output = dict()
output["a"], path_a = self.A[a_idx]
output["b"], path_b = self.B[b_idx]

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
@ -101,6 +101,31 @@ class TSITEngineKernel(EngineKernel):
)
class TSITTestEngineKernel(TestEngineKernel):
def __init__(self, config):
super().__init__(config)
def build_generators(self) -> dict:
generators = dict(
main=build_model(self.config.model.generator)
)
return generators
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
with torch.no_grad():
fake = self.generators["main"](content_img=batch["a"][0], style_img=batch["b"][0])
return {"a": fake.detach()}
def run(task, config, _):
kernel = TSITEngineKernel(config)
run_kernel(task, config, kernel)
if task == "train":
kernel = TSITEngineKernel(config)
run_kernel(task, config, kernel)
elif task == "test":
kernel = TSITTestEngineKernel(config)
run_kernel(task, config, kernel)
else:
raise NotImplemented

View File

@ -204,13 +204,21 @@ def get_trainer(config, kernel: EngineKernel):
return trainer
def save_images_helper(output_dir, paths, images_list):
batch_size = len(paths)
for i in range(batch_size):
image_name = Path(paths[i]).name
img_list = [img[i] for img in images_list]
torchvision.utils.save_image(img_list, Path(output_dir) / image_name, nrow=len(img_list), padding=0,
normalize=True, range=(-1, 1))
def get_tester(config, kernel: TestEngineKernel):
logger = logging.getLogger(config.name)
def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device())
fake = kernel.inference({"a": real_a})["a"]
return {"path": path, "img": [real_a.detach(), fake.detach()]}
batch = convert_tensor(batch, idist.device())
return {"batch": batch, "generated": kernel.inference(batch)}
tester = Engine(_step)
tester.logger = logger
@ -227,13 +235,14 @@ def get_tester(config, kernel: TestEngineKernel):
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output["img"]
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors), padding=0, normalize=True, range=(-1, 1))
if engine.state.dataloader.dataset.__class__.__name__ == "SingleFolderDataset":
images, paths = engine.state.output["batch"]
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"]])
else:
for k in engine.state.output['generated']:
images, paths = engine.state.output["batch"][k]
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"][k]])
return tester
@ -264,7 +273,7 @@ def run_kernel(task, config, kernel):
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
test_dataset = data.DATASET.build_with(config.data.test[config.data.test.which])
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, kernel)

17
main.py
View File

@ -1,16 +1,15 @@
from pathlib import Path
from importlib import import_module
import torch
import ignite
import ignite.distributed as idist
from ignite.utils import manual_seed
from util.misc import setup_logger
from pathlib import Path
import fire
import ignite
import ignite.distributed as idist
import torch
from ignite.utils import manual_seed
from omegaconf import OmegaConf
from util.misc import setup_logger
def log_basic_info(logger, config):
logger.info(f"Train {config.name}")
@ -28,7 +27,7 @@ def log_basic_info(logger, config):
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
if setup_random_seed:
manual_seed(config.misc.random_seed + idist.get_rank())
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else Path(config.output_dir)
config.output_dir = str(output_dir)
if setup_output_dir and config.resume_from is None:

View File

@ -1,6 +1,6 @@
import logging
from typing import Optional
from pathlib import Path
from typing import Optional
def setup_logger(
@ -8,7 +8,6 @@ def setup_logger(
level: int = logging.INFO,
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
filepath: Optional[str] = None,
file_level: int = logging.DEBUG,
distributed_rank: Optional[int] = None,
) -> logging.Logger:
"""Setups logger: name, level, format etc.
@ -18,7 +17,6 @@ def setup_logger(
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
file_level (int): Optional logging level for logging file.
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
If None, distributed_rank is initialized to the rank of process.