This commit is contained in:
budui 2020-10-23 16:14:37 +08:00
parent f7b7b78669
commit 0bec02bf6d
7 changed files with 287 additions and 26 deletions

View File

@ -0,0 +1,167 @@
name: huawei-GauGAN-3
engine: GauGAN
result_dir: ./result
max_pairs: 1000000
misc:
random_seed: 324
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: SPADEGenerator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
num_blocks: 7
use_vae: False
num_z_dim: 256
# discriminator:
# _type: MultiScaleDiscriminator
# _add_spectral_norm: True
# num_scale: 2
# down_sample_method: "bilinear"
# discriminator_cfg:
# _type: PatchDiscriminator
# in_channels: 3
# base_channels: 64
# num_conv: 4
# need_intermediate_feature: True
discriminator:
_type: PatchDiscriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_conv: 4
need_intermediate_feature: True
loss:
gan:
loss_type: hinge
weight: 1.0
real_label_val: 1
fake_label_val: 0.0
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 2
mgc:
weight: 5
fm:
weight: 5
edge:
weight: 0
hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers:
generator:
_type: Adam
lr: 1e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/all_face"
root_b: "/data/selfie2anime/trainB/"
random_pair: True
pipeline_a:
- Load
- RandomCrop:
size: [ 178, 178 ]
- Resize:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 286, 286 ]
- RandomCrop:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/test/human"
root_b: "/data/face2cartoon/test/anime"
random_pair: True
pipeline_a:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

86
engine/GauGAN.py Normal file
View File

@ -0,0 +1,86 @@
from itertools import chain
import torch
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from engine.util.container import GANImageBuffer, LossContainer
from engine.util.loss import gan_loss, feature_match_loss, perceptual_loss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
from model.weight_init import generation_init_weights
class GauGANEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
self.gan_loss = gan_loss(config.loss.gan)
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "exponential_decline"))
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
def build_models(self) -> (dict, dict):
generators = dict(
main=build_model(self.config.model.generator)
)
discriminators = dict(
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["b"])
self.logger.debug(generators["main"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
images = dict()
with torch.set_grad_enabled(not inference):
images["a2b"] = self.generators["main"](batch["a"])
return images
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
prediction_fake = self.discriminators["b"](generated["a2b"])
loss["gan"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True)
loss["mgc"] = self.mgc_loss(generated["a2b"], batch["a"])
loss["perceptual"] = self.perceptual_loss(generated["a2b"], batch["a"])
if self.fm_loss.weight > 0:
prediction_real = self.discriminators["b"](batch["b"])
loss["feature_match"] = self.fm_loss(prediction_fake, prediction_real)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
generated_image = self.image_buffers["b"].query(generated["a2b"].detach())
loss["b"] = (self.gan_loss(self.discriminators["b"](generated_image), False, is_discriminator=True) +
self.gan_loss(self.discriminators["b"](batch["b"]), True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
return dict(
a=[batch["a"].detach(), generated["a2b"].detach()],
)
def run(task, config, _):
kernel = GauGANEngineKernel(config)
run_kernel(task, config, kernel)

View File

@ -4,29 +4,20 @@ import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
def gan_loss(config):
gan_loss_cfg = OmegaConf.to_container(config)
gan_loss_cfg.pop("weight")
gl = GANLoss(**gan_loss_cfg).to(idist.device())
def gan_loss_fn(prediction, target_is_real: bool, is_discriminator=False):
if isinstance(prediction, torch.Tensor):
# origin
return gl(prediction, target_is_real, is_discriminator)
elif isinstance(prediction, list) and isinstance(prediction[0], list):
# for multi scale discriminator, e.g. MultiScaleDiscriminator
loss = 0
for p in prediction:
loss += gl(p[-1], target_is_real, is_discriminator)
return loss
elif isinstance(prediction, list) and isinstance(prediction[0], torch.Tensor):
# for discriminator set `need_intermediate_feature` true
return gl(prediction[-1], target_is_real, is_discriminator)
else:
raise NotImplementedError("not support discriminator output")
return gan_loss_fn
return GANLoss(**gan_loss_cfg).to(idist.device())
def perceptual_loss(config):
perceptual_loss_cfg = OmegaConf.to_container(config)
perceptual_loss_cfg.pop("weight")
return PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
def pixel_loss(level):

View File

@ -1,4 +1,5 @@
import torch.nn as nn
import torch
import torch.nn.functional as F
@ -10,7 +11,7 @@ class GANLoss(nn.Module):
self.fake_label_val = fake_label_val
self.loss_type = loss_type
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
def single_forward(self, prediction, target_is_real: bool, is_discriminator=False):
"""
gan loss forward
:param prediction: network prediction
@ -37,3 +38,20 @@ class GANLoss(nn.Module):
return loss
else:
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
if isinstance(prediction, torch.Tensor):
# origin
return self.single_forward(prediction, target_is_real, is_discriminator)
elif isinstance(prediction, list):
# for multi scale discriminator, e.g. MultiScaleDiscriminator
loss = 0
for p in prediction:
loss += self.single_forward(p[-1], target_is_real, is_discriminator)
return loss
elif isinstance(prediction, tuple):
# for single discriminator set `need_intermediate_feature` true
return self.single_forward(prediction[-1], target_is_real, is_discriminator)
else:
raise NotImplementedError(f"not support discriminator output: {prediction}")

View File

@ -3,3 +3,4 @@ import model.base.normalization
import model.image_translation.UGATIT
import model.image_translation.CycleGAN
import model.image_translation.pix2pixHD
import model.image_translation.GauGAN

View File

@ -7,7 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
from model import MODEL
class StyleEncoder(nn.Module):
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
@ -122,7 +122,7 @@ class ImprovedSPADEGenerator(nn.Module):
def forward(self, seg, style=None):
pass
@MODEL.register_module()
class SPADEGenerator(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
padding_mode='reflect', activation_type="LeakyReLU"):
@ -156,11 +156,8 @@ class SPADEGenerator(nn.Module):
)
))
self.sequence = nn.Sequential(*sequence)
self.output_converter = nn.Sequential(
ReverseConv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"),
nn.Tanh()
)
self.output_converter = Conv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")
def forward(self, seg, z=None):
if self.use_vae:

View File

@ -65,7 +65,8 @@ def generation_init_weights(module, init_type='normal', init_gain=0.02):
elif classname.find('BatchNorm2d') != -1:
# BatchNorm Layer's weight is not a matrix;
# only normal distribution applies.
normal_init(m, 1.0, init_gain)
if m.weight is not None:
normal_init(m, 1.0, init_gain)
assert isinstance(module, nn.Module)
module.apply(init_func)