Compare commits

...

3 Commits

Author SHA1 Message Date
e71e8d95d0 TAHG 0.0.3 2020-09-01 09:02:04 +08:00
89b54105c7 TAHG 0.0.2 2020-08-30 14:44:40 +08:00
715a2e64a1 TANG 0.0.1 2020-08-30 09:34:23 +08:00
18 changed files with 918 additions and 14 deletions

View File

@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="15d" remoteFilesAllowedToDisappearOnAutoupload="false">
<component name="PublishConfigData" autoUpload="Always" serverName="22d" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="14d">
<serverdata>
@ -16,6 +16,13 @@
</mappings>
</serverdata>
</paths>
<paths name="22d">
<serverdata>
<mappings>
<mapping deploy="/raycv" local="$PROJECT_DIR$" />
</mappings>
</serverdata>
</paths>
</serverData>
<option name="myAutoUpload" value="ALWAYS" />
</component>

View File

@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="22d-base" project-jdk-type="Python SDK" />
</project>

6
.idea/other.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PySciProjectComponent">
<option name="PY_SCI_VIEW_SUGGESTED" value="true" />
</component>
</project>

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="22d-base" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">

View File

@ -0,0 +1,132 @@
name: TAHG
engine: TAHG
result_dir: ./result
max_pairs: 1000000
distributed:
model:
# broadcast_buffers: False
misc:
random_seed: 324
checkpoint:
epoch_interval: 1 # one checkpoint every 1 epoch
n_saved: 2
interval:
print_per_iteration: 10 # print once per 10 iteration
tensorboard:
scalar: 100
image: 2
model:
generator:
_type: TAHG-Generator
style_in_channels: 3
content_in_channels: 1
num_blocks: 4
discriminator:
_type: TAHG-Discriminator
in_channels: 3
loss:
gan:
loss_type: lsgan
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
edge:
criterion: 'L1'
hed_pretrained_model_path: "./network-bsds500.pytorch"
weight: 1
perceptual:
layer_weights:
"3": 1.0
# "0": 1.0
# "5": 1.0
# "10": 1.0
# "19": 1.0
criterion: 'L2'
style_loss: True
perceptual_loss: False
weight: 20
recon:
level: 1
weight: 1
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 1e-4
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 160
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDatasetWithEdge
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
edge_type: "hed"
size: [128, 128]
random_pair: True
pipeline:
- Load
- Resize:
size: [128, 128]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
dataloader:
batch_size: 8
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDatasetWithEdge
root_a: "/data/i2i/VoxCeleb2Anime/testA"
root_b: "/data/i2i/VoxCeleb2Anime/testB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
edge_type: "hed"
random_pair: False
size: [128, 128]
pipeline:
- Load
- Resize:
size: [128, 128]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@ -1,10 +1,13 @@
import os
import pickle
from pathlib import Path
from collections import defaultdict
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import functional as F
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
import lmdb
@ -171,3 +174,37 @@ class GenerationUnpairedDataset(Dataset):
def __repr__(self):
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, size=(256, 256)):
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
assert self.edges_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
def get_edge(self, origin_path):
op = Path(origin_path)
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
img = Image.open(edge_path).resize(self.size)
return F.to_tensor(img)
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output = dict()
output["a"], path_a = self.A[a_idx]
output["b"], path_b = self.B[b_idx]
output["edge_a"] = self.get_edge(path_a)
output["edge_b"] = self.get_edge(path_b)
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"

0
data/util/__init__.py Normal file
View File

View File

