TAFG 0.01
This commit is contained in:
parent
14d4247112
commit
2469bf15fe
@ -49,7 +49,7 @@ loss:
|
|||||||
criterion: 'L1'
|
criterion: 'L1'
|
||||||
style_loss: False
|
style_loss: False
|
||||||
perceptual_loss: True
|
perceptual_loss: True
|
||||||
weight: 1
|
weight: 5
|
||||||
style:
|
style:
|
||||||
layer_weights:
|
layer_weights:
|
||||||
"1": 0.03125
|
"1": 0.03125
|
||||||
@ -63,10 +63,10 @@ loss:
|
|||||||
weight: 0
|
weight: 0
|
||||||
fm:
|
fm:
|
||||||
level: 1
|
level: 1
|
||||||
weight: 1
|
weight: 10
|
||||||
recon:
|
recon:
|
||||||
level: 1
|
level: 1
|
||||||
weight: 1
|
weight: 5
|
||||||
|
|
||||||
optimizers:
|
optimizers:
|
||||||
generator:
|
generator:
|
||||||
@ -97,7 +97,8 @@ data:
|
|||||||
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
|
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
|
||||||
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
|
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
|
||||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||||
edge_type: "hed"
|
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||||
|
edge_type: "landmark_canny"
|
||||||
size: [128, 128]
|
size: [128, 128]
|
||||||
random_pair: True
|
random_pair: True
|
||||||
pipeline:
|
pipeline:
|
||||||
|
|||||||
@ -1,129 +0,0 @@
|
|||||||
name: TAHG
|
|
||||||
engine: TAHG
|
|
||||||
result_dir: ./result
|
|
||||||
max_pairs: 1000000
|
|
||||||
|
|
||||||
misc:
|
|
||||||
random_seed: 324
|
|
||||||
|
|
||||||
checkpoint:
|
|
||||||
epoch_interval: 1 # one checkpoint every 1 epoch
|
|
||||||
n_saved: 2
|
|
||||||
|
|
||||||
interval:
|
|
||||||
print_per_iteration: 10 # print once per 10 iteration
|
|
||||||
tensorboard:
|
|
||||||
scalar: 100
|
|
||||||
image: 2
|
|
||||||
|
|
||||||
model:
|
|
||||||
generator:
|
|
||||||
_type: TAHG-Generator
|
|
||||||
_bn_to_sync_bn: False
|
|
||||||
style_in_channels: 3
|
|
||||||
content_in_channels: 1
|
|
||||||
num_blocks: 4
|
|
||||||
discriminator:
|
|
||||||
_type: TAHG-Discriminator
|
|
||||||
in_channels: 3
|
|
||||||
|
|
||||||
loss:
|
|
||||||
gan:
|
|
||||||
loss_type: lsgan
|
|
||||||
real_label_val: 1.0
|
|
||||||
fake_label_val: 0.0
|
|
||||||
weight: 1.0
|
|
||||||
edge:
|
|
||||||
criterion: 'L1'
|
|
||||||
hed_pretrained_model_path: "./network-bsds500.pytorch"
|
|
||||||
weight: 1
|
|
||||||
perceptual:
|
|
||||||
layer_weights:
|
|
||||||
"3": 1.0
|
|
||||||
# "0": 1.0
|
|
||||||
# "5": 1.0
|
|
||||||
# "10": 1.0
|
|
||||||
# "19": 1.0
|
|
||||||
criterion: 'L2'
|
|
||||||
style_loss: True
|
|
||||||
perceptual_loss: False
|
|
||||||
weight: 20
|
|
||||||
recon:
|
|
||||||
level: 1
|
|
||||||
weight: 1
|
|
||||||
|
|
||||||
optimizers:
|
|
||||||
generator:
|
|
||||||
_type: Adam
|
|
||||||
lr: 0.0001
|
|
||||||
betas: [ 0.5, 0.999 ]
|
|
||||||
weight_decay: 0.0001
|
|
||||||
discriminator:
|
|
||||||
_type: Adam
|
|
||||||
lr: 1e-4
|
|
||||||
betas: [ 0.5, 0.999 ]
|
|
||||||
weight_decay: 0.0001
|
|
||||||
|
|
||||||
data:
|
|
||||||
train:
|
|
||||||
scheduler:
|
|
||||||
start_proportion: 0.5
|
|
||||||
target_lr: 0
|
|
||||||
buffer_size: 50
|
|
||||||
dataloader:
|
|
||||||
batch_size: 160
|
|
||||||
shuffle: True
|
|
||||||
num_workers: 2
|
|
||||||
pin_memory: True
|
|
||||||
drop_last: True
|
|
||||||
dataset:
|
|
||||||
_type: GenerationUnpairedDatasetWithEdge
|
|
||||||
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
|
|
||||||
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
|
|
||||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
|
||||||
edge_type: "hed"
|
|
||||||
size: [128, 128]
|
|
||||||
random_pair: True
|
|
||||||
pipeline:
|
|
||||||
- Load
|
|
||||||
- Resize:
|
|
||||||
size: [128, 128]
|
|
||||||
- ToTensor
|
|
||||||
- Normalize:
|
|
||||||
mean: [ 0.5, 0.5, 0.5 ]
|
|
||||||
std: [ 0.5, 0.5, 0.5 ]
|
|
||||||
test:
|
|
||||||
dataloader:
|
|
||||||
batch_size: 8
|
|
||||||
shuffle: False
|
|
||||||
num_workers: 1
|
|
||||||
pin_memory: False
|
|
||||||
drop_last: False
|
|
||||||
dataset:
|
|
||||||
_type: GenerationUnpairedDatasetWithEdge
|
|
||||||
root_a: "/data/i2i/VoxCeleb2Anime/testA"
|
|
||||||
root_b: "/data/i2i/VoxCeleb2Anime/testB"
|
|
||||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
|
||||||
edge_type: "hed"
|
|
||||||
random_pair: False
|
|
||||||
size: [128, 128]
|
|
||||||
pipeline:
|
|
||||||
- Load
|
|
||||||
- Resize:
|
|
||||||
size: [128, 128]
|
|
||||||
- ToTensor
|
|
||||||
- Normalize:
|
|
||||||
mean: [ 0.5, 0.5, 0.5 ]
|
|
||||||
std: [ 0.5, 0.5, 0.5 ]
|
|
||||||
video_dataset:
|
|
||||||
_type: SingleFolderDataset
|
|
||||||
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
|
|
||||||
with_path: True
|
|
||||||
pipeline:
|
|
||||||
- Load
|
|
||||||
- Resize:
|
|
||||||
size: [ 256, 256 ]
|
|
||||||
- ToTensor
|
|
||||||
- Normalize:
|
|
||||||
mean: [ 0.5, 0.5, 0.5 ]
|
|
||||||
std: [ 0.5, 0.5, 0.5 ]
|
|
||||||
@ -15,6 +15,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from .transform import transform_pipeline
|
from .transform import transform_pipeline
|
||||||
from .registry import DATASET
|
from .registry import DATASET
|
||||||
|
from .util import dlib_landmark
|
||||||
|
|
||||||
|
|
||||||
def default_transform_way(transform, sample):
|
def default_transform_way(transform, sample):
|
||||||
@ -178,20 +179,38 @@ class GenerationUnpairedDataset(Dataset):
|
|||||||
|
|
||||||
@DATASET.register_module()
|
@DATASET.register_module()
|
||||||
class GenerationUnpairedDatasetWithEdge(Dataset):
|
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, size=(256, 256)):
|
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)):
|
||||||
|
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
|
||||||
self.edge_type = edge_type
|
self.edge_type = edge_type
|
||||||
self.size = size
|
self.size = size
|
||||||
self.edges_path = Path(edges_path)
|
self.edges_path = Path(edges_path)
|
||||||
|
self.landmarks_path = Path(landmarks_path)
|
||||||
assert self.edges_path.exists()
|
assert self.edges_path.exists()
|
||||||
|
assert self.landmarks_path.exists()
|
||||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||||
self.random_pair = random_pair
|
self.random_pair = random_pair
|
||||||
|
|
||||||
def get_edge(self, origin_path):
|
def get_edge(self, origin_path):
|
||||||
op = Path(origin_path)
|
op = Path(origin_path)
|
||||||
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
|
if self.edge_type.startswith("landmark_"):
|
||||||
img = Image.open(edge_path).resize(self.size)
|
edge_type = self.edge_type.lstrip("landmark_")
|
||||||
return F.to_tensor(img)
|
use_landmark = True
|
||||||
|
else:
|
||||||
|
edge_type = self.edge_type
|
||||||
|
use_landmark = False
|
||||||
|
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
|
||||||
|
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size))
|
||||||
|
if not use_landmark:
|
||||||
|
return origin_edge
|
||||||
|
else:
|
||||||
|
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.{edge_type}.txt"
|
||||||
|
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
|
||||||
|
dist_tensor = torch.from_numpy(dlib_landmark.dist_tensor(key_points))
|
||||||
|
part_labels = torch.from_numpy(part_labels)
|
||||||
|
edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
|
||||||
|
edges = part_edge + edges
|
||||||
|
return torch.cat([edges, dist_tensor, part_labels], dim=0)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
a_idx = idx % len(self.A)
|
a_idx = idx % len(self.A)
|
||||||
@ -200,7 +219,6 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
|||||||
output["a"], path_a = self.A[a_idx]
|
output["a"], path_a = self.A[a_idx]
|
||||||
output["b"], path_b = self.B[b_idx]
|
output["b"], path_b = self.B[b_idx]
|
||||||
output["edge_a"] = self.get_edge(path_a)
|
output["edge_a"] = self.get_edge(path_a)
|
||||||
output["edge_b"] = self.get_edge(path_b)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|||||||
@ -65,12 +65,14 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
|
|
||||||
def criterion_generators(self, batch, generated) -> dict:
|
def criterion_generators(self, batch, generated) -> dict:
|
||||||
loss = dict()
|
loss = dict()
|
||||||
loss["perceptual"], _, = self.perceptual_loss(generated["b"], batch["b"]) * self.config.loss.perceptual.weight
|
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
||||||
|
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
|
||||||
for phase in "ab":
|
for phase in "ab":
|
||||||
pred_fake = self.discriminators[phase](generated[phase])
|
pred_fake = self.discriminators[phase](generated[phase])
|
||||||
for i, sub_pred_fake in enumerate(pred_fake):
|
loss[f"gan_{phase}"] = 0
|
||||||
|
for sub_pred_fake in pred_fake:
|
||||||
# last output is actual prediction
|
# last output is actual prediction
|
||||||
loss[f"gan_{phase}_sub_{i}"] = self.gan_loss(sub_pred_fake[-1], True)
|
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
|
||||||
|
|
||||||
if self.config.loss.fm.weight > 0 and phase == "b":
|
if self.config.loss.fm.weight > 0 and phase == "b":
|
||||||
pred_real = self.discriminators[phase](batch[phase])
|
pred_real = self.discriminators[phase](batch[phase])
|
||||||
|
|||||||
245
engine/TAHG.py
245
engine/TAHG.py
@ -1,245 +0,0 @@
|
|||||||
from itertools import chain
|
|
||||||
from math import ceil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchvision
|
|
||||||
|
|
||||||
import ignite.distributed as idist
|
|
||||||
from ignite.engine import Events, Engine
|
|
||||||
from ignite.metrics import RunningAverage
|
|
||||||
from ignite.utils import convert_tensor
|
|
||||||
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
|
|
||||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
|
||||||
|
|
||||||
from omegaconf import OmegaConf, read_write
|
|
||||||
|
|
||||||
import data
|
|
||||||
from loss.gan import GANLoss
|
|
||||||
from model.weight_init import generation_init_weights
|
|
||||||
from model.GAN.residual_generator import GANImageBuffer
|
|
||||||
from loss.I2I.edge_loss import EdgeLoss
|
|
||||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
|
||||||
from util.image import make_2d_grid
|
|
||||||
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
|
||||||
from util.build import build_model, build_optimizer
|
|
||||||
|
|
||||||
|
|
||||||
def build_lr_schedulers(optimizers, config):
|
|
||||||
g_milestones_values = [
|
|
||||||
(0, config.optimizers.generator.lr),
|
|
||||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
|
||||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
|
||||||
]
|
|
||||||
d_milestones_values = [
|
|
||||||
(0, config.optimizers.discriminator.lr),
|
|
||||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
|
||||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
|
||||||
]
|
|
||||||
return dict(
|
|
||||||
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
|
|
||||||
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_trainer(config, logger, train_data_loader):
|
|
||||||
generator = build_model(config.model.generator, config.distributed.model)
|
|
||||||
discriminators = dict(
|
|
||||||
a=build_model(config.model.discriminator, config.distributed.model),
|
|
||||||
b=build_model(config.model.discriminator, config.distributed.model),
|
|
||||||
)
|
|
||||||
generation_init_weights(generator)
|
|
||||||
for m in discriminators.values():
|
|
||||||
generation_init_weights(m)
|
|
||||||
|
|
||||||
logger.debug(discriminators["a"])
|
|
||||||
logger.debug(generator)
|
|
||||||
|
|
||||||
optimizers = dict(
|
|
||||||
g=build_optimizer(generator.parameters(), config.optimizers.generator),
|
|
||||||
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
|
|
||||||
)
|
|
||||||
logger.info(f"build optimizers:\n{optimizers}")
|
|
||||||
|
|
||||||
lr_schedulers = build_lr_schedulers(optimizers, config)
|
|
||||||
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
|
||||||
|
|
||||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
|
||||||
gan_loss_cfg.pop("weight")
|
|
||||||
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
|
||||||
|
|
||||||
edge_loss_cfg = OmegaConf.to_container(config.loss.edge)
|
|
||||||
edge_loss_cfg.pop("weight")
|
|
||||||
edge_loss = EdgeLoss(**edge_loss_cfg).to(idist.device())
|
|
||||||
|
|
||||||
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
|
||||||
perceptual_loss_cfg.pop("weight")
|
|
||||||
perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
|
||||||
|
|
||||||
recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
|
||||||
|
|
||||||
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
|
|
||||||
|
|
||||||
def _step(engine, batch):
|
|
||||||
batch = convert_tensor(batch, idist.device())
|
|
||||||
real = dict(a=batch["a"], b=batch["b"])
|
|
||||||
fake = dict(
|
|
||||||
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
|
|
||||||
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
|
|
||||||
)
|
|
||||||
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
|
|
||||||
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
|
|
||||||
|
|
||||||
optimizers["g"].zero_grad()
|
|
||||||
loss_g = dict()
|
|
||||||
for d in "ab":
|
|
||||||
discriminators[d].requires_grad_(False)
|
|
||||||
pred_fake = discriminators[d](fake[d])
|
|
||||||
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
|
|
||||||
_, t = perceptual_loss(fake[d], real[d])
|
|
||||||
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
|
|
||||||
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], batch["edge_a"])
|
|
||||||
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
|
|
||||||
loss_g["recon_b"] = config.loss.recon.weight * recon_loss(rec_b, real["b"])
|
|
||||||
loss_g["recon_bb"] = config.loss.recon.weight * recon_loss(rec_bb, real["b"])
|
|
||||||
sum(loss_g.values()).backward()
|
|
||||||
optimizers["g"].step()
|
|
||||||
|
|
||||||
for discriminator in discriminators.values():
|
|
||||||
discriminator.requires_grad_(True)
|
|
||||||
|
|
||||||
optimizers["d"].zero_grad()
|
|
||||||
loss_d = dict()
|
|
||||||
for k in discriminators.keys():
|
|
||||||
pred_real = discriminators[k](real[k])
|
|
||||||
pred_fake = discriminators[k](image_buffers[k].query(fake[k].detach()))
|
|
||||||
loss_d[f"gan_{k}"] = (gan_loss(pred_real, True, is_discriminator=True) +
|
|
||||||
gan_loss(pred_fake, False, is_discriminator=True)) / 2
|
|
||||||
sum(loss_d.values()).backward()
|
|
||||||
optimizers["d"].step()
|
|
||||||
|
|
||||||
generated_img = {f"real_{k}": real[k].detach() for k in real}
|
|
||||||
generated_img["rec_b"] = rec_b.detach()
|
|
||||||
generated_img["rec_bb"] = rec_b.detach()
|
|
||||||
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
|
|
||||||
generated_img.update({f"edge_{k}": batch[f"edge_{k}"].expand(-1, 3, -1, -1).detach() for k in "ab"})
|
|
||||||
return {
|
|
||||||
"loss": {
|
|
||||||
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
|
||||||
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
|
|
||||||
},
|
|
||||||
"img": generated_img
|
|
||||||
}
|
|
||||||
|
|
||||||
trainer = Engine(_step)
|
|
||||||
trainer.logger = logger
|
|
||||||
for lr_shd in lr_schedulers.values():
|
|
||||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
|
||||||
|
|
||||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
|
||||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
|
||||||
|
|
||||||
to_save = dict(trainer=trainer)
|
|
||||||
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
|
||||||
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
|
||||||
to_save.update({"generator": generator})
|
|
||||||
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
|
||||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
|
|
||||||
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
|
||||||
|
|
||||||
def output_transform(output):
|
|
||||||
loss = dict()
|
|
||||||
for tl in output["loss"]:
|
|
||||||
if isinstance(output["loss"][tl], dict):
|
|
||||||
for l in output["loss"][tl]:
|
|
||||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
|
||||||
else:
|
|
||||||
loss[tl] = output["loss"][tl]
|
|
||||||
return loss
|
|
||||||
|
|
||||||
iter_per_epoch = len(train_data_loader)
|
|
||||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
|
|
||||||
if tensorboard_handler is not None:
|
|
||||||
tensorboard_handler.attach(
|
|
||||||
trainer,
|
|
||||||
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
|
||||||
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
|
||||||
)
|
|
||||||
|
|
||||||
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
|
|
||||||
def show_images(engine):
|
|
||||||
output = engine.state.output
|
|
||||||
image_order = dict(
|
|
||||||
a=["edge_a", "real_a", "fake_a", "fake_b"],
|
|
||||||
b=["edge_b", "real_b", "rec_b", "rec_bb"]
|
|
||||||
)
|
|
||||||
for k in "ab":
|
|
||||||
tensorboard_handler.writer.add_image(
|
|
||||||
f"train/{k}",
|
|
||||||
make_2d_grid([output["img"][o] for o in image_order[k]]),
|
|
||||||
engine.state.iteration
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
g = torch.Generator()
|
|
||||||
g.manual_seed(config.misc.random_seed)
|
|
||||||
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
|
|
||||||
test_images = dict(
|
|
||||||
a=[[], [], [], []],
|
|
||||||
b=[[], [], [], []]
|
|
||||||
)
|
|
||||||
for i in range(random_start, random_start + 10):
|
|
||||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
|
||||||
for k in batch:
|
|
||||||
batch[k] = batch[k].view(1, *batch[k].size())
|
|
||||||
|
|
||||||
real = dict(a=batch["a"], b=batch["b"])
|
|
||||||
fake = dict(
|
|
||||||
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
|
|
||||||
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
|
|
||||||
)
|
|
||||||
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
|
|
||||||
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
|
|
||||||
|
|
||||||
test_images["a"][0].append(batch["edge_a"])
|
|
||||||
test_images["a"][1].append(batch["a"])
|
|
||||||
test_images["a"][2].append(fake["a"])
|
|
||||||
test_images["a"][3].append(fake["b"])
|
|
||||||
test_images["b"][0].append(batch["edge_b"])
|
|
||||||
test_images["b"][1].append(batch["b"])
|
|
||||||
test_images["b"][2].append(rec_b)
|
|
||||||
test_images["b"][3].append(rec_bb)
|
|
||||||
for n in "ab":
|
|
||||||
tensorboard_handler.writer.add_image(
|
|
||||||
f"test/{n}",
|
|
||||||
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
|
|
||||||
engine.state.iteration
|
|
||||||
)
|
|
||||||
|
|
||||||
return trainer
|
|
||||||
|
|
||||||
|
|
||||||
def run(task, config, logger):
|
|
||||||
assert torch.backends.cudnn.enabled
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
logger.info(f"start task {task}")
|
|
||||||
with read_write(config):
|
|
||||||
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
|
|
||||||
|
|
||||||
if task == "train":
|
|
||||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
|
||||||
logger.info(f"train with dataset:\n{train_dataset}")
|
|
||||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
|
||||||
trainer = get_trainer(config, logger, train_data_loader)
|
|
||||||
if idist.get_rank() == 0:
|
|
||||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
|
||||||
trainer.state.test_dataset = test_dataset
|
|
||||||
try:
|
|
||||||
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
|
||||||
except Exception:
|
|
||||||
import traceback
|
|
||||||
print(traceback.format_exc())
|
|
||||||
else:
|
|
||||||
return NotImplemented(f"invalid task: {task}")
|
|
||||||
@ -145,6 +145,7 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
|||||||
loss[tl] = output["loss"][tl]
|
loss[tl] = output["loss"][tl]
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
pairs_per_iteration = config.data.train.dataloader.batch_size
|
||||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
|
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
|
||||||
if tensorboard_handler is not None:
|
if tensorboard_handler is not None:
|
||||||
tensorboard_handler.attach(
|
tensorboard_handler.attach(
|
||||||
@ -159,7 +160,8 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
|||||||
test_images = {}
|
test_images = {}
|
||||||
for k in output["img"]:
|
for k in output["img"]:
|
||||||
image_list = output["img"][k]
|
image_list = output["img"][k]
|
||||||
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list), engine.state.iteration)
|
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
|
||||||
|
engine.state.iteration * pairs_per_iteration)
|
||||||
test_images[k] = []
|
test_images[k] = []
|
||||||
for i in range(len(image_list)):
|
for i in range(len(image_list)):
|
||||||
test_images[k].append([])
|
test_images[k].append([])
|
||||||
@ -182,6 +184,6 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
|||||||
tensorboard_handler.writer.add_image(
|
tensorboard_handler.writer.add_image(
|
||||||
f"test/{k}",
|
f"test/{k}",
|
||||||
make_2d_grid([torch.cat(ti) for ti in test_images[k]]),
|
make_2d_grid([torch.cat(ti) for ti in test_images[k]]),
|
||||||
engine.state.iteration
|
engine.state.iteration * pairs_per_iteration
|
||||||
)
|
)
|
||||||
return trainer
|
return trainer
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user