update tester
This commit is contained in:
parent
7ea9c6d0d5
commit
72d09aa483
@ -114,6 +114,7 @@ data:
|
|||||||
mean: [ 0.5, 0.5, 0.5 ]
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
std: [ 0.5, 0.5, 0.5 ]
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
test:
|
test:
|
||||||
|
which: video_dataset
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
shuffle: False
|
shuffle: False
|
||||||
|
|||||||
@ -109,6 +109,7 @@ data:
|
|||||||
mean: [ 0.5, 0.5, 0.5 ]
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
std: [ 0.5, 0.5, 0.5 ]
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
test:
|
test:
|
||||||
|
which: video_dataset
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
shuffle: False
|
shuffle: False
|
||||||
@ -116,14 +117,11 @@ data:
|
|||||||
pin_memory: False
|
pin_memory: False
|
||||||
drop_last: False
|
drop_last: False
|
||||||
dataset:
|
dataset:
|
||||||
_type: GenerationUnpairedDatasetWithEdge
|
_type: GenerationUnpairedDataset
|
||||||
root_a: "/data/i2i/VoxCeleb2Anime/testA"
|
root_a: "/data/i2i/VoxCeleb2Anime/testA"
|
||||||
root_b: "/data/i2i/VoxCeleb2Anime/testB"
|
root_b: "/data/i2i/VoxCeleb2Anime/testB"
|
||||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
with_path: True
|
||||||
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
|
||||||
edge_type: "landmark_hed"
|
|
||||||
random_pair: False
|
random_pair: False
|
||||||
size: [ 128, 128 ]
|
|
||||||
pipeline:
|
pipeline:
|
||||||
- Load
|
- Load
|
||||||
- Resize:
|
- Resize:
|
||||||
|
|||||||
@ -1,20 +1,19 @@
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from PIL import Image
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from torchvision.datasets import ImageFolder
|
|
||||||
from torchvision.transforms import functional as F
|
|
||||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
|
|
||||||
|
|
||||||
import lmdb
|
import lmdb
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision.datasets import ImageFolder
|
||||||
|
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .transform import transform_pipeline
|
|
||||||
from .registry import DATASET
|
from .registry import DATASET
|
||||||
|
from .transform import transform_pipeline
|
||||||
from .util import dlib_landmark
|
from .util import dlib_landmark
|
||||||
|
|
||||||
|
|
||||||
@ -160,9 +159,9 @@ class SingleFolderDataset(Dataset):
|
|||||||
|
|
||||||
@DATASET.register_module()
|
@DATASET.register_module()
|
||||||
class GenerationUnpairedDataset(Dataset):
|
class GenerationUnpairedDataset(Dataset):
|
||||||
def __init__(self, root_a, root_b, random_pair, pipeline):
|
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
|
||||||
self.A = SingleFolderDataset(root_a, pipeline)
|
self.A = SingleFolderDataset(root_a, pipeline, with_path)
|
||||||
self.B = SingleFolderDataset(root_b, pipeline)
|
self.B = SingleFolderDataset(root_b, pipeline, with_path)
|
||||||
self.random_pair = random_pair
|
self.random_pair = random_pair
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
@ -186,7 +185,8 @@ def normalize_tensor(tensor):
|
|||||||
|
|
||||||
@DATASET.register_module()
|
@DATASET.register_module()
|
||||||
class GenerationUnpairedDatasetWithEdge(Dataset):
|
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)):
|
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
|
||||||
|
with_path=False):
|
||||||
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
|
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
|
||||||
self.edge_type = edge_type
|
self.edge_type = edge_type
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -197,6 +197,7 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
|||||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||||
self.random_pair = random_pair
|
self.random_pair = random_pair
|
||||||
|
self.with_path = with_path
|
||||||
|
|
||||||
def get_edge(self, origin_path):
|
def get_edge(self, origin_path):
|
||||||
op = Path(origin_path)
|
op = Path(origin_path)
|
||||||
@ -224,6 +225,10 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
|||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
a_idx = idx % len(self.A)
|
a_idx = idx % len(self.A)
|
||||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||||
|
if self.with_path:
|
||||||
|
output = {"a": self.A[a_idx], "b": self.B[b_idx]}
|
||||||
|
output["edge_a"] = output["a"][1]
|
||||||
|
return output
|
||||||
output = dict()
|
output = dict()
|
||||||
output["a"], path_a = self.A[a_idx]
|
output["a"], path_a = self.A[a_idx]
|
||||||
output["b"], path_b = self.B[b_idx]
|
output["b"], path_b = self.B[b_idx]
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from engine.base.i2i import EngineKernel, run_kernel
|
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||||
from engine.util.build import build_model
|
from engine.util.build import build_model
|
||||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||||
from loss.gan import GANLoss
|
from loss.gan import GANLoss
|
||||||
@ -101,6 +101,31 @@ class TSITEngineKernel(EngineKernel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TSITTestEngineKernel(TestEngineKernel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def build_generators(self) -> dict:
|
||||||
|
generators = dict(
|
||||||
|
main=build_model(self.config.model.generator)
|
||||||
|
)
|
||||||
|
return generators
|
||||||
|
|
||||||
|
def to_load(self):
|
||||||
|
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
||||||
|
|
||||||
|
def inference(self, batch):
|
||||||
|
with torch.no_grad():
|
||||||
|
fake = self.generators["main"](content_img=batch["a"][0], style_img=batch["b"][0])
|
||||||
|
return {"a": fake.detach()}
|
||||||
|
|
||||||
|
|
||||||
def run(task, config, _):
|
def run(task, config, _):
|
||||||
kernel = TSITEngineKernel(config)
|
if task == "train":
|
||||||
run_kernel(task, config, kernel)
|
kernel = TSITEngineKernel(config)
|
||||||
|
run_kernel(task, config, kernel)
|
||||||
|
elif task == "test":
|
||||||
|
kernel = TSITTestEngineKernel(config)
|
||||||
|
run_kernel(task, config, kernel)
|
||||||
|
else:
|
||||||
|
raise NotImplemented
|
||||||
|
|||||||
@ -204,13 +204,21 @@ def get_trainer(config, kernel: EngineKernel):
|
|||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
|
def save_images_helper(output_dir, paths, images_list):
|
||||||
|
batch_size = len(paths)
|
||||||
|
for i in range(batch_size):
|
||||||
|
image_name = Path(paths[i]).name
|
||||||
|
img_list = [img[i] for img in images_list]
|
||||||
|
torchvision.utils.save_image(img_list, Path(output_dir) / image_name, nrow=len(img_list), padding=0,
|
||||||
|
normalize=True, range=(-1, 1))
|
||||||
|
|
||||||
|
|
||||||
def get_tester(config, kernel: TestEngineKernel):
|
def get_tester(config, kernel: TestEngineKernel):
|
||||||
logger = logging.getLogger(config.name)
|
logger = logging.getLogger(config.name)
|
||||||
|
|
||||||
def _step(engine, batch):
|
def _step(engine, batch):
|
||||||
real_a, path = convert_tensor(batch, idist.device())
|
batch = convert_tensor(batch, idist.device())
|
||||||
fake = kernel.inference({"a": real_a})["a"]
|
return {"batch": batch, "generated": kernel.inference(batch)}
|
||||||
return {"path": path, "img": [real_a.detach(), fake.detach()]}
|
|
||||||
|
|
||||||
tester = Engine(_step)
|
tester = Engine(_step)
|
||||||
tester.logger = logger
|
tester.logger = logger
|
||||||
@ -227,13 +235,14 @@ def get_tester(config, kernel: TestEngineKernel):
|
|||||||
|
|
||||||
@tester.on(Events.ITERATION_COMPLETED)
|
@tester.on(Events.ITERATION_COMPLETED)
|
||||||
def save_images(engine):
|
def save_images(engine):
|
||||||
img_tensors = engine.state.output["img"]
|
if engine.state.dataloader.dataset.__class__.__name__ == "SingleFolderDataset":
|
||||||
paths = engine.state.output["path"]
|
images, paths = engine.state.output["batch"]
|
||||||
batch_size = img_tensors[0].size(0)
|
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"]])
|
||||||
for i in range(batch_size):
|
|
||||||
image_name = Path(paths[i]).name
|
else:
|
||||||
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
|
for k in engine.state.output['generated']:
|
||||||
nrow=len(img_tensors), padding=0, normalize=True, range=(-1, 1))
|
images, paths = engine.state.output["batch"][k]
|
||||||
|
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"][k]])
|
||||||
|
|
||||||
return tester
|
return tester
|
||||||
|
|
||||||
@ -264,7 +273,7 @@ def run_kernel(task, config, kernel):
|
|||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
elif task == "test":
|
elif task == "test":
|
||||||
assert config.resume_from is not None
|
assert config.resume_from is not None
|
||||||
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
|
test_dataset = data.DATASET.build_with(config.data.test[config.data.test.which])
|
||||||
logger.info(f"test with dataset:\n{test_dataset}")
|
logger.info(f"test with dataset:\n{test_dataset}")
|
||||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||||
tester = get_tester(config, kernel)
|
tester = get_tester(config, kernel)
|
||||||
|
|||||||
17
main.py
17
main.py
@ -1,16 +1,15 @@
|
|||||||
from pathlib import Path
|
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
from pathlib import Path
|
||||||
import torch
|
|
||||||
|
|
||||||
import ignite
|
|
||||||
import ignite.distributed as idist
|
|
||||||
from ignite.utils import manual_seed
|
|
||||||
from util.misc import setup_logger
|
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
import ignite
|
||||||
|
import ignite.distributed as idist
|
||||||
|
import torch
|
||||||
|
from ignite.utils import manual_seed
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from util.misc import setup_logger
|
||||||
|
|
||||||
|
|
||||||
def log_basic_info(logger, config):
|
def log_basic_info(logger, config):
|
||||||
logger.info(f"Train {config.name}")
|
logger.info(f"Train {config.name}")
|
||||||
@ -28,7 +27,7 @@ def log_basic_info(logger, config):
|
|||||||
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
|
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
|
||||||
if setup_random_seed:
|
if setup_random_seed:
|
||||||
manual_seed(config.misc.random_seed + idist.get_rank())
|
manual_seed(config.misc.random_seed + idist.get_rank())
|
||||||
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
|
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else Path(config.output_dir)
|
||||||
config.output_dir = str(output_dir)
|
config.output_dir = str(output_dir)
|
||||||
|
|
||||||
if setup_output_dir and config.resume_from is None:
|
if setup_output_dir and config.resume_from is None:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(
|
def setup_logger(
|
||||||
@ -8,7 +8,6 @@ def setup_logger(
|
|||||||
level: int = logging.INFO,
|
level: int = logging.INFO,
|
||||||
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
|
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
|
||||||
filepath: Optional[str] = None,
|
filepath: Optional[str] = None,
|
||||||
file_level: int = logging.DEBUG,
|
|
||||||
distributed_rank: Optional[int] = None,
|
distributed_rank: Optional[int] = None,
|
||||||
) -> logging.Logger:
|
) -> logging.Logger:
|
||||||
"""Setups logger: name, level, format etc.
|
"""Setups logger: name, level, format etc.
|
||||||
@ -18,7 +17,6 @@ def setup_logger(
|
|||||||
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
|
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
|
||||||
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
|
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
|
||||||
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
|
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
|
||||||
file_level (int): Optional logging level for logging file.
|
|
||||||
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
|
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
|
||||||
If None, distributed_rank is initialized to the rank of process.
|
If None, distributed_rank is initialized to the rank of process.
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user