update a lot

This commit is contained in:
Ray Wong 2020-09-07 21:38:10 +08:00
parent ab545843bf
commit 97ded53b30
6 changed files with 452 additions and 4 deletions

View File

@ -0,0 +1,146 @@
name: self2anime-TSIT
engine: TSIT
result_dir: ./result
max_pairs: 1500000
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
misc:
random_seed: 324
model:
generator:
_type: TSIT-Generator
_bn_to_sync_bn: True
style_in_channels: 3
content_in_channels: 3
num_blocks: 5
input_layer_type: "conv7x7"
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: PatchDiscriminator
in_channels: 3
base_channels: 64
use_spectral: True
need_intermediate_feature: True
loss:
gan:
loss_type: hinge
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 1
style:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L2'
style_loss: True
perceptual_loss: False
weight: 0
fm:
level: 1
weight: 1
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDatasetWithEdge
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
edge_type: "landmark_hed"
size: [ 128, 128 ]
random_pair: True
pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
dataloader:
batch_size: 8
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDatasetWithEdge
root_a: "/data/i2i/VoxCeleb2Anime/testA"
root_b: "/data/i2i/VoxCeleb2Anime/testB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
edge_type: "landmark_hed"
random_pair: False
size: [ 128, 128 ]
pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

106
engine/TSIT.py Normal file
View File

