Compare commits
4 Commits
0019d4034c
...
8998c30c23
| Author | SHA1 | Date | |
|---|---|---|---|
| 8998c30c23 | |||
| 0bec02bf6d | |||
| f7b7b78669 | |||
| 376f5caeb7 |
@ -1,6 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="21d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="14d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="14d">
|
||||
<serverdata>
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="14d-python" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="14d-python" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
name: selfie2anime-cycleGAN
|
||||
name: huawei-cycylegan-7
|
||||
engine: CycleGAN
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
@ -27,18 +27,33 @@ model:
|
||||
out_channels: 3
|
||||
base_channels: 64
|
||||
num_blocks: 9
|
||||
use_transpose_conv: False
|
||||
pre_activation: True
|
||||
# discriminator:
|
||||
# _type: MultiScaleDiscriminator
|
||||
# _add_spectral_norm: True
|
||||
# num_scale: 2
|
||||
# down_sample_method: "bilinear"
|
||||
# discriminator_cfg:
|
||||
# _type: PatchDiscriminator
|
||||
# in_channels: 3
|
||||
# base_channels: 64
|
||||
# num_conv: 4
|
||||
# need_intermediate_feature: True
|
||||
discriminator:
|
||||
_type: PatchDiscriminator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_conv: 4
|
||||
need_intermediate_feature: False
|
||||
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: lsgan
|
||||
loss_type: hinge
|
||||
weight: 1.0
|
||||
real_label_val: 1.0
|
||||
real_label_val: 1
|
||||
fake_label_val: 0.0
|
||||
cycle:
|
||||
level: 1
|
||||
@ -47,17 +62,22 @@ loss:
|
||||
level: 1
|
||||
weight: 10.0
|
||||
mgc:
|
||||
weight: 5
|
||||
weight: 1
|
||||
fm:
|
||||
weight: 0
|
||||
edge:
|
||||
weight: 0
|
||||
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 0.0001
|
||||
lr: 1e-4
|
||||
betas: [ 0.5, 0.999 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 1e-4
|
||||
lr: 4e-4
|
||||
betas: [ 0.5, 0.999 ]
|
||||
weight_decay: 0.0001
|
||||
|
||||
@ -75,10 +95,21 @@ data:
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/selfie2anime/trainA"
|
||||
root_b: "/data/i2i/selfie2anime/trainB"
|
||||
root_a: "/data/face2cartoon/all_face"
|
||||
root_b: "/data/selfie2anime/trainB/"
|
||||
random_pair: True
|
||||
pipeline:
|
||||
pipeline_a:
|
||||
- Load
|
||||
- RandomCrop:
|
||||
size: [ 178, 178 ]
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 286, 286 ]
|
||||
@ -99,10 +130,18 @@ data:
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/selfie2anime/testA"
|
||||
root_b: "/data/i2i/selfie2anime/testB"
|
||||
random_pair: False
|
||||
pipeline:
|
||||
root_a: "/data/face2cartoon/test/human"
|
||||
root_b: "/data/face2cartoon/test/anime"
|
||||
random_pair: True
|
||||
pipeline_a:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
|
||||
167
configs/synthesizers/GauGAN.yml
Normal file
167
configs/synthesizers/GauGAN.yml
Normal file
@ -0,0 +1,167 @@
|
||||
name: huawei-GauGAN-3
|
||||
engine: GauGAN
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 4 # log image `image` times per epoch
|
||||
test:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: SPADEGenerator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
num_blocks: 7
|
||||
use_vae: False
|
||||
num_z_dim: 256
|
||||
# discriminator:
|
||||
# _type: MultiScaleDiscriminator
|
||||
# _add_spectral_norm: True
|
||||
# num_scale: 2
|
||||
# down_sample_method: "bilinear"
|
||||
# discriminator_cfg:
|
||||
# _type: PatchDiscriminator
|
||||
# in_channels: 3
|
||||
# base_channels: 64
|
||||
# num_conv: 4
|
||||
# need_intermediate_feature: True
|
||||
discriminator:
|
||||
_type: PatchDiscriminator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_conv: 4
|
||||
need_intermediate_feature: True
|
||||
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: hinge
|
||||
weight: 1.0
|
||||
real_label_val: 1
|
||||
fake_label_val: 0.0
|
||||
perceptual:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
"6": 0.0625
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'L1'
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
weight: 2
|
||||
mgc:
|
||||
weight: 5
|
||||
fm:
|
||||
weight: 5
|
||||
edge:
|
||||
weight: 0
|
||||
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 1e-4
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 4e-4
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
|
||||
data:
|
||||
train:
|
||||
scheduler:
|
||||
start_proportion: 0.5
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/face2cartoon/all_face"
|
||||
root_b: "/data/selfie2anime/trainB/"
|
||||
random_pair: True
|
||||
pipeline_a:
|
||||
- Load
|
||||
- RandomCrop:
|
||||
size: [ 178, 178 ]
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 286, 286 ]
|
||||
- RandomCrop:
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
which: video_dataset
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/face2cartoon/test/human"
|
||||
root_b: "/data/face2cartoon/test/anime"
|
||||
random_pair: True
|
||||
pipeline_a:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
video_dataset:
|
||||
_type: SingleFolderDataset
|
||||
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
|
||||
with_path: True
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
@ -1,7 +1,10 @@
|
||||
name: VoxCeleb2Anime-TSIT
|
||||
engine: TSIT
|
||||
name: huawei-TSIT-1
|
||||
engine: GauGAN
|
||||
result_dir: ./result
|
||||
max_pairs: 1500000
|
||||
max_pairs: 1000000
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
@ -16,34 +19,39 @@ handler:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: TSIT-Generator
|
||||
_bn_to_sync_bn: True
|
||||
style_in_channels: 3
|
||||
content_in_channels: 3
|
||||
num_blocks: 5
|
||||
input_layer_type: "conv7x7"
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
num_blocks: 7
|
||||
# discriminator:
|
||||
# _type: MultiScaleDiscriminator
|
||||
# _add_spectral_norm: True
|
||||
# num_scale: 2
|
||||
# down_sample_method: "bilinear"
|
||||
# discriminator_cfg:
|
||||
# _type: PatchDiscriminator
|
||||
# in_channels: 3
|
||||
# base_channels: 64
|
||||
# num_conv: 4
|
||||
# need_intermediate_feature: True
|
||||
discriminator:
|
||||
_type: MultiScaleDiscriminator
|
||||
num_scale: 2
|
||||
discriminator_cfg:
|
||||
_type: PatchDiscriminator
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
use_spectral: True
|
||||
need_intermediate_feature: True
|
||||
_type: PatchDiscriminator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_conv: 4
|
||||
need_intermediate_feature: True
|
||||
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: hinge
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
weight: 1.0
|
||||
real_label_val: 1
|
||||
fake_label_val: 0.0
|
||||
perceptual:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
@ -55,25 +63,18 @@ loss:
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
weight: 1
|
||||
style:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
"6": 0.0625
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'L2'
|
||||
style_loss: True
|
||||
perceptual_loss: False
|
||||
weight: 0
|
||||
mgc:
|
||||
weight: 5
|
||||
fm:
|
||||
level: 1
|
||||
weight: 1
|
||||
edge:
|
||||
weight: 0
|
||||
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 0.0001
|
||||
lr: 1e-4
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
@ -87,24 +88,35 @@ data:
|
||||
scheduler:
|
||||
start_proportion: 0.5
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
buffer_size: 0
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
batch_size: 1
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/faces/CelebA-Asian/trainA"
|
||||
root_b: "/data/i2i/anime/your-name/faces"
|
||||
root_a: "/data/face2cartoon/all_face"
|
||||
root_b: "/data/selfie2anime/trainB/"
|
||||
random_pair: True
|
||||
pipeline:
|
||||
pipeline_a:
|
||||
- Load
|
||||
- RandomCrop:
|
||||
size: [ 178, 178 ]
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 170, 144 ]
|
||||
size: [ 286, 286 ]
|
||||
- RandomCrop:
|
||||
size: [ 128, 128 ]
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
@ -113,22 +125,28 @@ data:
|
||||
test:
|
||||
which: video_dataset
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/faces/CelebA-Asian/testA"
|
||||
root_b: "/data/i2i/anime/your-name/faces"
|
||||
random_pair: False
|
||||
pipeline:
|
||||
root_a: "/data/face2cartoon/test/human"
|
||||
root_b: "/data/face2cartoon/test/anime"
|
||||
random_pair: True
|
||||
pipeline_a:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 170, 144 ]
|
||||
- RandomCrop:
|
||||
size: [ 128, 128 ]
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
|
||||
@ -38,9 +38,9 @@ class SingleFolderDataset(Dataset):
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDataset(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path)
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline_a, pipeline_b, with_path=False):
|
||||
self.A = SingleFolderDataset(root_a, pipeline_a, with_path)
|
||||
self.B = SingleFolderDataset(root_b, pipeline_b, with_path)
|
||||
self.with_path = with_path
|
||||
self.random_pair = random_pair
|
||||
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from itertools import chain
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
from engine.util.container import GANImageBuffer, LossContainer
|
||||
from engine.util.loss import pixel_loss, gan_loss
|
||||
from engine.util.loss import pixel_loss, gan_loss, feature_match_loss
|
||||
from loss.I2I.edge_loss import EdgeLoss
|
||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
@ -17,7 +19,10 @@ class CycleGANEngineKernel(EngineKernel):
|
||||
self.gan_loss = gan_loss(config.loss.gan)
|
||||
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
|
||||
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
|
||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss())
|
||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
|
||||
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
|
||||
self.edge_loss = LossContainer(config.loss.edge.weight, EdgeLoss(
|
||||
"HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(idist.device()))
|
||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||
self.discriminators.keys()}
|
||||
|
||||
@ -64,8 +69,12 @@ class CycleGANEngineKernel(EngineKernel):
|
||||
loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph])
|
||||
loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph])
|
||||
loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph])
|
||||
loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(
|
||||
self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"]), True)
|
||||
prediction_fake = self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"])
|
||||
loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True)
|
||||
if self.fm_loss.weight > 0:
|
||||
prediction_real = self.discriminators[ph](batch[ph])
|
||||
loss[f"feature_match_{ph}"] = self.fm_loss(prediction_fake, prediction_real)
|
||||
loss[f"edge_{ph}"] = self.edge_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph], gt_is_edge=False)
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
|
||||
86
engine/GauGAN.py
Normal file
86
engine/GauGAN.py
Normal file
@ -0,0 +1,86 @@
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
from engine.util.container import GANImageBuffer, LossContainer
|
||||
from engine.util.loss import gan_loss, feature_match_loss, perceptual_loss
|
||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
|
||||
class GauGANEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.gan_loss = gan_loss(config.loss.gan)
|
||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
|
||||
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
|
||||
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
|
||||
|
||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||
self.discriminators.keys()}
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
main=build_model(self.config.model.generator)
|
||||
)
|
||||
discriminators = dict(
|
||||
b=build_model(self.config.model.discriminator)
|
||||
)
|
||||
self.logger.debug(discriminators["b"])
|
||||
self.logger.debug(generators["main"])
|
||||
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_after_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
def setup_before_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(False)
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
images = dict()
|
||||
with torch.set_grad_enabled(not inference):
|
||||
images["a2b"] = self.generators["main"](batch["a"])
|
||||
return images
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
prediction_fake = self.discriminators["b"](generated["a2b"])
|
||||
loss["gan"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True)
|
||||
loss["mgc"] = self.mgc_loss(generated["a2b"], batch["a"])
|
||||
loss["perceptual"] = self.perceptual_loss(generated["a2b"], batch["a"])
|
||||
if self.fm_loss.weight > 0:
|
||||
prediction_real = self.discriminators["b"](batch["b"])
|
||||
loss["feature_match"] = self.fm_loss(prediction_fake, prediction_real)
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
generated_image = self.image_buffers["b"].query(generated["a2b"].detach())
|
||||
loss["b"] = (self.gan_loss(self.discriminators["b"](generated_image), False, is_discriminator=True) +
|
||||
self.gan_loss(self.discriminators["b"](batch["b"]), True, is_discriminator=True)) / 2
|
||||
return loss
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
:param batch:
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
return dict(
|
||||
a=[batch["a"].detach(), generated["a2b"].detach()],
|
||||
)
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
kernel = GauGANEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
@ -101,9 +101,12 @@ class EngineKernel(object):
|
||||
|
||||
|
||||
def _remove_no_grad_loss(loss_dict):
|
||||
need_to_pop = []
|
||||
for k in loss_dict:
|
||||
if not isinstance(loss_dict[k], torch.Tensor):
|
||||
loss_dict.pop(k)
|
||||
need_to_pop.append(k)
|
||||
for k in need_to_pop:
|
||||
loss_dict.pop(k)
|
||||
return loss_dict
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
|
||||
|
||||
@ -13,6 +14,12 @@ def gan_loss(config):
|
||||
return GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
|
||||
|
||||
def perceptual_loss(config):
|
||||
perceptual_loss_cfg = OmegaConf.to_container(config)
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
return PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
|
||||
|
||||
def pixel_loss(level):
|
||||
return nn.L1Loss() if level == 1 else nn.MSELoss()
|
||||
|
||||
@ -23,3 +30,19 @@ def mse_loss(x, target_flag):
|
||||
|
||||
def bce_loss(x, target_flag):
|
||||
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
|
||||
def feature_match_loss(level, weight_policy):
|
||||
compare_loss = pixel_loss(level)
|
||||
assert weight_policy in ["same", "exponential_decline"]
|
||||
|
||||
def fm_loss(generated_features, target_features):
|
||||
num_scale = len(generated_features)
|
||||
loss = torch.zeros(1, device=idist.device())
|
||||
for s_i in range(num_scale):
|
||||
for i in range(len(generated_features[s_i]) - 1):
|
||||
weight = 1 if weight_policy == "same" else 2 ** i
|
||||
loss += weight * compare_loss(generated_features[s_i][i], target_features[s_i][i].detach()) / num_scale
|
||||
return loss
|
||||
|
||||
return fm_loss
|
||||
|
||||
@ -105,10 +105,12 @@ class MGCLoss(nn.Module):
|
||||
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
|
||||
"""
|
||||
|
||||
def __init__(self, beta=0.5, lambda_=0.05, device=idist.device()):
|
||||
def __init__(self, mi_to_loss_way="opposite", beta=0.5, lambda_=0.05, device=idist.device()):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.lambda_ = lambda_
|
||||
assert mi_to_loss_way in ["opposite", "reciprocal"]
|
||||
self.mi_to_loss_way = mi_to_loss_way
|
||||
mu_y, mu_x = torch.meshgrid([torch.arange(-1, 1.25, 0.25), torch.arange(-1, 1.25, 0.25)])
|
||||
self.mu_x = mu_x.flatten().to(device)
|
||||
self.mu_y = mu_y.flatten().to(device)
|
||||
@ -134,6 +136,8 @@ class MGCLoss(nn.Module):
|
||||
|
||||
def forward(self, fake, real):
|
||||
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_, self.R)
|
||||
if self.mi_to_loss_way == "reciprocal":
|
||||
return 1/rSMI.mean()
|
||||
return -rSMI.mean()
|
||||
|
||||
|
||||
|
||||
20
loss/gan.py
20
loss/gan.py
@ -1,4 +1,5 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@ -10,7 +11,7 @@ class GANLoss(nn.Module):
|
||||
self.fake_label_val = fake_label_val
|
||||
self.loss_type = loss_type
|
||||
|
||||
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
def single_forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
"""
|
||||
gan loss forward
|
||||
:param prediction: network prediction
|
||||
@ -37,3 +38,20 @@ class GANLoss(nn.Module):
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
|
||||
|
||||
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
if isinstance(prediction, torch.Tensor):
|
||||
# origin
|
||||
return self.single_forward(prediction, target_is_real, is_discriminator)
|
||||
elif isinstance(prediction, list):
|
||||
# for multi scale discriminator, e.g. MultiScaleDiscriminator
|
||||
loss = 0
|
||||
for p in prediction:
|
||||
loss += self.single_forward(p[-1], target_is_real, is_discriminator)
|
||||
return loss
|
||||
elif isinstance(prediction, tuple):
|
||||
# for single discriminator set `need_intermediate_feature` true
|
||||
return self.single_forward(prediction[-1], target_is_real, is_discriminator)
|
||||
else:
|
||||
raise NotImplementedError(f"not support discriminator output: {prediction}")
|
||||
|
||||
|
||||
@ -2,3 +2,6 @@ from model.registry import MODEL, NORMALIZATION
|
||||
import model.base.normalization
|
||||
import model.image_translation.UGATIT
|
||||
import model.image_translation.CycleGAN
|
||||
import model.image_translation.pix2pixHD
|
||||
import model.image_translation.GauGAN
|
||||
import model.image_translation.TSIT
|
||||
@ -119,6 +119,8 @@ class ResidualBlock(nn.Module):
|
||||
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
|
||||
|
||||
if self.learn_skip_connection:
|
||||
conv_param['kernel_size'] = 1
|
||||
conv_param['padding'] = 0
|
||||
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
|
||||
|
||||
from model import MODEL
|
||||
|
||||
class StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
|
||||
@ -122,7 +122,7 @@ class ImprovedSPADEGenerator(nn.Module):
|
||||
def forward(self, seg, style=None):
|
||||
pass
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class SPADEGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
|
||||
padding_mode='reflect', activation_type="LeakyReLU"):
|
||||
@ -156,11 +156,8 @@ class SPADEGenerator(nn.Module):
|
||||
)
|
||||
))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
self.output_converter = nn.Sequential(
|
||||
ReverseConv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"),
|
||||
nn.Tanh()
|
||||
)
|
||||
self.output_converter = Conv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")
|
||||
|
||||
def forward(self, seg, z=None):
|
||||
if self.use_vae:
|
||||
|
||||
@ -0,0 +1,98 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
|
||||
from model import MODEL
|
||||
from model.base.module import ResidualBlock, Conv2dBlock
|
||||
|
||||
|
||||
class Interpolation(nn.Module):
|
||||
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
|
||||
super(Interpolation, self).__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
|
||||
recompute_scale_factor=False)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Interpolation(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
|
||||
|
||||
|
||||
@MODEL.register_module("TSIT-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=7,
|
||||
padding_mode='reflect', activation_type="LeakyReLU"):
|
||||
super().__init__()
|
||||
self.input_layer = Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type="IN",
|
||||
)
|
||||
multiple_now = 1
|
||||
stream_sequence = []
|
||||
for i in range(1, num_blocks + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 4)
|
||||
stream_sequence.append(nn.Sequential(
|
||||
Interpolation(scale_factor=0.5, mode="nearest"),
|
||||
ResidualBlock(
|
||||
multiple_prev * base_channels, out_channels=multiple_now * base_channels,
|
||||
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
|
||||
))
|
||||
self.down_sequence = nn.ModuleList(stream_sequence)
|
||||
|
||||
|
||||
sequence = []
|
||||
multiple_now = 16
|
||||
for i in range(num_blocks - 1, -1, -1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 4)
|
||||
sequence.append(nn.Sequential(
|
||||
ResidualBlock(
|
||||
base_channels * multiple_prev,
|
||||
out_channels=base_channels * multiple_now,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type,
|
||||
norm_type="FADE",
|
||||
pre_activation=True,
|
||||
additional_norm_kwargs=dict(
|
||||
condition_in_channels=base_channels * multiple_prev, base_norm_type="BN",
|
||||
padding_mode="zeros", gamma_bias=0.0
|
||||
)
|
||||
),
|
||||
Interpolation(scale_factor=2, mode="nearest")
|
||||
))
|
||||
self.up_sequence = nn.Sequential(*sequence)
|
||||
|
||||
self.output_layer = Conv2dBlock(
|
||||
base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"
|
||||
)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
c = self.input_layer(x)
|
||||
contents = []
|
||||
for blk in self.down_sequence:
|
||||
c = blk(c)
|
||||
contents.append(c)
|
||||
if z is None:
|
||||
# for image 256x256, z size: [batch_size, 1024, 2, 2]
|
||||
z = torch.randn(size=contents[-1].size(), device=contents[-1].device)
|
||||
|
||||
for blk in self.up_sequence:
|
||||
res = blk[0]
|
||||
c = contents.pop()
|
||||
res.conv1.normalization.set_feature(c)
|
||||
res.conv2.normalization.set_feature(c)
|
||||
if res.learn_skip_connection:
|
||||
res.res_conv.normalization.set_feature(c)
|
||||
return self.output_layer(self.up_sequence(z))
|
||||
|
||||
if __name__ == '__main__':
|
||||
g = Generator(3, 3).cuda()
|
||||
img = torch.randn(2, 3, 256, 256).cuda()
|
||||
print(g(img).size())
|
||||
|
||||
|
||||
29
model/image_translation/pix2pixHD.py
Normal file
29
model/image_translation/pix2pixHD.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import MODEL
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class MultiScaleDiscriminator(nn.Module):
|
||||
def __init__(self, num_scale, discriminator_cfg, down_sample_method="avg"):
|
||||
super().__init__()
|
||||
assert down_sample_method in ["avg", "bilinear"]
|
||||
self.down_sample_method = down_sample_method
|
||||
|
||||
self.discriminator_list = nn.ModuleList([
|
||||
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
|
||||
])
|
||||
|
||||
def down_sample(self, x):
|
||||
if self.down_sample_method == "avg":
|
||||
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
|
||||
if self.down_sample_method == "bilinear":
|
||||
return F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True)
|
||||
|
||||
def forward(self, x):
|
||||
results = []
|
||||
for discriminator in self.discriminator_list:
|
||||
results.append(discriminator(x))
|
||||
x = self.down_sample(x)
|
||||
return results
|
||||
@ -65,7 +65,8 @@ def generation_init_weights(module, init_type='normal', init_gain=0.02):
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
# BatchNorm Layer's weight is not a matrix;
|
||||
# only normal distribution applies.
|
||||
normal_init(m, 1.0, init_gain)
|
||||
if m.weight is not None:
|
||||
normal_init(m, 1.0, init_gain)
|
||||
|
||||
assert isinstance(module, nn.Module)
|
||||
module.apply(init_func)
|
||||
|
||||
@ -53,11 +53,9 @@ class _Registry:
|
||||
else:
|
||||
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
|
||||
|
||||
for k in args:
|
||||
assert isinstance(k, str)
|
||||
if k.startswith("_"):
|
||||
warnings.warn(f"got param start with `_`: {k}, will remove it")
|
||||
args.pop(k)
|
||||
for invalid_key in [k for k in args.keys() if k.startswith("_")]:
|
||||
warnings.warn(f"got param start with `_`: {invalid_key}, will remove it")
|
||||
args.pop(invalid_key)
|
||||
|
||||
if not (isinstance(default_args, dict) or default_args is None):
|
||||
raise TypeError('default_args must be a dict or None, '
|
||||
|
||||
Loading…
Reference in New Issue
Block a user