v2
This commit is contained in:
parent
0019d4034c
commit
376f5caeb7
@ -1,6 +1,6 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="PublishConfigData" autoUpload="Always" serverName="21d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
<component name="PublishConfigData" autoUpload="Always" serverName="14d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||||
<serverData>
|
<serverData>
|
||||||
<paths name="14d">
|
<paths name="14d">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
name: selfie2anime-cycleGAN
|
name: huawei-cycylegan-7
|
||||||
engine: CycleGAN
|
engine: CycleGAN
|
||||||
result_dir: ./result
|
result_dir: ./result
|
||||||
max_pairs: 1000000
|
max_pairs: 1000000
|
||||||
@ -27,18 +27,33 @@ model:
|
|||||||
out_channels: 3
|
out_channels: 3
|
||||||
base_channels: 64
|
base_channels: 64
|
||||||
num_blocks: 9
|
num_blocks: 9
|
||||||
|
use_transpose_conv: False
|
||||||
|
pre_activation: True
|
||||||
|
# discriminator:
|
||||||
|
# _type: MultiScaleDiscriminator
|
||||||
|
# _add_spectral_norm: True
|
||||||
|
# num_scale: 2
|
||||||
|
# down_sample_method: "bilinear"
|
||||||
|
# discriminator_cfg:
|
||||||
|
# _type: PatchDiscriminator
|
||||||
|
# in_channels: 3
|
||||||
|
# base_channels: 64
|
||||||
|
# num_conv: 4
|
||||||
|
# need_intermediate_feature: True
|
||||||
discriminator:
|
discriminator:
|
||||||
_type: PatchDiscriminator
|
_type: PatchDiscriminator
|
||||||
_add_spectral_norm: True
|
_add_spectral_norm: True
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
base_channels: 64
|
base_channels: 64
|
||||||
num_conv: 4
|
num_conv: 4
|
||||||
|
need_intermediate_feature: False
|
||||||
|
|
||||||
|
|
||||||
loss:
|
loss:
|
||||||
gan:
|
gan:
|
||||||
loss_type: lsgan
|
loss_type: hinge
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
real_label_val: 1.0
|
real_label_val: 1
|
||||||
fake_label_val: 0.0
|
fake_label_val: 0.0
|
||||||
cycle:
|
cycle:
|
||||||
level: 1
|
level: 1
|
||||||
@ -47,17 +62,22 @@ loss:
|
|||||||
level: 1
|
level: 1
|
||||||
weight: 10.0
|
weight: 10.0
|
||||||
mgc:
|
mgc:
|
||||||
weight: 5
|
weight: 1
|
||||||
|
fm:
|
||||||
|
weight: 0
|
||||||
|
edge:
|
||||||
|
weight: 0
|
||||||
|
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||||
|
|
||||||
optimizers:
|
optimizers:
|
||||||
generator:
|
generator:
|
||||||
_type: Adam
|
_type: Adam
|
||||||
lr: 0.0001
|
lr: 1e-4
|
||||||
betas: [ 0.5, 0.999 ]
|
betas: [ 0.5, 0.999 ]
|
||||||
weight_decay: 0.0001
|
weight_decay: 0.0001
|
||||||
discriminator:
|
discriminator:
|
||||||
_type: Adam
|
_type: Adam
|
||||||
lr: 1e-4
|
lr: 4e-4
|
||||||
betas: [ 0.5, 0.999 ]
|
betas: [ 0.5, 0.999 ]
|
||||||
weight_decay: 0.0001
|
weight_decay: 0.0001
|
||||||
|
|
||||||
@ -75,10 +95,21 @@ data:
|
|||||||
drop_last: True
|
drop_last: True
|
||||||
dataset:
|
dataset:
|
||||||
_type: GenerationUnpairedDataset
|
_type: GenerationUnpairedDataset
|
||||||
root_a: "/data/i2i/selfie2anime/trainA"
|
root_a: "/data/face2cartoon/all_face"
|
||||||
root_b: "/data/i2i/selfie2anime/trainB"
|
root_b: "/data/selfie2anime/trainB/"
|
||||||
random_pair: True
|
random_pair: True
|
||||||
pipeline:
|
pipeline_a:
|
||||||
|
- Load
|
||||||
|
- RandomCrop:
|
||||||
|
size: [ 178, 178 ]
|
||||||
|
- Resize:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- RandomHorizontalFlip
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
pipeline_b:
|
||||||
- Load
|
- Load
|
||||||
- Resize:
|
- Resize:
|
||||||
size: [ 286, 286 ]
|
size: [ 286, 286 ]
|
||||||
@ -99,10 +130,18 @@ data:
|
|||||||
drop_last: False
|
drop_last: False
|
||||||
dataset:
|
dataset:
|
||||||
_type: GenerationUnpairedDataset
|
_type: GenerationUnpairedDataset
|
||||||
root_a: "/data/i2i/selfie2anime/testA"
|
root_a: "/data/face2cartoon/test/human"
|
||||||
root_b: "/data/i2i/selfie2anime/testB"
|
root_b: "/data/face2cartoon/test/anime"
|
||||||
random_pair: False
|
random_pair: True
|
||||||
pipeline:
|
pipeline_a:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
pipeline_b:
|
||||||
- Load
|
- Load
|
||||||
- Resize:
|
- Resize:
|
||||||
size: [ 256, 256 ]
|
size: [ 256, 256 ]
|
||||||
|
|||||||
@ -38,9 +38,9 @@ class SingleFolderDataset(Dataset):
|
|||||||
|
|
||||||
@DATASET.register_module()
|
@DATASET.register_module()
|
||||||
class GenerationUnpairedDataset(Dataset):
|
class GenerationUnpairedDataset(Dataset):
|
||||||
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
|
def __init__(self, root_a, root_b, random_pair, pipeline_a, pipeline_b, with_path=False):
|
||||||
self.A = SingleFolderDataset(root_a, pipeline, with_path)
|
self.A = SingleFolderDataset(root_a, pipeline_a, with_path)
|
||||||
self.B = SingleFolderDataset(root_b, pipeline, with_path)
|
self.B = SingleFolderDataset(root_b, pipeline_b, with_path)
|
||||||
self.with_path = with_path
|
self.with_path = with_path
|
||||||
self.random_pair = random_pair
|
self.random_pair = random_pair
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
|
import ignite.distributed as idist
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from engine.base.i2i import EngineKernel, run_kernel
|
from engine.base.i2i import EngineKernel, run_kernel
|
||||||
from engine.util.build import build_model
|
from engine.util.build import build_model
|
||||||
from engine.util.container import GANImageBuffer, LossContainer
|
from engine.util.container import GANImageBuffer, LossContainer
|
||||||
from engine.util.loss import pixel_loss, gan_loss
|
from engine.util.loss import pixel_loss, gan_loss, feature_match_loss
|
||||||
|
from loss.I2I.edge_loss import EdgeLoss
|
||||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
|
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
|
||||||
from model.weight_init import generation_init_weights
|
from model.weight_init import generation_init_weights
|
||||||
|
|
||||||
@ -17,7 +19,10 @@ class CycleGANEngineKernel(EngineKernel):
|
|||||||
self.gan_loss = gan_loss(config.loss.gan)
|
self.gan_loss = gan_loss(config.loss.gan)
|
||||||
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
|
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
|
||||||
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
|
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
|
||||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss())
|
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
|
||||||
|
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
|
||||||
|
self.edge_loss = LossContainer(config.loss.edge.weight, EdgeLoss(
|
||||||
|
"HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(idist.device()))
|
||||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||||
self.discriminators.keys()}
|
self.discriminators.keys()}
|
||||||
|
|
||||||
@ -64,8 +69,12 @@ class CycleGANEngineKernel(EngineKernel):
|
|||||||
loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph])
|
loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph])
|
||||||
loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph])
|
loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph])
|
||||||
loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph])
|
loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph])
|
||||||
loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(
|
prediction_fake = self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"])
|
||||||
self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"]), True)
|
loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True)
|
||||||
|
if self.fm_loss.weight > 0:
|
||||||
|
prediction_real = self.discriminators[ph](batch[ph])
|
||||||
|
loss[f"feature_match_{ph}"] = self.fm_loss(prediction_fake, prediction_real)
|
||||||
|
loss[f"edge_{ph}"] = self.edge_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph], gt_is_edge=False)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def criterion_discriminators(self, batch, generated) -> dict:
|
def criterion_discriminators(self, batch, generated) -> dict:
|
||||||
|
|||||||
@ -101,9 +101,12 @@ class EngineKernel(object):
|
|||||||
|
|
||||||
|
|
||||||
def _remove_no_grad_loss(loss_dict):
|
def _remove_no_grad_loss(loss_dict):
|
||||||
|
need_to_pop = []
|
||||||
for k in loss_dict:
|
for k in loss_dict:
|
||||||
if not isinstance(loss_dict[k], torch.Tensor):
|
if not isinstance(loss_dict[k], torch.Tensor):
|
||||||
loss_dict.pop(k)
|
need_to_pop.append(k)
|
||||||
|
for k in need_to_pop:
|
||||||
|
loss_dict.pop(k)
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -23,3 +23,19 @@ def mse_loss(x, target_flag):
|
|||||||
|
|
||||||
def bce_loss(x, target_flag):
|
def bce_loss(x, target_flag):
|
||||||
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||||
|
|
||||||
|
|
||||||
|
def feature_match_loss(level, weight_policy):
|
||||||
|
compare_loss = pixel_loss(level)
|
||||||
|
assert weight_policy in ["same", "exponential_decline"]
|
||||||
|
|
||||||
|
def fm_loss(generated_features, target_features):
|
||||||
|
num_scale = len(generated_features)
|
||||||
|
loss = 0
|
||||||
|
for s_i in range(num_scale):
|
||||||
|
for i in range(len(generated_features[s_i]) - 1):
|
||||||
|
weight = 1 if weight_policy == "same" else 2 ** i
|
||||||
|
loss += weight * compare_loss(generated_features[s_i][i], target_features[s_i][i].detach()) / num_scale
|
||||||
|
return loss
|
||||||
|
|
||||||
|
return fm_loss
|
||||||
|
|||||||
@ -105,10 +105,12 @@ class MGCLoss(nn.Module):
|
|||||||
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
|
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, beta=0.5, lambda_=0.05, device=idist.device()):
|
def __init__(self, mi_to_loss_way="opposite", beta=0.5, lambda_=0.05, device=idist.device()):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.lambda_ = lambda_
|
self.lambda_ = lambda_
|
||||||
|
assert mi_to_loss_way in ["opposite", "reciprocal"]
|
||||||
|
self.mi_to_loss_way = mi_to_loss_way
|
||||||
mu_y, mu_x = torch.meshgrid([torch.arange(-1, 1.25, 0.25), torch.arange(-1, 1.25, 0.25)])
|
mu_y, mu_x = torch.meshgrid([torch.arange(-1, 1.25, 0.25), torch.arange(-1, 1.25, 0.25)])
|
||||||
self.mu_x = mu_x.flatten().to(device)
|
self.mu_x = mu_x.flatten().to(device)
|
||||||
self.mu_y = mu_y.flatten().to(device)
|
self.mu_y = mu_y.flatten().to(device)
|
||||||
@ -134,6 +136,8 @@ class MGCLoss(nn.Module):
|
|||||||
|
|
||||||
def forward(self, fake, real):
|
def forward(self, fake, real):
|
||||||
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_, self.R)
|
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_, self.R)
|
||||||
|
if self.mi_to_loss_way == "reciprocal":
|
||||||
|
return 1/rSMI.mean()
|
||||||
return -rSMI.mean()
|
return -rSMI.mean()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
14
loss/gan.py
14
loss/gan.py
@ -1,5 +1,6 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class GANLoss(nn.Module):
|
class GANLoss(nn.Module):
|
||||||
@ -10,7 +11,7 @@ class GANLoss(nn.Module):
|
|||||||
self.fake_label_val = fake_label_val
|
self.fake_label_val = fake_label_val
|
||||||
self.loss_type = loss_type
|
self.loss_type = loss_type
|
||||||
|
|
||||||
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
def single_forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||||
"""
|
"""
|
||||||
gan loss forward
|
gan loss forward
|
||||||
:param prediction: network prediction
|
:param prediction: network prediction
|
||||||
@ -37,3 +38,14 @@ class GANLoss(nn.Module):
|
|||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
|
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
|
||||||
|
|
||||||
|
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||||
|
if isinstance(prediction, torch.Tensor):
|
||||||
|
# origin
|
||||||
|
return self.single_forward(prediction, target_is_real, is_discriminator)
|
||||||
|
elif isinstance(prediction, list):
|
||||||
|
# for multi scale discriminator, e.g. MultiScaleDiscriminator
|
||||||
|
loss = 0
|
||||||
|
for p in prediction:
|
||||||
|
loss += self.single_forward(p[-1], target_is_real, is_discriminator)
|
||||||
|
return loss
|
||||||
|
|||||||
@ -2,3 +2,4 @@ from model.registry import MODEL, NORMALIZATION
|
|||||||
import model.base.normalization
|
import model.base.normalization
|
||||||
import model.image_translation.UGATIT
|
import model.image_translation.UGATIT
|
||||||
import model.image_translation.CycleGAN
|
import model.image_translation.CycleGAN
|
||||||
|
import model.image_translation.pix2pixHD
|
||||||
|
|||||||
29
model/image_translation/pix2pixHD.py
Normal file
29
model/image_translation/pix2pixHD.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from model import MODEL
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL.register_module()
|
||||||
|
class MultiScaleDiscriminator(nn.Module):
|
||||||
|
def __init__(self, num_scale, discriminator_cfg, down_sample_method="avg"):
|
||||||
|
super().__init__()
|
||||||
|
assert down_sample_method in ["avg", "bilinear"]
|
||||||
|
self.down_sample_method = down_sample_method
|
||||||
|
|
||||||
|
self.discriminator_list = nn.ModuleList([
|
||||||
|
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
|
||||||
|
])
|
||||||
|
|
||||||
|
def down_sample(self, x):
|
||||||
|
if self.down_sample_method == "avg":
|
||||||
|
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
|
||||||
|
if self.down_sample_method == "bilinear":
|
||||||
|
return F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
results = []
|
||||||
|
for discriminator in self.discriminator_list:
|
||||||
|
results.append(discriminator(x))
|
||||||
|
x = self.down_sample(x)
|
||||||
|
return results
|
||||||
@ -53,11 +53,9 @@ class _Registry:
|
|||||||
else:
|
else:
|
||||||
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
|
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
|
||||||
|
|
||||||
for k in args:
|
for invalid_key in [k for k in args.keys() if k.startswith("_")]:
|
||||||
assert isinstance(k, str)
|
warnings.warn(f"got param start with `_`: {invalid_key}, will remove it")
|
||||||
if k.startswith("_"):
|
args.pop(invalid_key)
|
||||||
warnings.warn(f"got param start with `_`: {k}, will remove it")
|
|
||||||
args.pop(k)
|
|
||||||
|
|
||||||
if not (isinstance(default_args, dict) or default_args is None):
|
if not (isinstance(default_args, dict) or default_args is None):
|
||||||
raise TypeError('default_args must be a dict or None, '
|
raise TypeError('default_args must be a dict or None, '
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user