@ -0,0 +1,106 @@
from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TSITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
def build_models(self) -> (dict, dict):
generators = dict(
main=build_model(self.config.model.generator)
)
discriminators = dict(
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["b"])
self.logger.debug(generators["main"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
with torch.set_grad_enabled(not inference):
fake = dict(
b=self.generators["main"](content_img=batch["a"], style_img=batch["b"])
)
return fake
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
for phase in "b":
pred_fake = self.discriminators[phase](generated[phase])
loss[f"gan_{phase}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
if self.config.loss.fm.weight > 0 and phase == "b":
pred_real = self.discriminators[phase](batch[phase])
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase])
pred_fake = self.discriminators[phase](generated[phase].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
return dict(
b=[batch["a"].detach(), batch["b"].detach(), generated["b"].detach()]
)
def run(task, config, _):
kernel = TSITEngineKernel(config)
run_kernel(task, config, kernel)

192
model/GAN/TSIT.py Normal file
View File

@ -0,0 +1,192 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
from model.normalization import AdaptiveInstanceNorm2d
from model.normalization import select_norm_layer
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, padding_mode='zeros', norm_type="IN", use_bias=None,
use_spectral=True):
super().__init__()
self.padding_mode = padding_mode
self.use_bias = use_bias
self.use_spectral = use_spectral
if use_bias is None:
# Only for IN, use bias since it does not have affine parameters.
self.use_bias = norm_type == "IN"
norm_layer = select_norm_layer(norm_type)
self.main = nn.Sequential(
self.conv_block(in_channels, in_channels),
norm_layer(in_channels),
nn.LeakyReLU(0.2, inplace=True),
self.conv_block(in_channels, out_channels),
norm_layer(out_channels),
nn.LeakyReLU(0.2, inplace=True),
)
self.skip = nn.Sequential(
self.conv_block(in_channels, out_channels, padding=0, kernel_size=1),
norm_layer(out_channels),
nn.LeakyReLU(0.2, inplace=True),
)
def conv_block(self, in_channels, out_channels, kernel_size=3, padding=1):
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding,
padding_mode=self.padding_mode, bias=self.use_bias)
if self.use_spectral:
return nn.utils.spectral_norm(conv)
else:
return conv
def forward(self, x):
return self.main(x) + self.skip(x)
class Interpolation(nn.Module):
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolation, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
recompute_scale_factor=False)
def __repr__(self):
return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
class FADE(nn.Module):
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
super().__init__()
self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
padding_mode="zeros")
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
padding_mode="zeros")
def forward(self, x, feature):
alpha = self.alpha_conv(feature)
beta = self.beta_conv(feature)
x = self.bn(x)
return alpha * x + beta
class FADEResBlock(nn.Module):
def __init__(self, use_spectral, features_channels, in_channels, out_channels):
super().__init__()
self.main = nn.Sequential(
FADE(use_spectral, features_channels, in_channels),
nn.LeakyReLU(0.2, inplace=True),
conv_block(use_spectral, in_channels, in_channels, kernel_size=3, padding=1),
FADE(use_spectral, features_channels, in_channels),
nn.LeakyReLU(0.2, inplace=True),
conv_block(use_spectral, in_channels, out_channels, kernel_size=3, padding=1),
)
self.skip = nn.Sequential(
FADE(use_spectral, features_channels, in_channels),
nn.LeakyReLU(0.2, inplace=True),
conv_block(use_spectral, in_channels, out_channels, kernel_size=1, padding=0),
)
self.up_sample = Interpolation(2, mode="nearest")
@staticmethod
def forward_with_fade(module, x, feature):
for layer in module:
if layer.__class__.__name__ == "FADE":
x = layer(x, feature)
else:
x = layer(x)
return x
def forward(self, x, feature):
out = self.forward_with_fade(self.main, x, feature) + self.forward_with_fade(self.main, x, feature)
return self.up_sample(out)
def conv_block(use_spectral, in_channels, out_channels, **kwargs):
conv = nn.Conv2d(in_channels, out_channels, **kwargs)
return nn.utils.spectral_norm(conv) if use_spectral else conv
@MODEL.register_module("TSIT-Generator")
class TSITGenerator(nn.Module):
def __init__(self, num_blocks=7, base_channels=64, content_in_channels=3, style_in_channels=3,
out_channels=3, use_spectral=True, input_layer_type="conv1x1"):
super().__init__()
self.num_blocks = num_blocks
self.base_channels = base_channels
self.use_spectral = use_spectral
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type)
self.content_stream = self.build_stream()
self.style_stream = self.build_stream()
self.generator = self.build_generator()
self.end_conv = nn.Sequential(
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
nn.Tanh()
)
def build_generator(self):
stream_sequence = []
multiple_now = min(2 ** self.num_blocks, 2 ** 4)
for i in range(1, self.num_blocks + 1):
m = self.num_blocks - i
multiple_prev = multiple_now
multiple_now = min(2 ** m, 2 ** 4)
stream_sequence.append(nn.Sequential(
AdaptiveInstanceNorm2d(multiple_prev * self.base_channels),
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
multiple_now * self.base_channels)
))
return nn.ModuleList(stream_sequence)
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
if input_layer_type == "conv7x7":
return nn.Sequential(
conv_block(self.use_spectral, in_channels, out_channels, kernel_size=7, stride=1,
padding_mode="zeros", padding=3, bias=True),
select_norm_layer("IN")(out_channels),
nn.ReLU(inplace=True)
)
elif input_layer_type == "conv1x1":
return conv_block(self.use_spectral, in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
raise NotImplemented
def build_stream(self):
multiple_now = 1
stream_sequence = []
for i in range(1, self.num_blocks + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
stream_sequence.append(nn.Sequential(
Interpolation(scale_factor=0.5, mode="nearest"),
ResBlock(multiple_prev * self.base_channels, multiple_now * self.base_channels,
use_spectral=self.use_spectral)
))
return nn.ModuleList(stream_sequence)
def forward(self, content_img, style_img):
c = self.content_input_layer(content_img)
s = self.style_input_layer(style_img)
content_features = []
style_features = []
for i in range(self.num_blocks):
s = self.style_stream[i](s)
c = self.content_stream[i](c)
content_features.append(c)
style_features.append(s)
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
for i in range(self.num_blocks):
m = - i - 1
layer = self.generator[i]
layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1))
z = layer[0](z)
z = layer[1](z, content_features[m])
return self.end_conv(z)

View File

@ -4,3 +4,4 @@ import model.GAN.TAFG
import model.GAN.UGATIT
import model.GAN.wrapper
import model.GAN.base
import model.GAN.TSIT

View File

@ -1,6 +1,7 @@
import torch.nn as nn
import functools
import torch
import torch.nn as nn
def select_norm_layer(norm_type):

8
run.sh
View File

@ -5,16 +5,18 @@ TASK=$2
GPUS=$3
MORE_ARG=${*:4}
RANDOM_MASTER=$(shuf -i 2000-65000 -n 1)
_command="print(len('${GPUS}'.split(',')))"
GPU_COUNT=$(python3 -c "${_command}")
echo "GPU_COUNT:${GPU_COUNT}"
echo CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed "$MORE_ARG"
CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
--master_port=${RANDOM_MASTER} \
main.py "$TASK" "$CONFIG" $MORE_ARG --backup_config --setup_output_dir --setup_random_seed