@ -0,0 +1,67 @@
import numpy as np
import cv2
from skimage import feature
# https://ibug.doc.ic.ac.uk/resources/facial-point-annotations/
DLIB_LANDMARKS_PART_LIST = [
[list(range(0, 17)) + list(range(68, 83)) + [0]], # face
[range(17, 22)], # right eyebrow
[range(22, 27)], # left eyebrow
[[28, 31], range(31, 36), [35, 28]], # nose
[[36, 37, 38, 39], [39, 40, 41, 36]], # right eye
[[42, 43, 44, 45], [45, 46, 47, 42]], # left eye
[range(48, 55), [54, 55, 56, 57, 58, 59, 48]], # mouth
[range(60, 65), [64, 65, 66, 67, 60]] # tongue
]
def dist_tensor(key_points, size=(256, 256)):
dist_list = []
for edge_list in DLIB_LANDMARKS_PART_LIST:
for edge in edge_list:
pts = key_points[edge, :]
im_edge = np.zeros(size, np.uint8)
cv2.polylines(im_edge, [pts], isClosed=False, color=255, thickness=1)
im_dist = cv2.distanceTransform(255 - im_edge, cv2.DIST_L1, 3)
im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
dist_list.append(im_dist)
return np.stack(dist_list)
def read_keypoints(kp_path, origin_size=(256, 256), size=(256, 256), thickness=1):
key_points = np.loadtxt(kp_path, delimiter=",").astype(np.int32)
if origin_size != size:
# resize key_points using simplest way...
key_points = (key_points * (np.array(size) / np.array(origin_size))).astype(np.int32)
# add upper half face by symmetry
face_pts = key_points[:17, :]
face_baseline_y = (face_pts[0, 1] + face_pts[-1, 1]) // 2
upper_symmetry_face_pts = face_pts[1:-1, :].copy()
# keep x untouched
upper_symmetry_face_pts[:, 1] = face_baseline_y + (face_baseline_y - upper_symmetry_face_pts[:, 1]) * 2 // 3
key_points = np.vstack((key_points, upper_symmetry_face_pts[::-1, :]))
assert key_points.shape == (83, 2)
part_labels = np.zeros((len(DLIB_LANDMARKS_PART_LIST), *size), np.uint8)
part_edge = np.zeros(size, np.uint8)
for i, edge_list in enumerate(DLIB_LANDMARKS_PART_LIST):
indices = [item for sublist in edge_list for item in sublist]
pts = key_points[indices, :]
cv2.fillPoly(part_labels[i], pts=[pts], color=1)
if i in [1, 2]:
# some part of landmarks is a line
cv2.polylines(part_edge, [pts], isClosed=False, color=1, thickness=thickness)
else:
cv2.drawContours(part_edge, [pts], 0, color=1, thickness=thickness)
return key_points, part_labels, part_edge
def edge_map(img, part_labels, part_edge, remove_edge_within_face=True):
edges = feature.canny(np.array(img.convert("L")))
if remove_edge_within_face:
edges = edges * (part_labels.sum(0) == 0) # remove edges within face
edges = part_edge + edges
return edges

245
engine/TAHG.py Normal file
View File

@ -0,0 +1,245 @@
from itertools import chain
from math import ceil
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from omegaconf import OmegaConf, read_write
import data
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from util.image import make_2d_grid
from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_model, build_optimizer
def build_lr_schedulers(optimizers, config):
g_milestones_values = [
(0, config.optimizers.generator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
d_milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
return dict(
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
)
def get_trainer(config, logger, train_data_loader):
generator = build_model(config.model.generator, config.distributed.model)
discriminators = dict(
a=build_model(config.model.discriminator, config.distributed.model),
b=build_model(config.model.discriminator, config.distributed.model),
)
generation_init_weights(generator)
for m in discriminators.values():
generation_init_weights(m)
logger.debug(discriminators["a"])
logger.debug(generator)
optimizers = dict(
g=build_optimizer(generator.parameters(), config.optimizers.generator),
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info(f"build optimizers:\n{optimizers}")
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
edge_loss_cfg = OmegaConf.to_container(config.loss.edge)
edge_loss_cfg.pop("weight")
edge_loss = EdgeLoss(**edge_loss_cfg).to(idist.device())
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real = dict(a=batch["a"], b=batch["b"])
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
)
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
optimizers["g"].zero_grad()
loss_g = dict()
for d in "ab":
discriminators[d].requires_grad_(False)
pred_fake = discriminators[d](fake[d])
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
_, t = perceptual_loss(fake[d], real[d])
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], batch["edge_a"])
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
loss_g["recon_b"] = config.loss.recon.weight * recon_loss(rec_b, real["b"])
loss_g["recon_bb"] = config.loss.recon.weight * recon_loss(rec_bb, real["b"])
sum(loss_g.values()).backward()
optimizers["g"].step()
for discriminator in discriminators.values():
discriminator.requires_grad_(True)
optimizers["d"].zero_grad()
loss_d = dict()
for k in discriminators.keys():
pred_real = discriminators[k](real[k])
pred_fake = discriminators[k](image_buffers[k].query(fake[k].detach()))
loss_d[f"gan_{k}"] = (gan_loss(pred_real, True, is_discriminator=True) +
gan_loss(pred_fake, False, is_discriminator=True)) / 2
sum(loss_d.values()).backward()
optimizers["d"].step()
generated_img = {f"real_{k}": real[k].detach() for k in real}
generated_img["rec_b"] = rec_b.detach()
generated_img["rec_bb"] = rec_b.detach()
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
generated_img.update({f"edge_{k}": batch[f"edge_{k}"].expand(-1, 3, -1, -1).detach() for k in "ab"})
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
},
"img": generated_img
}
trainer = Engine(_step)
trainer.logger = logger
for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update({"generator": generator})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
iter_per_epoch = len(train_data_loader)
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
if tensorboard_handler is not None:
tensorboard_handler.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
)
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
def show_images(engine):
output = engine.state.output
image_order = dict(
a=["edge_a", "real_a", "fake_a", "fake_b"],
b=["edge_b", "real_b", "rec_b", "rec_bb"]
)
for k in "ab":
tensorboard_handler.writer.add_image(
f"train/{k}",
make_2d_grid([output["img"][o] for o in image_order[k]]),
engine.state.iteration
)
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
test_images = dict(
a=[[], [], [], []],
b=[[], [], [], []]
)
for i in range(random_start, random_start + 10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
batch[k] = batch[k].view(1, *batch[k].size())
real = dict(a=batch["a"], b=batch["b"])
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
)
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
test_images["a"][0].append(batch["edge_a"])
test_images["a"][1].append(batch["a"])
test_images["a"][2].append(fake["a"])
test_images["a"][3].append(fake["b"])
test_images["b"][0].append(batch["edge_b"])
test_images["b"][1].append(batch["b"])
test_images["b"][2].append(rec_b)
test_images["b"][3].append(rec_bb)
for n in "ab":
tensorboard_handler.writer.add_image(
f"test/{n}",
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
engine.state.iteration
)
return trainer
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
with read_write(config):
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger, train_data_loader)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

