update tester
This commit is contained in:
parent
7ea9c6d0d5
commit
72d09aa483
@ -114,6 +114,7 @@ data:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
which: video_dataset
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
shuffle: False
|
||||
|
||||
@ -109,6 +109,7 @@ data:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
which: video_dataset
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
shuffle: False
|
||||
@ -116,14 +117,11 @@ data:
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDatasetWithEdge
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/VoxCeleb2Anime/testA"
|
||||
root_b: "/data/i2i/VoxCeleb2Anime/testB"
|
||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||
edge_type: "landmark_hed"
|
||||
with_path: True
|
||||
random_pair: False
|
||||
size: [ 128, 128 ]
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
|
||||
@ -1,20 +1,19 @@
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.transforms import functional as F
|
||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
|
||||
from pathlib import Path
|
||||
|
||||
import lmdb
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
|
||||
from torchvision.transforms import functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
from .transform import transform_pipeline
|
||||
from .registry import DATASET
|
||||
from .transform import transform_pipeline
|
||||
from .util import dlib_landmark
|
||||
|
||||
|
||||
@ -160,9 +159,9 @@ class SingleFolderDataset(Dataset):
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDataset(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline):
|
||||
self.A = SingleFolderDataset(root_a, pipeline)
|
||||
self.B = SingleFolderDataset(root_b, pipeline)
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path)
|
||||
self.random_pair = random_pair
|
||||
|
||||
def __getitem__(self, idx):
|
||||
@ -186,7 +185,8 @@ def normalize_tensor(tensor):
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
|
||||
with_path=False):
|
||||
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
|
||||
self.edge_type = edge_type
|
||||
self.size = size
|
||||
@ -197,6 +197,7 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||
self.random_pair = random_pair
|
||||
self.with_path = with_path
|
||||
|
||||
def get_edge(self, origin_path):
|
||||
op = Path(origin_path)
|
||||
@ -224,6 +225,10 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
if self.with_path:
|
||||
output = {"a": self.A[a_idx], "b": self.B[b_idx]}
|
||||
output["edge_a"] = output["a"][1]
|
||||
return output
|
||||
output = dict()
|
||||
output["a"], path_a = self.A[a_idx]
|
||||
output["b"], path_b = self.B[b_idx]
|
||||
|
||||
@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||
from engine.util.build import build_model
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
@ -101,6 +101,31 @@ class TSITEngineKernel(EngineKernel):
|
||||
)
|
||||
|
||||
|
||||
class TSITTestEngineKernel(TestEngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def build_generators(self) -> dict:
|
||||
generators = dict(
|
||||
main=build_model(self.config.model.generator)
|
||||
)
|
||||
return generators
|
||||
|
||||
def to_load(self):
|
||||
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
||||
|
||||
def inference(self, batch):
|
||||
with torch.no_grad():
|
||||
fake = self.generators["main"](content_img=batch["a"][0], style_img=batch["b"][0])
|
||||
return {"a": fake.detach()}
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
kernel = TSITEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
if task == "train":
|
||||
kernel = TSITEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
elif task == "test":
|
||||
kernel = TSITTestEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
@ -204,13 +204,21 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
return trainer
|
||||
|
||||
|
||||
def save_images_helper(output_dir, paths, images_list):
|
||||
batch_size = len(paths)
|
||||
for i in range(batch_size):
|
||||
image_name = Path(paths[i]).name
|
||||
img_list = [img[i] for img in images_list]
|
||||
torchvision.utils.save_image(img_list, Path(output_dir) / image_name, nrow=len(img_list), padding=0,
|
||||
normalize=True, range=(-1, 1))
|
||||
|
||||
|
||||
def get_tester(config, kernel: TestEngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
|
||||
def _step(engine, batch):
|
||||
real_a, path = convert_tensor(batch, idist.device())
|
||||
fake = kernel.inference({"a": real_a})["a"]
|
||||
return {"path": path, "img": [real_a.detach(), fake.detach()]}
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
return {"batch": batch, "generated": kernel.inference(batch)}
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
@ -227,13 +235,14 @@ def get_tester(config, kernel: TestEngineKernel):
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
img_tensors = engine.state.output["img"]
|
||||
paths = engine.state.output["path"]
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
image_name = Path(paths[i]).name
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
|
||||
nrow=len(img_tensors), padding=0, normalize=True, range=(-1, 1))
|
||||
if engine.state.dataloader.dataset.__class__.__name__ == "SingleFolderDataset":
|
||||
images, paths = engine.state.output["batch"]
|
||||
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"]])
|
||||
|
||||
else:
|
||||
for k in engine.state.output['generated']:
|
||||
images, paths = engine.state.output["batch"][k]
|
||||
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"][k]])
|
||||
|
||||
return tester
|
||||
|
||||
@ -264,7 +273,7 @@ def run_kernel(task, config, kernel):
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
assert config.resume_from is not None
|
||||
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
|
||||
test_dataset = data.DATASET.build_with(config.data.test[config.data.test.which])
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, kernel)
|
||||
|
||||
17
main.py
17
main.py
@ -1,16 +1,15 @@
|
||||
from pathlib import Path
|
||||
from importlib import import_module
|
||||
|
||||
import torch
|
||||
|
||||
import ignite
|
||||
import ignite.distributed as idist
|
||||
from ignite.utils import manual_seed
|
||||
from util.misc import setup_logger
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import ignite
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
from ignite.utils import manual_seed
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from util.misc import setup_logger
|
||||
|
||||
|
||||
def log_basic_info(logger, config):
|
||||
logger.info(f"Train {config.name}")
|
||||
@ -28,7 +27,7 @@ def log_basic_info(logger, config):
|
||||
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
|
||||
if setup_random_seed:
|
||||
manual_seed(config.misc.random_seed + idist.get_rank())
|
||||
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
|
||||
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else Path(config.output_dir)
|
||||
config.output_dir = str(output_dir)
|
||||
|
||||
if setup_output_dir and config.resume_from is None:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def setup_logger(
|
||||
@ -8,7 +8,6 @@ def setup_logger(
|
||||
level: int = logging.INFO,
|
||||
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
|
||||
filepath: Optional[str] = None,
|
||||
file_level: int = logging.DEBUG,
|
||||
distributed_rank: Optional[int] = None,
|
||||
) -> logging.Logger:
|
||||
"""Setups logger: name, level, format etc.
|
||||
@ -18,7 +17,6 @@ def setup_logger(
|
||||
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
|
||||
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
|
||||
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
|
||||
file_level (int): Optional logging level for logging file.
|
||||
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
|
||||
If None, distributed_rank is initialized to the rank of process.
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user