Compare commits
3 Commits
ab545843bf
...
7ea9c6d0d5
| Author | SHA1 | Date | |
|---|---|---|---|
| 7ea9c6d0d5 | |||
| 87cbcf34d3 | |||
| 97ded53b30 |
@ -23,7 +23,8 @@ model:
|
|||||||
_bn_to_sync_bn: False
|
_bn_to_sync_bn: False
|
||||||
style_in_channels: 3
|
style_in_channels: 3
|
||||||
content_in_channels: 24
|
content_in_channels: 24
|
||||||
num_blocks: 8
|
num_adain_blocks: 8
|
||||||
|
num_res_blocks: 0
|
||||||
discriminator:
|
discriminator:
|
||||||
_type: MultiScaleDiscriminator
|
_type: MultiScaleDiscriminator
|
||||||
num_scale: 2
|
num_scale: 2
|
||||||
@ -47,21 +48,17 @@ loss:
|
|||||||
"11": 0.125
|
"11": 0.125
|
||||||
"20": 0.25
|
"20": 0.25
|
||||||
"29": 1
|
"29": 1
|
||||||
criterion: 'L2'
|
criterion: 'L1'
|
||||||
style_loss: False
|
style_loss: False
|
||||||
perceptual_loss: True
|
perceptual_loss: True
|
||||||
weight: 0.5
|
weight: 10
|
||||||
style:
|
style:
|
||||||
layer_weights:
|
layer_weights:
|
||||||
"1": 0.03125
|
"3": 1
|
||||||
"6": 0.0625
|
criterion: 'L1'
|
||||||
"11": 0.125
|
|
||||||
"20": 0.25
|
|
||||||
"29": 1
|
|
||||||
criterion: 'L2'
|
|
||||||
style_loss: True
|
style_loss: True
|
||||||
perceptual_loss: False
|
perceptual_loss: False
|
||||||
weight: 0
|
weight: 10
|
||||||
fm:
|
fm:
|
||||||
level: 1
|
level: 1
|
||||||
weight: 10
|
weight: 10
|
||||||
@ -71,6 +68,9 @@ loss:
|
|||||||
style_recon:
|
style_recon:
|
||||||
level: 1
|
level: 1
|
||||||
weight: 0
|
weight: 0
|
||||||
|
edge:
|
||||||
|
weight: 10
|
||||||
|
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||||
|
|
||||||
optimizers:
|
optimizers:
|
||||||
generator:
|
generator:
|
||||||
@ -91,7 +91,7 @@ data:
|
|||||||
target_lr: 0
|
target_lr: 0
|
||||||
buffer_size: 50
|
buffer_size: 50
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 24
|
batch_size: 8
|
||||||
shuffle: True
|
shuffle: True
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
pin_memory: True
|
pin_memory: True
|
||||||
|
|||||||
146
configs/synthesizers/TSIT.yml
Normal file
146
configs/synthesizers/TSIT.yml
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
name: self2anime-TSIT
|
||||||
|
engine: TSIT
|
||||||
|
result_dir: ./result
|
||||||
|
max_pairs: 1500000
|
||||||
|
|
||||||
|
handler:
|
||||||
|
clear_cuda_cache: True
|
||||||
|
set_epoch_for_dist_sampler: True
|
||||||
|
checkpoint:
|
||||||
|
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||||
|
n_saved: 2
|
||||||
|
tensorboard:
|
||||||
|
scalar: 100 # log scalar `scalar` times per epoch
|
||||||
|
image: 2 # log image `image` times per epoch
|
||||||
|
|
||||||
|
|
||||||
|
misc:
|
||||||
|
random_seed: 324
|
||||||
|
|
||||||
|
model:
|
||||||
|
generator:
|
||||||
|
_type: TSIT-Generator
|
||||||
|
_bn_to_sync_bn: True
|
||||||
|
style_in_channels: 3
|
||||||
|
content_in_channels: 3
|
||||||
|
num_blocks: 5
|
||||||
|
input_layer_type: "conv7x7"
|
||||||
|
discriminator:
|
||||||
|
_type: MultiScaleDiscriminator
|
||||||
|
num_scale: 2
|
||||||
|
discriminator_cfg:
|
||||||
|
_type: PatchDiscriminator
|
||||||
|
in_channels: 3
|
||||||
|
base_channels: 64
|
||||||
|
use_spectral: True
|
||||||
|
need_intermediate_feature: True
|
||||||
|
|
||||||
|
loss:
|
||||||
|
gan:
|
||||||
|
loss_type: hinge
|
||||||
|
real_label_val: 1.0
|
||||||
|
fake_label_val: 0.0
|
||||||
|
weight: 1.0
|
||||||
|
perceptual:
|
||||||
|
layer_weights:
|
||||||
|
"1": 0.03125
|
||||||
|
"6": 0.0625
|
||||||
|
"11": 0.125
|
||||||
|
"20": 0.25
|
||||||
|
"29": 1
|
||||||
|
criterion: 'L1'
|
||||||
|
style_loss: False
|
||||||
|
perceptual_loss: True
|
||||||
|
weight: 1
|
||||||
|
style:
|
||||||
|
layer_weights:
|
||||||
|
"1": 0.03125
|
||||||
|
"6": 0.0625
|
||||||
|
"11": 0.125
|
||||||
|
"20": 0.25
|
||||||
|
"29": 1
|
||||||
|
criterion: 'L2'
|
||||||
|
style_loss: True
|
||||||
|
perceptual_loss: False
|
||||||
|
weight: 0
|
||||||
|
fm:
|
||||||
|
level: 1
|
||||||
|
weight: 1
|
||||||
|
|
||||||
|
optimizers:
|
||||||
|
generator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 0.0001
|
||||||
|
betas: [ 0, 0.9 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
discriminator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 4e-4
|
||||||
|
betas: [ 0, 0.9 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
|
||||||
|
data:
|
||||||
|
train:
|
||||||
|
scheduler:
|
||||||
|
start_proportion: 0.5
|
||||||
|
target_lr: 0
|
||||||
|
buffer_size: 50
|
||||||
|
dataloader:
|
||||||
|
batch_size: 1
|
||||||
|
shuffle: True
|
||||||
|
num_workers: 2
|
||||||
|
pin_memory: True
|
||||||
|
drop_last: True
|
||||||
|
dataset:
|
||||||
|
_type: GenerationUnpairedDatasetWithEdge
|
||||||
|
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
|
||||||
|
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
|
||||||
|
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||||
|
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||||
|
edge_type: "landmark_hed"
|
||||||
|
size: [ 128, 128 ]
|
||||||
|
random_pair: True
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 128, 128 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
test:
|
||||||
|
dataloader:
|
||||||
|
batch_size: 8
|
||||||
|
shuffle: False
|
||||||
|
num_workers: 1
|
||||||
|
pin_memory: False
|
||||||
|
drop_last: False
|
||||||
|
dataset:
|
||||||
|
_type: GenerationUnpairedDatasetWithEdge
|
||||||
|
root_a: "/data/i2i/VoxCeleb2Anime/testA"
|
||||||
|
root_b: "/data/i2i/VoxCeleb2Anime/testB"
|
||||||
|
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||||
|
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||||
|
edge_type: "landmark_hed"
|
||||||
|
random_pair: False
|
||||||
|
size: [ 128, 128 ]
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 128, 128 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
video_dataset:
|
||||||
|
_type: SingleFolderDataset
|
||||||
|
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
|
||||||
|
with_path: True
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
@ -1,20 +1,17 @@
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
import ignite.distributed as idist
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import ignite.distributed as idist
|
|
||||||
from ignite.engine import Events
|
from ignite.engine import Events
|
||||||
|
|
||||||
from omegaconf import read_write, OmegaConf
|
from omegaconf import read_write, OmegaConf
|
||||||
|
|
||||||
from model.weight_init import generation_init_weights
|
|
||||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
|
||||||
from loss.gan import GANLoss
|
|
||||||
|
|
||||||
from engine.base.i2i import EngineKernel, run_kernel
|
from engine.base.i2i import EngineKernel, run_kernel
|
||||||
from engine.util.build import build_model
|
from engine.util.build import build_model
|
||||||
|
from loss.I2I.edge_loss import EdgeLoss
|
||||||
|
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||||
|
from loss.gan import GANLoss
|
||||||
|
from model.weight_init import generation_init_weights
|
||||||
|
|
||||||
|
|
||||||
class TAFGEngineKernel(EngineKernel):
|
class TAFGEngineKernel(EngineKernel):
|
||||||
@ -24,6 +21,10 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
perceptual_loss_cfg.pop("weight")
|
perceptual_loss_cfg.pop("weight")
|
||||||
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
style_loss_cfg = OmegaConf.to_container(config.loss.style)
|
||||||
|
style_loss_cfg.pop("weight")
|
||||||
|
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||||
gan_loss_cfg.pop("weight")
|
gan_loss_cfg.pop("weight")
|
||||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
@ -32,6 +33,9 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||||
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
|
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
|
||||||
|
|
||||||
|
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
|
||||||
|
idist.device())
|
||||||
|
|
||||||
def _process_batch(self, batch, inference=False):
|
def _process_batch(self, batch, inference=False):
|
||||||
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
|
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
|
||||||
return batch
|
return batch
|
||||||
@ -74,7 +78,9 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
batch = self._process_batch(batch)
|
batch = self._process_batch(batch)
|
||||||
loss = dict()
|
loss = dict()
|
||||||
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
||||||
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
|
_, loss_style = self.style_loss(generated["a"], batch["a"])
|
||||||
|
loss["style"] = self.config.loss.style.weight * loss_style
|
||||||
|
loss["perceptual"] = self.config.loss.perceptual.weight * loss_perceptual
|
||||||
for phase in "ab":
|
for phase in "ab":
|
||||||
pred_fake = self.discriminators[phase](generated[phase])
|
pred_fake = self.discriminators[phase](generated[phase])
|
||||||
loss[f"gan_{phase}"] = 0
|
loss[f"gan_{phase}"] = 0
|
||||||
@ -93,10 +99,7 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||||
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
||||||
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
|
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
|
||||||
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
loss["edge"] = self.config.loss.edge.weight * self.edge_loss(generated["b"], batch["edge_a"][:, 0:1, :, :])
|
||||||
# self.generators["main"].module.style_encoders["b"](batch["b"]),
|
|
||||||
# self.generators["main"].module.style_encoders["b"](generated["b"])
|
|
||||||
# )
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def criterion_discriminators(self, batch, generated) -> dict:
|
def criterion_discriminators(self, batch, generated) -> dict:
|
||||||
|
|||||||
106
engine/TSIT.py
Normal file
106
engine/TSIT.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
import ignite.distributed as idist
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from engine.base.i2i import EngineKernel, run_kernel
|
||||||
|
from engine.util.build import build_model
|
||||||
|
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||||
|
from loss.gan import GANLoss
|
||||||
|
from model.weight_init import generation_init_weights
|
||||||
|
|
||||||
|
|
||||||
|
class TSITEngineKernel(EngineKernel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
||||||
|
perceptual_loss_cfg.pop("weight")
|
||||||
|
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||||
|
gan_loss_cfg.pop("weight")
|
||||||
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
|
||||||
|
|
||||||
|
def build_models(self) -> (dict, dict):
|
||||||
|
generators = dict(
|
||||||
|
main=build_model(self.config.model.generator)
|
||||||
|
)
|
||||||
|
discriminators = dict(
|
||||||
|
b=build_model(self.config.model.discriminator)
|
||||||
|
)
|
||||||
|
self.logger.debug(discriminators["b"])
|
||||||
|
self.logger.debug(generators["main"])
|
||||||
|
|
||||||
|
for m in chain(generators.values(), discriminators.values()):
|
||||||
|
generation_init_weights(m)
|
||||||
|
|
||||||
|
return generators, discriminators
|
||||||
|
|
||||||
|
def setup_after_g(self):
|
||||||
|
for discriminator in self.discriminators.values():
|
||||||
|
discriminator.requires_grad_(True)
|
||||||
|
|
||||||
|
def setup_before_g(self):
|
||||||
|
for discriminator in self.discriminators.values():
|
||||||
|
discriminator.requires_grad_(False)
|
||||||
|
|
||||||
|
def forward(self, batch, inference=False) -> dict:
|
||||||
|
with torch.set_grad_enabled(not inference):
|
||||||
|
fake = dict(
|
||||||
|
b=self.generators["main"](content_img=batch["a"], style_img=batch["b"])
|
||||||
|
)
|
||||||
|
return fake
|
||||||
|
|
||||||
|
def criterion_generators(self, batch, generated) -> dict:
|
||||||
|
loss = dict()
|
||||||
|
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
||||||
|
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
|
||||||
|
for phase in "b":
|
||||||
|
pred_fake = self.discriminators[phase](generated[phase])
|
||||||
|
loss[f"gan_{phase}"] = 0
|
||||||
|
for sub_pred_fake in pred_fake:
|
||||||
|
# last output is actual prediction
|
||||||
|
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
|
||||||
|
|
||||||
|
if self.config.loss.fm.weight > 0 and phase == "b":
|
||||||
|
pred_real = self.discriminators[phase](batch[phase])
|
||||||
|
loss_fm = 0
|
||||||
|
num_scale_discriminator = len(pred_fake)
|
||||||
|
for i in range(num_scale_discriminator):
|
||||||
|
# last output is the final prediction, so we exclude it
|
||||||
|
num_intermediate_outputs = len(pred_fake[i]) - 1
|
||||||
|
for j in range(num_intermediate_outputs):
|
||||||
|
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||||
|
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def criterion_discriminators(self, batch, generated) -> dict:
|
||||||
|
loss = dict()
|
||||||
|
for phase in self.discriminators.keys():
|
||||||
|
pred_real = self.discriminators[phase](batch[phase])
|
||||||
|
pred_fake = self.discriminators[phase](generated[phase].detach())
|
||||||
|
loss[f"gan_{phase}"] = 0
|
||||||
|
for i in range(len(pred_fake)):
|
||||||
|
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||||
|
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def intermediate_images(self, batch, generated) -> dict:
|
||||||
|
"""
|
||||||
|
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||||
|
:param batch:
|
||||||
|
:param generated: dict of images
|
||||||
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||||
|
"""
|
||||||
|
return dict(
|
||||||
|
b=[batch["a"].detach(), batch["b"].detach(), generated["b"].detach()]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run(task, config, _):
|
||||||
|
kernel = TSITEngineKernel(config)
|
||||||
|
run_kernel(task, config, kernel)
|
||||||
@ -1,9 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from .base import ResidualBlock
|
|
||||||
from model.registry import MODEL
|
|
||||||
from torchvision.models import vgg19
|
from torchvision.models import vgg19
|
||||||
|
|
||||||
from model.normalization import select_norm_layer
|
from model.normalization import select_norm_layer
|
||||||
|
from model.registry import MODEL
|
||||||
|
from .base import ResidualBlock
|
||||||
|
|
||||||
|
|
||||||
class VGG19StyleEncoder(nn.Module):
|
class VGG19StyleEncoder(nn.Module):
|
||||||
@ -169,25 +170,37 @@ class StyleGenerator(nn.Module):
|
|||||||
|
|
||||||
@MODEL.register_module("TAFG-Generator")
|
@MODEL.register_module("TAFG-Generator")
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8,
|
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512,
|
||||||
|
num_adain_blocks=8, num_res_blocks=4,
|
||||||
base_channels=64, padding_mode="reflect"):
|
base_channels=64, padding_mode="reflect"):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
self.num_blocks = num_blocks
|
self.num_adain_blocks=num_adain_blocks
|
||||||
self.style_encoders = nn.ModuleDict({
|
self.style_encoders = nn.ModuleDict({
|
||||||
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
|
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
|
||||||
base_channels=base_channels, padding_mode=padding_mode),
|
base_channels=base_channels, padding_mode=padding_mode),
|
||||||
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
|
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
|
||||||
base_channels=base_channels, padding_mode=padding_mode),
|
base_channels=base_channels, padding_mode=padding_mode),
|
||||||
})
|
})
|
||||||
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
|
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=8,
|
||||||
padding_mode=padding_mode, norm_type="IN")
|
padding_mode=padding_mode, norm_type="IN")
|
||||||
res_block_channels = 2 ** 2 * base_channels
|
res_block_channels = 2 ** 2 * base_channels
|
||||||
self.adain_resnet_a = nn.ModuleList([
|
|
||||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
self.resnet = nn.ModuleDict({
|
||||||
])
|
"a": nn.Sequential(*[
|
||||||
self.adain_resnet_b = nn.ModuleList([
|
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
|
||||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
]),
|
||||||
])
|
"b": nn.Sequential(*[
|
||||||
|
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
|
||||||
|
])
|
||||||
|
})
|
||||||
|
self.adain_resnet = nn.ModuleDict({
|
||||||
|
"a": nn.ModuleList([
|
||||||
|
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
|
||||||
|
]),
|
||||||
|
"b": nn.ModuleList([
|
||||||
|
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
|
||||||
|
])
|
||||||
|
})
|
||||||
|
|
||||||
self.decoders = nn.ModuleDict({
|
self.decoders = nn.ModuleDict({
|
||||||
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode),
|
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode),
|
||||||
@ -196,10 +209,10 @@ class Generator(nn.Module):
|
|||||||
|
|
||||||
def forward(self, content_img, style_img, which_decoder: str = "a"):
|
def forward(self, content_img, style_img, which_decoder: str = "a"):
|
||||||
x = self.content_encoder(content_img)
|
x = self.content_encoder(content_img)
|
||||||
|
x = self.resnet[which_decoder](x)
|
||||||
styles = self.style_encoders[which_decoder](style_img)
|
styles = self.style_encoders[which_decoder](style_img)
|
||||||
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
|
styles = torch.chunk(styles, self.num_adain_blocks * 2, dim=1)
|
||||||
resnet = self.adain_resnet_a if which_decoder == "a" else self.adain_resnet_b
|
for i, ar in enumerate(self.adain_resnet[which_decoder]):
|
||||||
for i, ar in enumerate(resnet):
|
|
||||||
ar.norm1.set_style(styles[2 * i])
|
ar.norm1.set_style(styles[2 * i])
|
||||||
ar.norm2.set_style(styles[2 * i + 1])
|
ar.norm2.set_style(styles[2 * i + 1])
|
||||||
x = ar(x)
|
x = ar(x)
|
||||||
|
|||||||
192
model/GAN/TSIT.py
Normal file
192
model/GAN/TSIT.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from model import MODEL
|
||||||
|
from model.normalization import AdaptiveInstanceNorm2d
|
||||||
|
from model.normalization import select_norm_layer
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, padding_mode='zeros', norm_type="IN", use_bias=None,
|
||||||
|
use_spectral=True):
|
||||||
|
super().__init__()
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
self.use_bias = use_bias
|
||||||
|
self.use_spectral = use_spectral
|
||||||
|
if use_bias is None:
|
||||||
|
# Only for IN, use bias since it does not have affine parameters.
|
||||||
|
self.use_bias = norm_type == "IN"
|
||||||
|
norm_layer = select_norm_layer(norm_type)
|
||||||
|
self.main = nn.Sequential(
|
||||||
|
self.conv_block(in_channels, in_channels),
|
||||||
|
norm_layer(in_channels),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
self.conv_block(in_channels, out_channels),
|
||||||
|
norm_layer(out_channels),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
)
|
||||||
|
self.skip = nn.Sequential(
|
||||||
|
self.conv_block(in_channels, out_channels, padding=0, kernel_size=1),
|
||||||
|
norm_layer(out_channels),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def conv_block(self, in_channels, out_channels, kernel_size=3, padding=1):
|
||||||
|
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding,
|
||||||
|
padding_mode=self.padding_mode, bias=self.use_bias)
|
||||||
|
if self.use_spectral:
|
||||||
|
return nn.utils.spectral_norm(conv)
|
||||||
|
else:
|
||||||
|
return conv
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.main(x) + self.skip(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Interpolation(nn.Module):
|
||||||
|
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
|
||||||
|
super(Interpolation, self).__init__()
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.mode = mode
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
|
||||||
|
recompute_scale_factor=False)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
|
||||||
|
|
||||||
|
|
||||||
|
class FADE(nn.Module):
|
||||||
|
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
|
||||||
|
super().__init__()
|
||||||
|
self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
|
||||||
|
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
||||||
|
padding_mode="zeros")
|
||||||
|
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
||||||
|
padding_mode="zeros")
|
||||||
|
|
||||||
|
def forward(self, x, feature):
|
||||||
|
alpha = self.alpha_conv(feature)
|
||||||
|
beta = self.beta_conv(feature)
|
||||||
|
x = self.bn(x)
|
||||||
|
return alpha * x + beta
|
||||||
|
|
||||||
|
|
||||||
|
class FADEResBlock(nn.Module):
|
||||||
|
def __init__(self, use_spectral, features_channels, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.main = nn.Sequential(
|
||||||
|
FADE(use_spectral, features_channels, in_channels),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
conv_block(use_spectral, in_channels, in_channels, kernel_size=3, padding=1),
|
||||||
|
FADE(use_spectral, features_channels, in_channels),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
conv_block(use_spectral, in_channels, out_channels, kernel_size=3, padding=1),
|
||||||
|
)
|
||||||
|
self.skip = nn.Sequential(
|
||||||
|
FADE(use_spectral, features_channels, in_channels),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
conv_block(use_spectral, in_channels, out_channels, kernel_size=1, padding=0),
|
||||||
|
)
|
||||||
|
self.up_sample = Interpolation(2, mode="nearest")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_with_fade(module, x, feature):
|
||||||
|
for layer in module:
|
||||||
|
if layer.__class__.__name__ == "FADE":
|
||||||
|
x = layer(x, feature)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, feature):
|
||||||
|
out = self.forward_with_fade(self.main, x, feature) + self.forward_with_fade(self.main, x, feature)
|
||||||
|
return self.up_sample(out)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_block(use_spectral, in_channels, out_channels, **kwargs):
|
||||||
|
conv = nn.Conv2d(in_channels, out_channels, **kwargs)
|
||||||
|
return nn.utils.spectral_norm(conv) if use_spectral else conv
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL.register_module("TSIT-Generator")
|
||||||
|
class TSITGenerator(nn.Module):
|
||||||
|
def __init__(self, num_blocks=7, base_channels=64, content_in_channels=3, style_in_channels=3,
|
||||||
|
out_channels=3, use_spectral=True, input_layer_type="conv1x1"):
|
||||||
|
super().__init__()
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.base_channels = base_channels
|
||||||
|
self.use_spectral = use_spectral
|
||||||
|
|
||||||
|
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
|
||||||
|
self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type)
|
||||||
|
self.content_stream = self.build_stream()
|
||||||
|
self.style_stream = self.build_stream()
|
||||||
|
self.generator = self.build_generator()
|
||||||
|
self.end_conv = nn.Sequential(
|
||||||
|
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
|
||||||
|
nn.Tanh()
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_generator(self):
|
||||||
|
stream_sequence = []
|
||||||
|
multiple_now = min(2 ** self.num_blocks, 2 ** 4)
|
||||||
|
for i in range(1, self.num_blocks + 1):
|
||||||
|
m = self.num_blocks - i
|
||||||
|
multiple_prev = multiple_now
|
||||||
|
multiple_now = min(2 ** m, 2 ** 4)
|
||||||
|
stream_sequence.append(nn.Sequential(
|
||||||
|
AdaptiveInstanceNorm2d(multiple_prev * self.base_channels),
|
||||||
|
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
|
||||||
|
multiple_now * self.base_channels)
|
||||||
|
))
|
||||||
|
return nn.ModuleList(stream_sequence)
|
||||||
|
|
||||||
|
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
|
||||||
|
if input_layer_type == "conv7x7":
|
||||||
|
return nn.Sequential(
|
||||||
|
conv_block(self.use_spectral, in_channels, out_channels, kernel_size=7, stride=1,
|
||||||
|
padding_mode="zeros", padding=3, bias=True),
|
||||||
|
select_norm_layer("IN")(out_channels),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
)
|
||||||
|
elif input_layer_type == "conv1x1":
|
||||||
|
return conv_block(self.use_spectral, in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
else:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def build_stream(self):
|
||||||
|
multiple_now = 1
|
||||||
|
stream_sequence = []
|
||||||
|
for i in range(1, self.num_blocks + 1):
|
||||||
|
multiple_prev = multiple_now
|
||||||
|
multiple_now = min(2 ** i, 2 ** 4)
|
||||||
|
stream_sequence.append(nn.Sequential(
|
||||||
|
Interpolation(scale_factor=0.5, mode="nearest"),
|
||||||
|
ResBlock(multiple_prev * self.base_channels, multiple_now * self.base_channels,
|
||||||
|
use_spectral=self.use_spectral)
|
||||||
|
))
|
||||||
|
return nn.ModuleList(stream_sequence)
|
||||||
|
|
||||||
|
def forward(self, content_img, style_img):
|
||||||
|
c = self.content_input_layer(content_img)
|
||||||
|
s = self.style_input_layer(style_img)
|
||||||
|
content_features = []
|
||||||
|
style_features = []
|
||||||
|
for i in range(self.num_blocks):
|
||||||
|
s = self.style_stream[i](s)
|
||||||
|
c = self.content_stream[i](c)
|
||||||
|
content_features.append(c)
|
||||||
|
style_features.append(s)
|
||||||
|
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
|
||||||
|
|
||||||
|
for i in range(self.num_blocks):
|
||||||
|
m = - i - 1
|
||||||
|
layer = self.generator[i]
|
||||||
|
layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1))
|
||||||
|
z = layer[0](z)
|
||||||
|
z = layer[1](z, content_features[m])
|
||||||
|
return self.end_conv(z)
|
||||||
@ -4,3 +4,4 @@ import model.GAN.TAFG
|
|||||||
import model.GAN.UGATIT
|
import model.GAN.UGATIT
|
||||||
import model.GAN.wrapper
|
import model.GAN.wrapper
|
||||||
import model.GAN.base
|
import model.GAN.base
|
||||||
|
import model.GAN.TSIT
|
||||||
@ -1,6 +1,7 @@
|
|||||||
import torch.nn as nn
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def select_norm_layer(norm_type):
|
def select_norm_layer(norm_type):
|
||||||
|
|||||||
8
run.sh
8
run.sh
@ -5,16 +5,18 @@ TASK=$2
|
|||||||
GPUS=$3
|
GPUS=$3
|
||||||
MORE_ARG=${*:4}
|
MORE_ARG=${*:4}
|
||||||
|
|
||||||
|
RANDOM_MASTER=$(shuf -i 2000-65000 -n 1)
|
||||||
|
|
||||||
_command="print(len('${GPUS}'.split(',')))"
|
_command="print(len('${GPUS}'.split(',')))"
|
||||||
GPU_COUNT=$(python3 -c "${_command}")
|
GPU_COUNT=$(python3 -c "${_command}")
|
||||||
|
|
||||||
echo "GPU_COUNT:${GPU_COUNT}"
|
echo "GPU_COUNT:${GPU_COUNT}"
|
||||||
|
|
||||||
echo CUDA_VISIBLE_DEVICES=$GPUS \
|
echo CUDA_VISIBLE_DEVICES=$GPUS \
|
||||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
||||||
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed "$MORE_ARG"
|
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed "$MORE_ARG"
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=$GPUS \
|
CUDA_VISIBLE_DEVICES=$GPUS \
|
||||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
||||||
|
--master_port=${RANDOM_MASTER} \
|
||||||
main.py "$TASK" "$CONFIG" $MORE_ARG --backup_config --setup_output_dir --setup_random_seed
|
main.py "$TASK" "$CONFIG" $MORE_ARG --backup_config --setup_output_dir --setup_random_seed
|
||||||
|
|
||||||
|
|||||||
46
tool/dump_tensorboard.py
Normal file
46
tool/dump_tensorboard.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# edit from https://gist.github.com/hysts/81a0d30ac4f33dfa0c8859383aec42c2
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from tensorboard.backend.event_processing import event_accumulator
|
||||||
|
|
||||||
|
|
||||||
|
def save(outdir: pathlib.Path, tag, event_acc):
|
||||||
|
events = event_acc.Images(tag)
|
||||||
|
|
||||||
|
for index, event in enumerate(events):
|
||||||
|
s = np.frombuffer(event.encoded_image_string, dtype=np.uint8)
|
||||||
|
image = cv2.imdecode(s, cv2.IMREAD_COLOR)
|
||||||
|
outpath = outdir / f"{tag.replace('/', '_')}@{index}.png"
|
||||||
|
cv2.imwrite(outpath.as_posix(), image)
|
||||||
|
|
||||||
|
|
||||||
|
# ffmpeg -framerate 1 -i ./tmp/test_b/%04d.jpg -vcodec mpeg4 test_b.mp4
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--path', type=str, required=True)
|
||||||
|
parser.add_argument('--outdir', type=str, required=True)
|
||||||
|
parser.add_argument("--tag", type=str, required=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
event_acc = event_accumulator.EventAccumulator(args.path, size_guidance={'images': 0})
|
||||||
|
event_acc.Reload()
|
||||||
|
|
||||||
|
outdir = pathlib.Path(args.outdir)
|
||||||
|
outdir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
if args.tag is None:
|
||||||
|
for tag in event_acc.Tags()['images']:
|
||||||
|
save(outdir, tag, event_acc)
|
||||||
|
else:
|
||||||
|
assert args.tag in event_acc.Tags()['images'], f"{args.tag} not in {event_acc.Tags()['images']}"
|
||||||
|
save(outdir, args.tag, event_acc)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue
Block a user