0
loss/I2I/__init__.py Normal file
View File

129
loss/I2I/edge_loss.py Normal file
View File

@ -0,0 +1,129 @@
from pathlib import Path
import torch
import torch.nn as nn
from torch.nn import functional as F
class HED(nn.Module):
def __init__(self, pretrained_model_path, norm_img=True):
"""
HED module to get edge
:param pretrained_model_path: path to pretrained HED.
:param norm_img(bool): If True, the image will be normed to [0, 1]. Note that
this is different from the `use_input_norm` which norm the input in
in forward function of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
"""
super().__init__()
self.norm_img = norm_img
self.vgg_nets = nn.ModuleList([torch.nn.Sequential(
torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
), torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
), torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
), torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
), torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
)])
self.score_nets = nn.ModuleList([
torch.nn.Conv2d(in_channels=i, out_channels=1, kernel_size=1, stride=1, padding=0)
for i in [64, 128, 256, 512, 512]
])
self.combine_net = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
torch.nn.Sigmoid()
)
self.load_weights(pretrained_model_path)
self.register_buffer('mean', torch.Tensor([104.00698793, 116.66876762, 122.67891434]).view(1, 3, 1, 1))
for v in self.parameters():
v.requies_grad = False
def load_weights(self, pretrained_model_path):
checkpoint_path = Path(pretrained_model_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
m = {"One": "0", "Two": "1", "Thr": "2", "Fou": "3", "Fiv": "4"}
def replace_key(key):
if key.startswith("moduleVgg"):
return f"vgg_nets.{m[key[9:12]]}{key[12:]}"
elif key.startswith("moduleScore"):
return f"score_nets.{m[key[11:14]]}{key[14:]}"
elif key.startswith("moduleCombine"):
return f"combine_net{key[13:]}"
else:
raise ValueError("wrong checkpoint for HED")
module_dict = {replace_key(k): v for k, v in ckp.items()}
self.load_state_dict(module_dict, strict=True)
def forward(self, x):
if self.norm_img:
x = (x + 1.) * 0.5
x = x * 255.0 - self.mean
img_size = (x.size(2), x.size(3))
to_combine = []
for i in range(5):
x = self.vgg_nets[i](x)
score_x = self.score_nets[i](x)
to_combine.append(F.interpolate(input=score_x, size=img_size, mode='bilinear', align_corners=False))
out = self.combine_net(torch.cat(to_combine, 1))
return out.clamp(0.0, 1.0)
class EdgeLoss(nn.Module):
def __init__(self, edge_extractor_type="HED", norm_img=True, criterion='L1', **kwargs):
super(EdgeLoss, self).__init__()
if edge_extractor_type == "HED":
pretrained_model_path = kwargs.get("hed_pretrained_model_path")
self.edge_extractor = HED(pretrained_model_path, norm_img)
else:
raise NotImplemented(f"do not support edge_extractor_type {edge_extractor_type}")
if criterion == 'L1':
self.criterion = nn.L1Loss()
elif criterion == "L2":
self.criterion = nn.MSELoss()
else:
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
def forward(self, x, gt, gt_is_edge=True):
edge = self.edge_extractor(x)
if not gt_is_edge:
gt = self.edge_extractor(gt.detach())
loss = self.criterion(edge, gt)
return loss

155
loss/I2I/perceptual_loss.py Normal file
View File

@ -0,0 +1,155 @@
import torch
import torch.nn as nn
import torchvision.models.vgg as vgg
class PerceptualVGG(nn.Module):
"""VGG network used in calculating perceptual loss.
In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.
Args:
layer_name_list (list[str]): According to the index in this list,
forward function will return the corresponding features. This
list contains the name each layer in `vgg.feature`. An example
of this list is ['4', '10'].
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image.
Importantly, the input feature must in the range [0, 1].
Default: True.
"""
def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True):
super(PerceptualVGG, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
# get vgg model and load pretrained vgg weight
# remove _vgg from attributes to avoid `find_unused_parameters` bug
_vgg = getattr(vgg, vgg_type)(pretrained=True)
num_layers = max(map(int, layer_name_list)) + 1
assert len(_vgg.features) >= num_layers
# only borrow layers that will be used from _vgg to avoid unused params
self.vgg_layers = _vgg.features[:num_layers]
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer(
'mean',
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [-1, 1]
self.register_buffer(
'std',
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
for v in self.vgg_layers.parameters():
v.requies_grad = False
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}
for i, l in enumerate(self.vgg_layers):
x = l(x)
if str(i) in self.layer_name_list:
output[str(i)] = x.clone()
return output
class PerceptualLoss(nn.Module):
"""Perceptual loss with commonly used style loss.
Args:
layer_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'4': 1., '9': 1., '18': 1.}, which means the
5th, 10th and 18th feature layer will be extracted with weight 1.0
in calculating losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
loss will be calculated.
Default: True.
style_loss (bool): If `style_loss == False`, the style loss will be calculated.
Default: False.
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
this is different from the `use_input_norm` which norm the input in
in forward function of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
"""
def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True,
style_loss=False, norm_img=True, criterion='L1'):
super(PerceptualLoss, self).__init__()
self.norm_img = norm_img
self.perceptual_loss = perceptual_loss
self.style_loss = style_loss
self.layer_weights = layer_weights
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
use_input_norm=use_input_norm)
if criterion == 'L1':
self.criterion = torch.nn.L1Loss()
elif criterion == "L2":
self.criterion = torch.nn.MSELoss()
else:
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
def forward(self, x, gt):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.norm_img:
x = (x + 1.) * 0.5
gt = (gt + 1.) * 0.5
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
# calculate preceptual loss
if self.perceptual_loss:
percep_loss = 0
for k in x_features.keys():
percep_loss += self.criterion(
x_features[k], gt_features[k]) * self.layer_weights[k]
else:
percep_loss = None
# calculate style loss
if self.style_loss:
style_loss = 0
for k in x_features.keys():
style_loss += self.criterion(
self._gram_mat(x_features[k]),
self._gram_mat(gt_features[k])) * self.layer_weights[k]
else:
style_loss = None
return percep_loss, style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
(n, c, h, w) = x.size()
features = x.view(n, c, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram

View File

@ -87,10 +87,17 @@ class ContentEncoder(nn.Module):
class Decoder(nn.Module):
def __init__(self, out_channels, base_channels=64, num_down_sampling=2, padding_mode='reflect', norm_type="LN"):
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
norm_type="LN"):
super(Decoder, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
res_block_channels = (2 ** 2) * base_channels
self.resnet = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
@ -109,6 +116,7 @@ class Decoder(nn.Module):
)
def forward(self, x):
x = self.resnet(x)
x = self.decoder(x)
x = self.end_conv(x)
return x
@ -142,12 +150,16 @@ class Fusion(nn.Module):
@MODEL.register_module("TAHG-Generator")
class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels, out_channels, style_dim=512, num_blocks=8,
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8,
base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__()
self.num_blocks = num_blocks
self.style_encoder = VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
padding_mode=padding_mode, norm_type="NONE")
self.style_encoders = nn.ModuleDict({
"a": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
padding_mode=padding_mode, norm_type="NONE"),
"b": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
padding_mode=padding_mode, norm_type="NONE")
})
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
padding_mode=padding_mode, norm_type="IN")
res_block_channels = 2 ** 2 * base_channels
@ -155,8 +167,8 @@ class Generator(nn.Module):
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
])
self.decoders = nn.ModuleDict({
"a": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode),
"b": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode)
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode),
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode)
})
self.fc = nn.Sequential(
@ -168,10 +180,45 @@ class Generator(nn.Module):
def forward(self, content_img, style_img, which_decoder: str = "a"):
x = self.content_encoder(content_img)
styles = self.fusion(self.fc(self.style_encoder(style_img)))
styles = self.fusion(self.fc(self.style_encoders[which_decoder](style_img)))
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
for i, ar in enumerate(self.adain_res):
ar.norm1.set_style(styles[2 * i])
ar.norm2.set_style(styles[2 * i + 1])
x = ar(x)
return self.decoders[which_decoder](x)
@MODEL.register_module("TAHG-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN",
padding_mode="reflect"):
super(Discriminator, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
sequence = [nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=use_bias),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)]
# stacked intermediate layers,
# gradually increasing the number of filters
multiple_now = 1
for n in range(1, num_down_sampling + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** n, 4)
sequence += [
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=3,
padding=1, stride=2, bias=use_bias),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=True)
]
for _ in range(num_blocks):
sequence.append(ResidualBlock(base_channels * multiple_now, padding_mode, norm_type))
self.model = nn.Sequential(*sequence)
def forward(self, x):
return self.model(x)

