23333
This commit is contained in:
parent
f7b7b78669
commit
0bec02bf6d
167
configs/synthesizers/GauGAN.yml
Normal file
167
configs/synthesizers/GauGAN.yml
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
name: huawei-GauGAN-3
|
||||||
|
engine: GauGAN
|
||||||
|
result_dir: ./result
|
||||||
|
max_pairs: 1000000
|
||||||
|
|
||||||
|
misc:
|
||||||
|
random_seed: 324
|
||||||
|
|
||||||
|
handler:
|
||||||
|
clear_cuda_cache: True
|
||||||
|
set_epoch_for_dist_sampler: True
|
||||||
|
checkpoint:
|
||||||
|
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||||
|
n_saved: 2
|
||||||
|
tensorboard:
|
||||||
|
scalar: 100 # log scalar `scalar` times per epoch
|
||||||
|
image: 4 # log image `image` times per epoch
|
||||||
|
test:
|
||||||
|
random: True
|
||||||
|
images: 10
|
||||||
|
|
||||||
|
model:
|
||||||
|
generator:
|
||||||
|
_type: SPADEGenerator
|
||||||
|
_add_spectral_norm: True
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
num_blocks: 7
|
||||||
|
use_vae: False
|
||||||
|
num_z_dim: 256
|
||||||
|
# discriminator:
|
||||||
|
# _type: MultiScaleDiscriminator
|
||||||
|
# _add_spectral_norm: True
|
||||||
|
# num_scale: 2
|
||||||
|
# down_sample_method: "bilinear"
|
||||||
|
# discriminator_cfg:
|
||||||
|
# _type: PatchDiscriminator
|
||||||
|
# in_channels: 3
|
||||||
|
# base_channels: 64
|
||||||
|
# num_conv: 4
|
||||||
|
# need_intermediate_feature: True
|
||||||
|
discriminator:
|
||||||
|
_type: PatchDiscriminator
|
||||||
|
_add_spectral_norm: True
|
||||||
|
in_channels: 3
|
||||||
|
base_channels: 64
|
||||||
|
num_conv: 4
|
||||||
|
need_intermediate_feature: True
|
||||||
|
|
||||||
|
|
||||||
|
loss:
|
||||||
|
gan:
|
||||||
|
loss_type: hinge
|
||||||
|
weight: 1.0
|
||||||
|
real_label_val: 1
|
||||||
|
fake_label_val: 0.0
|
||||||
|
perceptual:
|
||||||
|
layer_weights:
|
||||||
|
"1": 0.03125
|
||||||
|
"6": 0.0625
|
||||||
|
"11": 0.125
|
||||||
|
"20": 0.25
|
||||||
|
"29": 1
|
||||||
|
criterion: 'L1'
|
||||||
|
style_loss: False
|
||||||
|
perceptual_loss: True
|
||||||
|
weight: 2
|
||||||
|
mgc:
|
||||||
|
weight: 5
|
||||||
|
fm:
|
||||||
|
weight: 5
|
||||||
|
edge:
|
||||||
|
weight: 0
|
||||||
|
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||||
|
|
||||||
|
optimizers:
|
||||||
|
generator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 1e-4
|
||||||
|
betas: [ 0, 0.9 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
discriminator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 4e-4
|
||||||
|
betas: [ 0, 0.9 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
|
||||||
|
data:
|
||||||
|
train:
|
||||||
|
scheduler:
|
||||||
|
start_proportion: 0.5
|
||||||
|
target_lr: 0
|
||||||
|
buffer_size: 50
|
||||||
|
dataloader:
|
||||||
|
batch_size: 1
|
||||||
|
shuffle: True
|
||||||
|
num_workers: 2
|
||||||
|
pin_memory: True
|
||||||
|
drop_last: True
|
||||||
|
dataset:
|
||||||
|
_type: GenerationUnpairedDataset
|
||||||
|
root_a: "/data/face2cartoon/all_face"
|
||||||
|
root_b: "/data/selfie2anime/trainB/"
|
||||||
|
random_pair: True
|
||||||
|
pipeline_a:
|
||||||
|
- Load
|
||||||
|
- RandomCrop:
|
||||||
|
size: [ 178, 178 ]
|
||||||
|
- Resize:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- RandomHorizontalFlip
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
pipeline_b:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 286, 286 ]
|
||||||
|
- RandomCrop:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- RandomHorizontalFlip
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
test:
|
||||||
|
which: video_dataset
|
||||||
|
dataloader:
|
||||||
|
batch_size: 1
|
||||||
|
shuffle: False
|
||||||
|
num_workers: 1
|
||||||
|
pin_memory: False
|
||||||
|
drop_last: False
|
||||||
|
dataset:
|
||||||
|
_type: GenerationUnpairedDataset
|
||||||
|
root_a: "/data/face2cartoon/test/human"
|
||||||
|
root_b: "/data/face2cartoon/test/anime"
|
||||||
|
random_pair: True
|
||||||
|
pipeline_a:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
pipeline_b:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
video_dataset:
|
||||||
|
_type: SingleFolderDataset
|
||||||
|
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
|
||||||
|
with_path: True
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
86
engine/GauGAN.py
Normal file
86
engine/GauGAN.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from engine.base.i2i import EngineKernel, run_kernel
|
||||||
|
from engine.util.build import build_model
|
||||||
|
from engine.util.container import GANImageBuffer, LossContainer
|
||||||
|
from engine.util.loss import gan_loss, feature_match_loss, perceptual_loss
|
||||||
|
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
|
||||||
|
from model.weight_init import generation_init_weights
|
||||||
|
|
||||||
|
|
||||||
|
class GauGANEngineKernel(EngineKernel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.gan_loss = gan_loss(config.loss.gan)
|
||||||
|
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
|
||||||
|
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "exponential_decline"))
|
||||||
|
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
|
||||||
|
|
||||||
|
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||||
|
self.discriminators.keys()}
|
||||||
|
|
||||||
|
def build_models(self) -> (dict, dict):
|
||||||
|
generators = dict(
|
||||||
|
main=build_model(self.config.model.generator)
|
||||||
|
)
|
||||||
|
discriminators = dict(
|
||||||
|
b=build_model(self.config.model.discriminator)
|
||||||
|
)
|
||||||
|
self.logger.debug(discriminators["b"])
|
||||||
|
self.logger.debug(generators["main"])
|
||||||
|
|
||||||
|
for m in chain(generators.values(), discriminators.values()):
|
||||||
|
generation_init_weights(m)
|
||||||
|
|
||||||
|
return generators, discriminators
|
||||||
|
|
||||||
|
def setup_after_g(self):
|
||||||
|
for discriminator in self.discriminators.values():
|
||||||
|
discriminator.requires_grad_(True)
|
||||||
|
|
||||||
|
def setup_before_g(self):
|
||||||
|
for discriminator in self.discriminators.values():
|
||||||
|
discriminator.requires_grad_(False)
|
||||||
|
|
||||||
|
def forward(self, batch, inference=False) -> dict:
|
||||||
|
images = dict()
|
||||||
|
with torch.set_grad_enabled(not inference):
|
||||||
|
images["a2b"] = self.generators["main"](batch["a"])
|
||||||
|
return images
|
||||||
|
|
||||||
|
def criterion_generators(self, batch, generated) -> dict:
|
||||||
|
loss = dict()
|
||||||
|
prediction_fake = self.discriminators["b"](generated["a2b"])
|
||||||
|
loss["gan"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True)
|
||||||
|
loss["mgc"] = self.mgc_loss(generated["a2b"], batch["a"])
|
||||||
|
loss["perceptual"] = self.perceptual_loss(generated["a2b"], batch["a"])
|
||||||
|
if self.fm_loss.weight > 0:
|
||||||
|
prediction_real = self.discriminators["b"](batch["b"])
|
||||||
|
loss["feature_match"] = self.fm_loss(prediction_fake, prediction_real)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def criterion_discriminators(self, batch, generated) -> dict:
|
||||||
|
loss = dict()
|
||||||
|
generated_image = self.image_buffers["b"].query(generated["a2b"].detach())
|
||||||
|
loss["b"] = (self.gan_loss(self.discriminators["b"](generated_image), False, is_discriminator=True) +
|
||||||
|
self.gan_loss(self.discriminators["b"](batch["b"]), True, is_discriminator=True)) / 2
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def intermediate_images(self, batch, generated) -> dict:
|
||||||
|
"""
|
||||||
|
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||||
|
:param batch:
|
||||||
|
:param generated: dict of images
|
||||||
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||||
|
"""
|
||||||
|
return dict(
|
||||||
|
a=[batch["a"].detach(), generated["a2b"].detach()],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run(task, config, _):
|
||||||
|
kernel = GauGANEngineKernel(config)
|
||||||
|
run_kernel(task, config, kernel)
|
||||||
@ -4,29 +4,20 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||||
from loss.gan import GANLoss
|
from loss.gan import GANLoss
|
||||||
|
|
||||||
|
|
||||||
def gan_loss(config):
|
def gan_loss(config):
|
||||||
gan_loss_cfg = OmegaConf.to_container(config)
|
gan_loss_cfg = OmegaConf.to_container(config)
|
||||||
gan_loss_cfg.pop("weight")
|
gan_loss_cfg.pop("weight")
|
||||||
gl = GANLoss(**gan_loss_cfg).to(idist.device())
|
return GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
def gan_loss_fn(prediction, target_is_real: bool, is_discriminator=False):
|
|
||||||
if isinstance(prediction, torch.Tensor):
|
|
||||||
# origin
|
def perceptual_loss(config):
|
||||||
return gl(prediction, target_is_real, is_discriminator)
|
perceptual_loss_cfg = OmegaConf.to_container(config)
|
||||||
elif isinstance(prediction, list) and isinstance(prediction[0], list):
|
perceptual_loss_cfg.pop("weight")
|
||||||
# for multi scale discriminator, e.g. MultiScaleDiscriminator
|
return PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||||
loss = 0
|
|
||||||
for p in prediction:
|
|
||||||
loss += gl(p[-1], target_is_real, is_discriminator)
|
|
||||||
return loss
|
|
||||||
elif isinstance(prediction, list) and isinstance(prediction[0], torch.Tensor):
|
|
||||||
# for discriminator set `need_intermediate_feature` true
|
|
||||||
return gl(prediction[-1], target_is_real, is_discriminator)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("not support discriminator output")
|
|
||||||
return gan_loss_fn
|
|
||||||
|
|
||||||
|
|
||||||
def pixel_loss(level):
|
def pixel_loss(level):
|
||||||
|
|||||||
20
loss/gan.py
20
loss/gan.py
@ -1,4 +1,5 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
@ -10,7 +11,7 @@ class GANLoss(nn.Module):
|
|||||||
self.fake_label_val = fake_label_val
|
self.fake_label_val = fake_label_val
|
||||||
self.loss_type = loss_type
|
self.loss_type = loss_type
|
||||||
|
|
||||||
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
def single_forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||||
"""
|
"""
|
||||||
gan loss forward
|
gan loss forward
|
||||||
:param prediction: network prediction
|
:param prediction: network prediction
|
||||||
@ -37,3 +38,20 @@ class GANLoss(nn.Module):
|
|||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
|
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
|
||||||
|
|
||||||
|
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||||
|
if isinstance(prediction, torch.Tensor):
|
||||||
|
# origin
|
||||||
|
return self.single_forward(prediction, target_is_real, is_discriminator)
|
||||||
|
elif isinstance(prediction, list):
|
||||||
|
# for multi scale discriminator, e.g. MultiScaleDiscriminator
|
||||||
|
loss = 0
|
||||||
|
for p in prediction:
|
||||||
|
loss += self.single_forward(p[-1], target_is_real, is_discriminator)
|
||||||
|
return loss
|
||||||
|
elif isinstance(prediction, tuple):
|
||||||
|
# for single discriminator set `need_intermediate_feature` true
|
||||||
|
return self.single_forward(prediction[-1], target_is_real, is_discriminator)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"not support discriminator output: {prediction}")
|
||||||
|
|
||||||
|
|||||||
@ -3,3 +3,4 @@ import model.base.normalization
|
|||||||
import model.image_translation.UGATIT
|
import model.image_translation.UGATIT
|
||||||
import model.image_translation.CycleGAN
|
import model.image_translation.CycleGAN
|
||||||
import model.image_translation.pix2pixHD
|
import model.image_translation.pix2pixHD
|
||||||
|
import model.image_translation.GauGAN
|
||||||
@ -7,7 +7,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
|
from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
|
||||||
|
from model import MODEL
|
||||||
|
|
||||||
class StyleEncoder(nn.Module):
|
class StyleEncoder(nn.Module):
|
||||||
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
|
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
|
||||||
@ -122,7 +122,7 @@ class ImprovedSPADEGenerator(nn.Module):
|
|||||||
def forward(self, seg, style=None):
|
def forward(self, seg, style=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@MODEL.register_module()
|
||||||
class SPADEGenerator(nn.Module):
|
class SPADEGenerator(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
|
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
|
||||||
padding_mode='reflect', activation_type="LeakyReLU"):
|
padding_mode='reflect', activation_type="LeakyReLU"):
|
||||||
@ -156,11 +156,8 @@ class SPADEGenerator(nn.Module):
|
|||||||
)
|
)
|
||||||
))
|
))
|
||||||
self.sequence = nn.Sequential(*sequence)
|
self.sequence = nn.Sequential(*sequence)
|
||||||
self.output_converter = nn.Sequential(
|
self.output_converter = Conv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||||
ReverseConv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")
|
||||||
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"),
|
|
||||||
nn.Tanh()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, seg, z=None):
|
def forward(self, seg, z=None):
|
||||||
if self.use_vae:
|
if self.use_vae:
|
||||||
|
|||||||
@ -65,6 +65,7 @@ def generation_init_weights(module, init_type='normal', init_gain=0.02):
|
|||||||
elif classname.find('BatchNorm2d') != -1:
|
elif classname.find('BatchNorm2d') != -1:
|
||||||
# BatchNorm Layer's weight is not a matrix;
|
# BatchNorm Layer's weight is not a matrix;
|
||||||
# only normal distribution applies.
|
# only normal distribution applies.
|
||||||
|
if m.weight is not None:
|
||||||
normal_init(m, 1.0, init_gain)
|
normal_init(m, 1.0, init_gain)
|
||||||
|
|
||||||
assert isinstance(module, nn.Module)
|
assert isinstance(module, nn.Module)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user