Compare commits
3 Commits
7a85499edf
...
e71e8d95d0
| Author | SHA1 | Date | |
|---|---|---|---|
| e71e8d95d0 | |||
| 89b54105c7 | |||
| 715a2e64a1 |
@ -1,6 +1,6 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="PublishConfigData" autoUpload="Always" serverName="15d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
<component name="PublishConfigData" autoUpload="Always" serverName="22d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||||
<serverData>
|
<serverData>
|
||||||
<paths name="14d">
|
<paths name="14d">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
@ -16,6 +16,13 @@
|
|||||||
</mappings>
|
</mappings>
|
||||||
</serverdata>
|
</serverdata>
|
||||||
</paths>
|
</paths>
|
||||||
|
<paths name="22d">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping deploy="/raycv" local="$PROJECT_DIR$" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
</serverData>
|
</serverData>
|
||||||
<option name="myAutoUpload" value="ALWAYS" />
|
<option name="myAutoUpload" value="ALWAYS" />
|
||||||
</component>
|
</component>
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="22d-base" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
||||||
6
.idea/other.xml
Normal file
6
.idea/other.xml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="PySciProjectComponent">
|
||||||
|
<option name="PY_SCI_VIEW_SUGGESTED" value="true" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="22d-base" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
<component name="TestRunnerService">
|
<component name="TestRunnerService">
|
||||||
|
|||||||
132
configs/synthesizers/TAHG.yml
Normal file
132
configs/synthesizers/TAHG.yml
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
name: TAHG
|
||||||
|
engine: TAHG
|
||||||
|
result_dir: ./result
|
||||||
|
max_pairs: 1000000
|
||||||
|
|
||||||
|
distributed:
|
||||||
|
model:
|
||||||
|
# broadcast_buffers: False
|
||||||
|
|
||||||
|
misc:
|
||||||
|
random_seed: 324
|
||||||
|
|
||||||
|
checkpoint:
|
||||||
|
epoch_interval: 1 # one checkpoint every 1 epoch
|
||||||
|
n_saved: 2
|
||||||
|
|
||||||
|
interval:
|
||||||
|
print_per_iteration: 10 # print once per 10 iteration
|
||||||
|
tensorboard:
|
||||||
|
scalar: 100
|
||||||
|
image: 2
|
||||||
|
|
||||||
|
model:
|
||||||
|
generator:
|
||||||
|
_type: TAHG-Generator
|
||||||
|
style_in_channels: 3
|
||||||
|
content_in_channels: 1
|
||||||
|
num_blocks: 4
|
||||||
|
discriminator:
|
||||||
|
_type: TAHG-Discriminator
|
||||||
|
in_channels: 3
|
||||||
|
|
||||||
|
loss:
|
||||||
|
gan:
|
||||||
|
loss_type: lsgan
|
||||||
|
real_label_val: 1.0
|
||||||
|
fake_label_val: 0.0
|
||||||
|
weight: 1.0
|
||||||
|
edge:
|
||||||
|
criterion: 'L1'
|
||||||
|
hed_pretrained_model_path: "./network-bsds500.pytorch"
|
||||||
|
weight: 1
|
||||||
|
perceptual:
|
||||||
|
layer_weights:
|
||||||
|
"3": 1.0
|
||||||
|
# "0": 1.0
|
||||||
|
# "5": 1.0
|
||||||
|
# "10": 1.0
|
||||||
|
# "19": 1.0
|
||||||
|
criterion: 'L2'
|
||||||
|
style_loss: True
|
||||||
|
perceptual_loss: False
|
||||||
|
weight: 20
|
||||||
|
recon:
|
||||||
|
level: 1
|
||||||
|
weight: 1
|
||||||
|
|
||||||
|
optimizers:
|
||||||
|
generator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 0.0001
|
||||||
|
betas: [ 0.5, 0.999 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
discriminator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 1e-4
|
||||||
|
betas: [ 0.5, 0.999 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
|
||||||
|
data:
|
||||||
|
train:
|
||||||
|
scheduler:
|
||||||
|
start_proportion: 0.5
|
||||||
|
target_lr: 0
|
||||||
|
buffer_size: 50
|
||||||
|
dataloader:
|
||||||
|
batch_size: 160
|
||||||
|
shuffle: True
|
||||||
|
num_workers: 2
|
||||||
|
pin_memory: True
|
||||||
|
drop_last: True
|
||||||
|
dataset:
|
||||||
|
_type: GenerationUnpairedDatasetWithEdge
|
||||||
|
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
|
||||||
|
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
|
||||||
|
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||||
|
edge_type: "hed"
|
||||||
|
size: [128, 128]
|
||||||
|
random_pair: True
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [128, 128]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
test:
|
||||||
|
dataloader:
|
||||||
|
batch_size: 8
|
||||||
|
shuffle: False
|
||||||
|
num_workers: 1
|
||||||
|
pin_memory: False
|
||||||
|
drop_last: False
|
||||||
|
dataset:
|
||||||
|
_type: GenerationUnpairedDatasetWithEdge
|
||||||
|
root_a: "/data/i2i/VoxCeleb2Anime/testA"
|
||||||
|
root_b: "/data/i2i/VoxCeleb2Anime/testB"
|
||||||
|
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||||
|
edge_type: "hed"
|
||||||
|
random_pair: False
|
||||||
|
size: [128, 128]
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [128, 128]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
video_dataset:
|
||||||
|
_type: SingleFolderDataset
|
||||||
|
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
|
||||||
|
with_path: True
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
@ -1,10 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torchvision.datasets import ImageFolder
|
from torchvision.datasets import ImageFolder
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
|
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
|
||||||
|
|
||||||
import lmdb
|
import lmdb
|
||||||
@ -171,3 +174,37 @@ class GenerationUnpairedDataset(Dataset):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET.register_module()
|
||||||
|
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||||
|
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, size=(256, 256)):
|
||||||
|
self.edge_type = edge_type
|
||||||
|
self.size = size
|
||||||
|
self.edges_path = Path(edges_path)
|
||||||
|
assert self.edges_path.exists()
|
||||||
|
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||||
|
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||||
|
self.random_pair = random_pair
|
||||||
|
|
||||||
|
def get_edge(self, origin_path):
|
||||||
|
op = Path(origin_path)
|
||||||
|
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
|
||||||
|
img = Image.open(edge_path).resize(self.size)
|
||||||
|
return F.to_tensor(img)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
a_idx = idx % len(self.A)
|
||||||
|
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||||
|
output = dict()
|
||||||
|
output["a"], path_a = self.A[a_idx]
|
||||||
|
output["b"], path_b = self.B[b_idx]
|
||||||
|
output["edge_a"] = self.get_edge(path_a)
|
||||||
|
output["edge_b"] = self.get_edge(path_b)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return max(len(self.A), len(self.B))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||||
|
|||||||
0
data/util/__init__.py
Normal file
0
data/util/__init__.py
Normal file
67
data/util/dlib_landmark.py
Normal file
67
data/util/dlib_landmark.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from skimage import feature
|
||||||
|
|
||||||
|
# https://ibug.doc.ic.ac.uk/resources/facial-point-annotations/
|
||||||
|
DLIB_LANDMARKS_PART_LIST = [
|
||||||
|
[list(range(0, 17)) + list(range(68, 83)) + [0]], # face
|
||||||
|
[range(17, 22)], # right eyebrow
|
||||||
|
[range(22, 27)], # left eyebrow
|
||||||
|
[[28, 31], range(31, 36), [35, 28]], # nose
|
||||||
|
[[36, 37, 38, 39], [39, 40, 41, 36]], # right eye
|
||||||
|
[[42, 43, 44, 45], [45, 46, 47, 42]], # left eye
|
||||||
|
[range(48, 55), [54, 55, 56, 57, 58, 59, 48]], # mouth
|
||||||
|
[range(60, 65), [64, 65, 66, 67, 60]] # tongue
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def dist_tensor(key_points, size=(256, 256)):
|
||||||
|
dist_list = []
|
||||||
|
for edge_list in DLIB_LANDMARKS_PART_LIST:
|
||||||
|
for edge in edge_list:
|
||||||
|
pts = key_points[edge, :]
|
||||||
|
im_edge = np.zeros(size, np.uint8)
|
||||||
|
cv2.polylines(im_edge, [pts], isClosed=False, color=255, thickness=1)
|
||||||
|
im_dist = cv2.distanceTransform(255 - im_edge, cv2.DIST_L1, 3)
|
||||||
|
im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
|
||||||
|
dist_list.append(im_dist)
|
||||||
|
return np.stack(dist_list)
|
||||||
|
|
||||||
|
|
||||||
|
def read_keypoints(kp_path, origin_size=(256, 256), size=(256, 256), thickness=1):
|
||||||
|
key_points = np.loadtxt(kp_path, delimiter=",").astype(np.int32)
|
||||||
|
if origin_size != size:
|
||||||
|
# resize key_points using simplest way...
|
||||||
|
key_points = (key_points * (np.array(size) / np.array(origin_size))).astype(np.int32)
|
||||||
|
|
||||||
|
# add upper half face by symmetry
|
||||||
|
face_pts = key_points[:17, :]
|
||||||
|
face_baseline_y = (face_pts[0, 1] + face_pts[-1, 1]) // 2
|
||||||
|
upper_symmetry_face_pts = face_pts[1:-1, :].copy()
|
||||||
|
# keep x untouched
|
||||||
|
upper_symmetry_face_pts[:, 1] = face_baseline_y + (face_baseline_y - upper_symmetry_face_pts[:, 1]) * 2 // 3
|
||||||
|
key_points = np.vstack((key_points, upper_symmetry_face_pts[::-1, :]))
|
||||||
|
|
||||||
|
assert key_points.shape == (83, 2)
|
||||||
|
|
||||||
|
part_labels = np.zeros((len(DLIB_LANDMARKS_PART_LIST), *size), np.uint8)
|
||||||
|
part_edge = np.zeros(size, np.uint8)
|
||||||
|
for i, edge_list in enumerate(DLIB_LANDMARKS_PART_LIST):
|
||||||
|
indices = [item for sublist in edge_list for item in sublist]
|
||||||
|
pts = key_points[indices, :]
|
||||||
|
cv2.fillPoly(part_labels[i], pts=[pts], color=1)
|
||||||
|
if i in [1, 2]:
|
||||||
|
# some part of landmarks is a line
|
||||||
|
cv2.polylines(part_edge, [pts], isClosed=False, color=1, thickness=thickness)
|
||||||
|
else:
|
||||||
|
cv2.drawContours(part_edge, [pts], 0, color=1, thickness=thickness)
|
||||||
|
|
||||||
|
return key_points, part_labels, part_edge
|
||||||
|
|
||||||
|
|
||||||
|
def edge_map(img, part_labels, part_edge, remove_edge_within_face=True):
|
||||||
|
edges = feature.canny(np.array(img.convert("L")))
|
||||||
|
if remove_edge_within_face:
|
||||||
|
edges = edges * (part_labels.sum(0) == 0) # remove edges within face
|
||||||
|
edges = part_edge + edges
|
||||||
|
return edges
|
||||||
245
engine/TAHG.py
Normal file
245
engine/TAHG.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
from itertools import chain
|
||||||
|
from math import ceil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
import ignite.distributed as idist
|
||||||
|
from ignite.engine import Events, Engine
|
||||||
|
from ignite.metrics import RunningAverage
|
||||||
|
from ignite.utils import convert_tensor
|
||||||
|
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
|
||||||
|
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf, read_write
|
||||||
|
|
||||||
|
import data
|
||||||
|
from loss.gan import GANLoss
|
||||||
|
from model.weight_init import generation_init_weights
|
||||||
|
from model.GAN.residual_generator import GANImageBuffer
|
||||||
|
from loss.I2I.edge_loss import EdgeLoss
|
||||||
|
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||||
|
from util.image import make_2d_grid
|
||||||
|
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||||
|
from util.build import build_model, build_optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def build_lr_schedulers(optimizers, config):
|
||||||
|
g_milestones_values = [
|
||||||
|
(0, config.optimizers.generator.lr),
|
||||||
|
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
||||||
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||||
|
]
|
||||||
|
d_milestones_values = [
|
||||||
|
(0, config.optimizers.discriminator.lr),
|
||||||
|
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
||||||
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||||
|
]
|
||||||
|
return dict(
|
||||||
|
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
|
||||||
|
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_trainer(config, logger, train_data_loader):
|
||||||
|
generator = build_model(config.model.generator, config.distributed.model)
|
||||||
|
discriminators = dict(
|
||||||
|
a=build_model(config.model.discriminator, config.distributed.model),
|
||||||
|
b=build_model(config.model.discriminator, config.distributed.model),
|
||||||
|
)
|
||||||
|
generation_init_weights(generator)
|
||||||
|
for m in discriminators.values():
|
||||||
|
generation_init_weights(m)
|
||||||
|
|
||||||
|
logger.debug(discriminators["a"])
|
||||||
|
logger.debug(generator)
|
||||||
|
|
||||||
|
optimizers = dict(
|
||||||
|
g=build_optimizer(generator.parameters(), config.optimizers.generator),
|
||||||
|
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
|
||||||
|
)
|
||||||
|
logger.info(f"build optimizers:\n{optimizers}")
|
||||||
|
|
||||||
|
lr_schedulers = build_lr_schedulers(optimizers, config)
|
||||||
|
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
||||||
|
|
||||||
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||||
|
gan_loss_cfg.pop("weight")
|
||||||
|
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
edge_loss_cfg = OmegaConf.to_container(config.loss.edge)
|
||||||
|
edge_loss_cfg.pop("weight")
|
||||||
|
edge_loss = EdgeLoss(**edge_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
||||||
|
perceptual_loss_cfg.pop("weight")
|
||||||
|
perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||||
|
|
||||||
|
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
|
||||||
|
|
||||||
|
def _step(engine, batch):
|
||||||
|
batch = convert_tensor(batch, idist.device())
|
||||||
|
real = dict(a=batch["a"], b=batch["b"])
|
||||||
|
fake = dict(
|
||||||
|
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
|
||||||
|
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
|
||||||
|
)
|
||||||
|
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
|
||||||
|
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
|
||||||
|
|
||||||
|
optimizers["g"].zero_grad()
|
||||||
|
loss_g = dict()
|
||||||
|
for d in "ab":
|
||||||
|
discriminators[d].requires_grad_(False)
|
||||||
|
pred_fake = discriminators[d](fake[d])
|
||||||
|
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
|
||||||
|
_, t = perceptual_loss(fake[d], real[d])
|
||||||
|
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
|
||||||
|
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], batch["edge_a"])
|
||||||
|
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
|
||||||
|
loss_g["recon_b"] = config.loss.recon.weight * recon_loss(rec_b, real["b"])
|
||||||
|
loss_g["recon_bb"] = config.loss.recon.weight * recon_loss(rec_bb, real["b"])
|
||||||
|
sum(loss_g.values()).backward()
|
||||||
|
optimizers["g"].step()
|
||||||
|
|
||||||
|
for discriminator in discriminators.values():
|
||||||
|
discriminator.requires_grad_(True)
|
||||||
|
|
||||||
|
optimizers["d"].zero_grad()
|
||||||
|
loss_d = dict()
|
||||||
|
for k in discriminators.keys():
|
||||||
|
pred_real = discriminators[k](real[k])
|
||||||
|
pred_fake = discriminators[k](image_buffers[k].query(fake[k].detach()))
|
||||||
|
loss_d[f"gan_{k}"] = (gan_loss(pred_real, True, is_discriminator=True) +
|
||||||
|
gan_loss(pred_fake, False, is_discriminator=True)) / 2
|
||||||
|
sum(loss_d.values()).backward()
|
||||||
|
optimizers["d"].step()
|
||||||
|
|
||||||
|
generated_img = {f"real_{k}": real[k].detach() for k in real}
|
||||||
|
generated_img["rec_b"] = rec_b.detach()
|
||||||
|
generated_img["rec_bb"] = rec_b.detach()
|
||||||
|
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
|
||||||
|
generated_img.update({f"edge_{k}": batch[f"edge_{k}"].expand(-1, 3, -1, -1).detach() for k in "ab"})
|
||||||
|
return {
|
||||||
|
"loss": {
|
||||||
|
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
||||||
|
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
|
||||||
|
},
|
||||||
|
"img": generated_img
|
||||||
|
}
|
||||||
|
|
||||||
|
trainer = Engine(_step)
|
||||||
|
trainer.logger = logger
|
||||||
|
for lr_shd in lr_schedulers.values():
|
||||||
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
||||||
|
|
||||||
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||||
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
||||||
|
|
||||||
|
to_save = dict(trainer=trainer)
|
||||||
|
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
||||||
|
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||||
|
to_save.update({"generator": generator})
|
||||||
|
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||||
|
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
|
||||||
|
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||||
|
|
||||||
|
def output_transform(output):
|
||||||
|
loss = dict()
|
||||||
|
for tl in output["loss"]:
|
||||||
|
if isinstance(output["loss"][tl], dict):
|
||||||
|
for l in output["loss"][tl]:
|
||||||
|
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||||
|
else:
|
||||||
|
loss[tl] = output["loss"][tl]
|
||||||
|
return loss
|
||||||
|
|
||||||
|
iter_per_epoch = len(train_data_loader)
|
||||||
|
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
|
||||||
|
if tensorboard_handler is not None:
|
||||||
|
tensorboard_handler.attach(
|
||||||
|
trainer,
|
||||||
|
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
||||||
|
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||||
|
)
|
||||||
|
|
||||||
|
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
|
||||||
|
def show_images(engine):
|
||||||
|
output = engine.state.output
|
||||||
|
image_order = dict(
|
||||||
|
a=["edge_a", "real_a", "fake_a", "fake_b"],
|
||||||
|
b=["edge_b", "real_b", "rec_b", "rec_bb"]
|
||||||
|
)
|
||||||
|
for k in "ab":
|
||||||
|
tensorboard_handler.writer.add_image(
|
||||||
|
f"train/{k}",
|
||||||
|
make_2d_grid([output["img"][o] for o in image_order[k]]),
|
||||||
|
engine.state.iteration
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(config.misc.random_seed)
|
||||||
|
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
|
||||||
|
test_images = dict(
|
||||||
|
a=[[], [], [], []],
|
||||||
|
b=[[], [], [], []]
|
||||||
|
)
|
||||||
|
for i in range(random_start, random_start + 10):
|
||||||
|
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||||
|
for k in batch:
|
||||||
|
batch[k] = batch[k].view(1, *batch[k].size())
|
||||||
|
|
||||||
|
real = dict(a=batch["a"], b=batch["b"])
|
||||||
|
fake = dict(
|
||||||
|
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
|
||||||
|
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
|
||||||
|
)
|
||||||
|
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
|
||||||
|
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
|
||||||
|
|
||||||
|
test_images["a"][0].append(batch["edge_a"])
|
||||||
|
test_images["a"][1].append(batch["a"])
|
||||||
|
test_images["a"][2].append(fake["a"])
|
||||||
|
test_images["a"][3].append(fake["b"])
|
||||||
|
test_images["b"][0].append(batch["edge_b"])
|
||||||
|
test_images["b"][1].append(batch["b"])
|
||||||
|
test_images["b"][2].append(rec_b)
|
||||||
|
test_images["b"][3].append(rec_bb)
|
||||||
|
for n in "ab":
|
||||||
|
tensorboard_handler.writer.add_image(
|
||||||
|
f"test/{n}",
|
||||||
|
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
|
||||||
|
engine.state.iteration
|
||||||
|
)
|
||||||
|
|
||||||
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
|
def run(task, config, logger):
|
||||||
|
assert torch.backends.cudnn.enabled
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
logger.info(f"start task {task}")
|
||||||
|
with read_write(config):
|
||||||
|
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
|
||||||
|
|
||||||
|
if task == "train":
|
||||||
|
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||||
|
logger.info(f"train with dataset:\n{train_dataset}")
|
||||||
|
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||||
|
trainer = get_trainer(config, logger, train_data_loader)
|
||||||
|
if idist.get_rank() == 0:
|
||||||
|
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||||
|
trainer.state.test_dataset = test_dataset
|
||||||
|
try:
|
||||||
|
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
else:
|
||||||
|
return NotImplemented(f"invalid task: {task}")
|
||||||
0
loss/I2I/__init__.py
Normal file
0
loss/I2I/__init__.py
Normal file
129
loss/I2I/edge_loss.py
Normal file
129
loss/I2I/edge_loss.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class HED(nn.Module):
|
||||||
|
def __init__(self, pretrained_model_path, norm_img=True):
|
||||||
|
"""
|
||||||
|
HED module to get edge
|
||||||
|
:param pretrained_model_path: path to pretrained HED.
|
||||||
|
:param norm_img(bool): If True, the image will be normed to [0, 1]. Note that
|
||||||
|
this is different from the `use_input_norm` which norm the input in
|
||||||
|
in forward function of vgg according to the statistics of dataset.
|
||||||
|
Importantly, the input image must be in range [-1, 1].
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.norm_img = norm_img
|
||||||
|
|
||||||
|
self.vgg_nets = nn.ModuleList([torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(inplace=False),
|
||||||
|
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(inplace=False)
|
||||||
|
), torch.nn.Sequential(
|
||||||
|
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
||||||
|
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(inplace=False),
|
||||||
|
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(inplace=False)
|
||||||
|
), torch.nn.Sequential(
|
||||||
|
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
||||||
|
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(inplace=False),
|
||||||
|
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(inplace=False),
|
||||||
|
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(inplace=False)
|
||||||
|
), torch.nn.Sequential(
|
||||||
|
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
||||||
|
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(inplace=False),
|
||||||
|
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(inplace=False),
|
||||||
|
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(inplace=False)
|
||||||
|
), torch.nn.Sequential(
|
||||||
|
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
||||||
|
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(inplace=False),
|
||||||
|
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(inplace=False),
|
||||||
|
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(inplace=False)
|
||||||
|
)])
|
||||||
|
|
||||||
|
self.score_nets = nn.ModuleList([
|
||||||
|
torch.nn.Conv2d(in_channels=i, out_channels=1, kernel_size=1, stride=1, padding=0)
|
||||||
|
for i in [64, 128, 256, 512, 512]
|
||||||
|
])
|
||||||
|
|
||||||
|
self.combine_net = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
|
||||||
|
torch.nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.load_weights(pretrained_model_path)
|
||||||
|
self.register_buffer('mean', torch.Tensor([104.00698793, 116.66876762, 122.67891434]).view(1, 3, 1, 1))
|
||||||
|
for v in self.parameters():
|
||||||
|
v.requies_grad = False
|
||||||
|
|
||||||
|
def load_weights(self, pretrained_model_path):
|
||||||
|
checkpoint_path = Path(pretrained_model_path)
|
||||||
|
if not checkpoint_path.exists():
|
||||||
|
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
|
||||||
|
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
|
||||||
|
m = {"One": "0", "Two": "1", "Thr": "2", "Fou": "3", "Fiv": "4"}
|
||||||
|
|
||||||
|
def replace_key(key):
|
||||||
|
if key.startswith("moduleVgg"):
|
||||||
|
return f"vgg_nets.{m[key[9:12]]}{key[12:]}"
|
||||||
|
elif key.startswith("moduleScore"):
|
||||||
|
return f"score_nets.{m[key[11:14]]}{key[14:]}"
|
||||||
|
elif key.startswith("moduleCombine"):
|
||||||
|
return f"combine_net{key[13:]}"
|
||||||
|
else:
|
||||||
|
raise ValueError("wrong checkpoint for HED")
|
||||||
|
|
||||||
|
module_dict = {replace_key(k): v for k, v in ckp.items()}
|
||||||
|
self.load_state_dict(module_dict, strict=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.norm_img:
|
||||||
|
x = (x + 1.) * 0.5
|
||||||
|
x = x * 255.0 - self.mean
|
||||||
|
img_size = (x.size(2), x.size(3))
|
||||||
|
|
||||||
|
to_combine = []
|
||||||
|
for i in range(5):
|
||||||
|
x = self.vgg_nets[i](x)
|
||||||
|
score_x = self.score_nets[i](x)
|
||||||
|
to_combine.append(F.interpolate(input=score_x, size=img_size, mode='bilinear', align_corners=False))
|
||||||
|
out = self.combine_net(torch.cat(to_combine, 1))
|
||||||
|
return out.clamp(0.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeLoss(nn.Module):
|
||||||
|
def __init__(self, edge_extractor_type="HED", norm_img=True, criterion='L1', **kwargs):
|
||||||
|
super(EdgeLoss, self).__init__()
|
||||||
|
if edge_extractor_type == "HED":
|
||||||
|
pretrained_model_path = kwargs.get("hed_pretrained_model_path")
|
||||||
|
self.edge_extractor = HED(pretrained_model_path, norm_img)
|
||||||
|
else:
|
||||||
|
raise NotImplemented(f"do not support edge_extractor_type {edge_extractor_type}")
|
||||||
|
|
||||||
|
if criterion == 'L1':
|
||||||
|
self.criterion = nn.L1Loss()
|
||||||
|
elif criterion == "L2":
|
||||||
|
self.criterion = nn.MSELoss()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
|
||||||
|
|
||||||
|
def forward(self, x, gt, gt_is_edge=True):
|
||||||
|
edge = self.edge_extractor(x)
|
||||||
|
if not gt_is_edge:
|
||||||
|
gt = self.edge_extractor(gt.detach())
|
||||||
|
loss = self.criterion(edge, gt)
|
||||||
|
return loss
|
||||||
155
loss/I2I/perceptual_loss.py
Normal file
155
loss/I2I/perceptual_loss.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision.models.vgg as vgg
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualVGG(nn.Module):
|
||||||
|
"""VGG network used in calculating perceptual loss.
|
||||||
|
In this implementation, we allow users to choose whether use normalization
|
||||||
|
in the input feature and the type of vgg network. Note that the pretrained
|
||||||
|
path must fit the vgg type.
|
||||||
|
Args:
|
||||||
|
layer_name_list (list[str]): According to the index in this list,
|
||||||
|
forward function will return the corresponding features. This
|
||||||
|
list contains the name each layer in `vgg.feature`. An example
|
||||||
|
of this list is ['4', '10'].
|
||||||
|
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
||||||
|
use_input_norm (bool): If True, normalize the input image.
|
||||||
|
Importantly, the input feature must in the range [0, 1].
|
||||||
|
Default: True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True):
|
||||||
|
super(PerceptualVGG, self).__init__()
|
||||||
|
self.layer_name_list = layer_name_list
|
||||||
|
self.use_input_norm = use_input_norm
|
||||||
|
|
||||||
|
# get vgg model and load pretrained vgg weight
|
||||||
|
# remove _vgg from attributes to avoid `find_unused_parameters` bug
|
||||||
|
_vgg = getattr(vgg, vgg_type)(pretrained=True)
|
||||||
|
num_layers = max(map(int, layer_name_list)) + 1
|
||||||
|
assert len(_vgg.features) >= num_layers
|
||||||
|
# only borrow layers that will be used from _vgg to avoid unused params
|
||||||
|
self.vgg_layers = _vgg.features[:num_layers]
|
||||||
|
|
||||||
|
if self.use_input_norm:
|
||||||
|
# the mean is for image with range [0, 1]
|
||||||
|
self.register_buffer(
|
||||||
|
'mean',
|
||||||
|
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||||
|
# the std is for image with range [-1, 1]
|
||||||
|
self.register_buffer(
|
||||||
|
'std',
|
||||||
|
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||||
|
|
||||||
|
for v in self.vgg_layers.parameters():
|
||||||
|
v.requies_grad = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||||
|
Returns:
|
||||||
|
Tensor: Forward results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.use_input_norm:
|
||||||
|
x = (x - self.mean) / self.std
|
||||||
|
output = {}
|
||||||
|
|
||||||
|
for i, l in enumerate(self.vgg_layers):
|
||||||
|
x = l(x)
|
||||||
|
if str(i) in self.layer_name_list:
|
||||||
|
output[str(i)] = x.clone()
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualLoss(nn.Module):
|
||||||
|
"""Perceptual loss with commonly used style loss.
|
||||||
|
Args:
|
||||||
|
layer_weights (dict): The weight for each layer of vgg feature.
|
||||||
|
Here is an example: {'4': 1., '9': 1., '18': 1.}, which means the
|
||||||
|
5th, 10th and 18th feature layer will be extracted with weight 1.0
|
||||||
|
in calculating losses.
|
||||||
|
vgg_type (str): The type of vgg network used as feature extractor.
|
||||||
|
Default: 'vgg19'.
|
||||||
|
use_input_norm (bool): If True, normalize the input image in vgg.
|
||||||
|
Default: True.
|
||||||
|
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
|
||||||
|
loss will be calculated.
|
||||||
|
Default: True.
|
||||||
|
style_loss (bool): If `style_loss == False`, the style loss will be calculated.
|
||||||
|
Default: False.
|
||||||
|
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
|
||||||
|
this is different from the `use_input_norm` which norm the input in
|
||||||
|
in forward function of vgg according to the statistics of dataset.
|
||||||
|
Importantly, the input image must be in range [-1, 1].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True,
|
||||||
|
style_loss=False, norm_img=True, criterion='L1'):
|
||||||
|
super(PerceptualLoss, self).__init__()
|
||||||
|
self.norm_img = norm_img
|
||||||
|
self.perceptual_loss = perceptual_loss
|
||||||
|
self.style_loss = style_loss
|
||||||
|
self.layer_weights = layer_weights
|
||||||
|
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
||||||
|
use_input_norm=use_input_norm)
|
||||||
|
|
||||||
|
if criterion == 'L1':
|
||||||
|
self.criterion = torch.nn.L1Loss()
|
||||||
|
elif criterion == "L2":
|
||||||
|
self.criterion = torch.nn.MSELoss()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
|
||||||
|
|
||||||
|
def forward(self, x, gt):
|
||||||
|
"""Forward function.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||||
|
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
||||||
|
Returns:
|
||||||
|
Tensor: Forward results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.norm_img:
|
||||||
|
x = (x + 1.) * 0.5
|
||||||
|
gt = (gt + 1.) * 0.5
|
||||||
|
# extract vgg features
|
||||||
|
x_features = self.vgg(x)
|
||||||
|
gt_features = self.vgg(gt.detach())
|
||||||
|
|
||||||
|
# calculate preceptual loss
|
||||||
|
if self.perceptual_loss:
|
||||||
|
percep_loss = 0
|
||||||
|
for k in x_features.keys():
|
||||||
|
percep_loss += self.criterion(
|
||||||
|
x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||||
|
else:
|
||||||
|
percep_loss = None
|
||||||
|
|
||||||
|
# calculate style loss
|
||||||
|
if self.style_loss:
|
||||||
|
style_loss = 0
|
||||||
|
for k in x_features.keys():
|
||||||
|
style_loss += self.criterion(
|
||||||
|
self._gram_mat(x_features[k]),
|
||||||
|
self._gram_mat(gt_features[k])) * self.layer_weights[k]
|
||||||
|
else:
|
||||||
|
style_loss = None
|
||||||
|
|
||||||
|
return percep_loss, style_loss
|
||||||
|
|
||||||
|
def _gram_mat(self, x):
|
||||||
|
"""Calculate Gram matrix.
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Gram matrix.
|
||||||
|
"""
|
||||||
|
(n, c, h, w) = x.size()
|
||||||
|
features = x.view(n, c, w * h)
|
||||||
|
features_t = features.transpose(1, 2)
|
||||||
|
gram = features.bmm(features_t) / (c * h * w)
|
||||||
|
return gram
|
||||||
@ -87,10 +87,17 @@ class ContentEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, out_channels, base_channels=64, num_down_sampling=2, padding_mode='reflect', norm_type="LN"):
|
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
|
||||||
|
norm_type="LN"):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
norm_layer = select_norm_layer(norm_type)
|
norm_layer = select_norm_layer(norm_type)
|
||||||
use_bias = norm_type == "IN"
|
use_bias = norm_type == "IN"
|
||||||
|
|
||||||
|
res_block_channels = (2 ** 2) * base_channels
|
||||||
|
|
||||||
|
self.resnet = nn.Sequential(
|
||||||
|
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
|
||||||
|
|
||||||
# up sampling
|
# up sampling
|
||||||
submodules = []
|
submodules = []
|
||||||
for i in range(num_down_sampling):
|
for i in range(num_down_sampling):
|
||||||
@ -109,6 +116,7 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
x = self.resnet(x)
|
||||||
x = self.decoder(x)
|
x = self.decoder(x)
|
||||||
x = self.end_conv(x)
|
x = self.end_conv(x)
|
||||||
return x
|
return x
|
||||||
@ -142,12 +150,16 @@ class Fusion(nn.Module):
|
|||||||
|
|
||||||
@MODEL.register_module("TAHG-Generator")
|
@MODEL.register_module("TAHG-Generator")
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, style_in_channels, content_in_channels, out_channels, style_dim=512, num_blocks=8,
|
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8,
|
||||||
base_channels=64, padding_mode="reflect"):
|
base_channels=64, padding_mode="reflect"):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.style_encoder = VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
self.style_encoders = nn.ModuleDict({
|
||||||
|
"a": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||||
|
padding_mode=padding_mode, norm_type="NONE"),
|
||||||
|
"b": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||||
padding_mode=padding_mode, norm_type="NONE")
|
padding_mode=padding_mode, norm_type="NONE")
|
||||||
|
})
|
||||||
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
|
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
|
||||||
padding_mode=padding_mode, norm_type="IN")
|
padding_mode=padding_mode, norm_type="IN")
|
||||||
res_block_channels = 2 ** 2 * base_channels
|
res_block_channels = 2 ** 2 * base_channels
|
||||||
@ -155,8 +167,8 @@ class Generator(nn.Module):
|
|||||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||||
])
|
])
|
||||||
self.decoders = nn.ModuleDict({
|
self.decoders = nn.ModuleDict({
|
||||||
"a": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode),
|
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode),
|
||||||
"b": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode)
|
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode)
|
||||||
})
|
})
|
||||||
|
|
||||||
self.fc = nn.Sequential(
|
self.fc = nn.Sequential(
|
||||||
@ -168,10 +180,45 @@ class Generator(nn.Module):
|
|||||||
|
|
||||||
def forward(self, content_img, style_img, which_decoder: str = "a"):
|
def forward(self, content_img, style_img, which_decoder: str = "a"):
|
||||||
x = self.content_encoder(content_img)
|
x = self.content_encoder(content_img)
|
||||||
styles = self.fusion(self.fc(self.style_encoder(style_img)))
|
styles = self.fusion(self.fc(self.style_encoders[which_decoder](style_img)))
|
||||||
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
|
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
|
||||||
for i, ar in enumerate(self.adain_res):
|
for i, ar in enumerate(self.adain_res):
|
||||||
ar.norm1.set_style(styles[2 * i])
|
ar.norm1.set_style(styles[2 * i])
|
||||||
ar.norm2.set_style(styles[2 * i + 1])
|
ar.norm2.set_style(styles[2 * i + 1])
|
||||||
x = ar(x)
|
x = ar(x)
|
||||||
return self.decoders[which_decoder](x)
|
return self.decoders[which_decoder](x)
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL.register_module("TAHG-Discriminator")
|
||||||
|
class Discriminator(nn.Module):
|
||||||
|
def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN",
|
||||||
|
padding_mode="reflect"):
|
||||||
|
super(Discriminator, self).__init__()
|
||||||
|
|
||||||
|
norm_layer = select_norm_layer(norm_type)
|
||||||
|
use_bias = norm_type == "IN"
|
||||||
|
|
||||||
|
sequence = [nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
|
||||||
|
bias=use_bias),
|
||||||
|
norm_layer(num_features=base_channels),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
)]
|
||||||
|
# stacked intermediate layers,
|
||||||
|
# gradually increasing the number of filters
|
||||||
|
multiple_now = 1
|
||||||
|
for n in range(1, num_down_sampling + 1):
|
||||||
|
multiple_prev = multiple_now
|
||||||
|
multiple_now = min(2 ** n, 4)
|
||||||
|
sequence += [
|
||||||
|
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=3,
|
||||||
|
padding=1, stride=2, bias=use_bias),
|
||||||
|
norm_layer(base_channels * multiple_now),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
]
|
||||||
|
for _ in range(num_blocks):
|
||||||
|
sequence.append(ResidualBlock(base_channels * multiple_now, padding_mode, norm_type))
|
||||||
|
self.model = nn.Sequential(*sequence)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.model(x)
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
from model.registry import MODEL
|
from model.registry import MODEL
|
||||||
import model.GAN.residual_generator
|
import model.GAN.residual_generator
|
||||||
|
import model.GAN.TAHG
|
||||||
|
import model.GAN.UGATIT
|
||||||
import model.fewshot
|
import model.fewshot
|
||||||
|
|||||||
@ -37,7 +37,6 @@ class LayerNorm2d(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
||||||
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
|
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
|
||||||
print(x.size())
|
|
||||||
if self.affine:
|
if self.affine:
|
||||||
return self.channel_gamma * x + self.channel_beta
|
return self.channel_gamma * x + self.channel_beta
|
||||||
return x
|
return x
|
||||||
|
|||||||
32
tool/process/detect_landmark.py
Normal file
32
tool/process/detect_landmark.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import dlib
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
import sys
|
||||||
|
|
||||||
|
imagepaths = Path(sys.argv[1])
|
||||||
|
print(imagepaths)
|
||||||
|
phase = imagepaths.name
|
||||||
|
print(phase)
|
||||||
|
|
||||||
|
detector = dlib.get_frontal_face_detector()
|
||||||
|
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
|
||||||
|
|
||||||
|
if not os.path.isdir(phase):
|
||||||
|
os.makedirs(phase)
|
||||||
|
for ip in imagepaths.glob("*.jpg"):
|
||||||
|
img = np.asarray(Image.open(ip))
|
||||||
|
img.setflags(write=True)
|
||||||
|
dets = detector(img, 1)
|
||||||
|
if len(dets) > 0:
|
||||||
|
shape = predictor(img, dets[0])
|
||||||
|
points = np.empty([68, 2], dtype=int)
|
||||||
|
for b in range(68):
|
||||||
|
points[b, 0] = shape.part(b).x
|
||||||
|
points[b, 1] = shape.part(b).y
|
||||||
|
|
||||||
|
save_name = os.path.join(phase, ip.name[:-4] + '.txt')
|
||||||
|
np.savetxt(save_name, points, fmt='%d', delimiter=',')
|
||||||
|
else:
|
||||||
|
print(ip)
|
||||||
45
tool/process/generate_edge.py
Normal file
45
tool/process/generate_edge.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import numpy as np
|
||||||
|
from skimage import feature
|
||||||
|
from pathlib import Path
|
||||||
|
from torchvision.datasets.folder import is_image_file, default_loader
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
from loss.I2I.edge_loss import HED
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
import fire
|
||||||
|
|
||||||
|
|
||||||
|
def canny_edge(img):
|
||||||
|
edge = feature.canny(np.array(img.convert("L")))
|
||||||
|
return edge
|
||||||
|
|
||||||
|
|
||||||
|
def generate(image_folder, edge_type, save_folder, device="cuda:0"):
|
||||||
|
assert edge_type in ["canny", "hed"]
|
||||||
|
image_folder = Path(image_folder)
|
||||||
|
save_folder = Path(save_folder)
|
||||||
|
if edge_type == "hed":
|
||||||
|
edge_extractor = HED("/root/network-bsds500.pytorch", norm_img=False).to(device)
|
||||||
|
elif edge_type == "canny":
|
||||||
|
edge_extractor = canny_edge
|
||||||
|
else:
|
||||||
|
raise NotImplemented
|
||||||
|
for p in image_folder.glob("*"):
|
||||||
|
if is_image_file(p.as_posix()):
|
||||||
|
rgb_img = default_loader(p)
|
||||||
|
print(p)
|
||||||
|
if edge_type == "hed":
|
||||||
|
with torch.no_grad():
|
||||||
|
img_tensor = F.to_tensor(rgb_img).to(device)
|
||||||
|
edge_tensor = edge_extractor(img_tensor)
|
||||||
|
edge = F.to_pil_image(edge_tensor.clamp(0, 1.0).squeeze().detach().cpu())
|
||||||
|
edge.save(save_folder / f"{p.stem}.{edge_type}.png")
|
||||||
|
elif edge_type == "canny":
|
||||||
|
edge = edge_extractor(rgb_img)
|
||||||
|
Image.fromarray(edge).save(save_folder / f"{p.stem}.{edge_type}.png")
|
||||||
|
else:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
fire.Fire(generate)
|
||||||
@ -88,16 +88,17 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
|||||||
engine.terminate()
|
engine.terminate()
|
||||||
|
|
||||||
|
|
||||||
def setup_tensorboard_handler(trainer: Engine, config, output_transform):
|
def setup_tensorboard_handler(trainer: Engine, config, output_transform, iter_per_epoch):
|
||||||
if config.interval.tensorboard is None:
|
if config.interval.tensorboard is None:
|
||||||
return None
|
return None
|
||||||
if idist.get_rank() == 0:
|
if idist.get_rank() == 0:
|
||||||
# Create a logger
|
# Create a logger
|
||||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||||
|
basic_event = Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||||
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
|
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
|
||||||
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
|
event_name=basic_event)
|
||||||
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
|
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
|
||||||
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
|
event_name=basic_event)
|
||||||
|
|
||||||
@trainer.on(Events.COMPLETED)
|
@trainer.on(Events.COMPLETED)
|
||||||
@idist.one_rank_only()
|
@idist.one_rank_only()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user