View File

@ -1,3 +1,5 @@
from model.registry import MODEL
import model.GAN.residual_generator
import model.GAN.TAHG
import model.GAN.UGATIT
import model.fewshot

View File

@ -37,7 +37,6 @@ class LayerNorm2d(nn.Module):
def forward(self, x):
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
print(x.size())
if self.affine:
return self.channel_gamma * x + self.channel_beta
return x

View File

@ -0,0 +1,32 @@
import os
import numpy as np
import dlib
from pathlib import Path
from PIL import Image
import sys
imagepaths = Path(sys.argv[1])
print(imagepaths)
phase = imagepaths.name
print(phase)
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
if not os.path.isdir(phase):
os.makedirs(phase)
for ip in imagepaths.glob("*.jpg"):
img = np.asarray(Image.open(ip))
img.setflags(write=True)
dets = detector(img, 1)
if len(dets) > 0:
shape = predictor(img, dets[0])
points = np.empty([68, 2], dtype=int)
for b in range(68):
points[b, 0] = shape.part(b).x
points[b, 1] = shape.part(b).y
save_name = os.path.join(phase, ip.name[:-4] + '.txt')
np.savetxt(save_name, points, fmt='%d', delimiter=',')
else:
print(ip)

View File

@ -0,0 +1,45 @@
import numpy as np
from skimage import feature
from pathlib import Path
from torchvision.datasets.folder import is_image_file, default_loader
from torchvision.transforms import functional as F
from loss.I2I.edge_loss import HED
import torch
from PIL import Image
import fire
def canny_edge(img):
edge = feature.canny(np.array(img.convert("L")))
return edge
def generate(image_folder, edge_type, save_folder, device="cuda:0"):
assert edge_type in ["canny", "hed"]
image_folder = Path(image_folder)
save_folder = Path(save_folder)
if edge_type == "hed":
edge_extractor = HED("/root/network-bsds500.pytorch", norm_img=False).to(device)
elif edge_type == "canny":
edge_extractor = canny_edge
else:
raise NotImplemented
for p in image_folder.glob("*"):
if is_image_file(p.as_posix()):
rgb_img = default_loader(p)
print(p)
if edge_type == "hed":
with torch.no_grad():
img_tensor = F.to_tensor(rgb_img).to(device)
edge_tensor = edge_extractor(img_tensor)
edge = F.to_pil_image(edge_tensor.clamp(0, 1.0).squeeze().detach().cpu())
edge.save(save_folder / f"{p.stem}.{edge_type}.png")
elif edge_type == "canny":
edge = edge_extractor(rgb_img)
Image.fromarray(edge).save(save_folder / f"{p.stem}.{edge_type}.png")
else:
raise NotImplemented
if __name__ == '__main__':
fire.Fire(generate)

View File

@ -88,16 +88,17 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
engine.terminate()
def setup_tensorboard_handler(trainer: Engine, config, output_transform):
def setup_tensorboard_handler(trainer: Engine, config, output_transform, iter_per_epoch):
if config.interval.tensorboard is None:
return None
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
basic_event = Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
event_name=basic_event)
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
event_name=basic_event)
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()