Compare commits

...

3 Commits

Author SHA1 Message Date
7ea9c6d0d5 TAFG good result 2020-09-09 14:46:07 +08:00
87cbcf34d3 add tool to dump images in tensorboard event file 2020-09-09 09:08:11 +08:00
97ded53b30 update a lot 2020-09-07 21:38:10 +08:00
10 changed files with 554 additions and 44 deletions

View File

@ -23,7 +23,8 @@ model:
_bn_to_sync_bn: False _bn_to_sync_bn: False
style_in_channels: 3 style_in_channels: 3
content_in_channels: 24 content_in_channels: 24
num_blocks: 8 num_adain_blocks: 8
num_res_blocks: 0
discriminator: discriminator:
_type: MultiScaleDiscriminator _type: MultiScaleDiscriminator
num_scale: 2 num_scale: 2
@ -47,21 +48,17 @@ loss:
"11": 0.125 "11": 0.125
"20": 0.25 "20": 0.25
"29": 1 "29": 1
criterion: 'L2' criterion: 'L1'
style_loss: False style_loss: False
perceptual_loss: True perceptual_loss: True
weight: 0.5 weight: 10
style: style:
layer_weights: layer_weights:
"1": 0.03125 "3": 1
"6": 0.0625 criterion: 'L1'
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L2'
style_loss: True style_loss: True
perceptual_loss: False perceptual_loss: False
weight: 0 weight: 10
fm: fm:
level: 1 level: 1
weight: 10 weight: 10
@ -71,6 +68,9 @@ loss:
style_recon: style_recon:
level: 1 level: 1
weight: 0 weight: 0
edge:
weight: 10
hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers: optimizers:
generator: generator:
@ -91,7 +91,7 @@ data:
target_lr: 0 target_lr: 0
buffer_size: 50 buffer_size: 50
dataloader: dataloader:
batch_size: 24 batch_size: 8
shuffle: True shuffle: True
num_workers: 2 num_workers: 2
pin_memory: True pin_memory: True

View File

@ -0,0 +1,146 @@
name: self2anime-TSIT
engine: TSIT
result_dir: ./result
max_pairs: 1500000
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
misc:
random_seed: 324
model:
generator:
_type: TSIT-Generator
_bn_to_sync_bn: True
style_in_channels: 3
content_in_channels: 3
num_blocks: 5
input_layer_type: "conv7x7"
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: PatchDiscriminator
in_channels: 3
base_channels: 64
use_spectral: True
need_intermediate_feature: True
loss:
gan:
loss_type: hinge
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 1
style:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L2'
style_loss: True
perceptual_loss: False
weight: 0
fm:
level: 1
weight: 1
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDatasetWithEdge
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
edge_type: "landmark_hed"
size: [ 128, 128 ]
random_pair: True
pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
dataloader:
batch_size: 8
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDatasetWithEdge
root_a: "/data/i2i/VoxCeleb2Anime/testA"
root_b: "/data/i2i/VoxCeleb2Anime/testB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
edge_type: "landmark_hed"
random_pair: False
size: [ 128, 128 ]
pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@ -1,20 +1,17 @@
from itertools import chain from itertools import chain
from omegaconf import OmegaConf import ignite.distributed as idist
import torch import torch
import torch.nn as nn import torch.nn as nn
import ignite.distributed as idist
from ignite.engine import Events from ignite.engine import Events
from omegaconf import read_write, OmegaConf from omegaconf import read_write, OmegaConf
from model.weight_init import generation_init_weights
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from engine.base.i2i import EngineKernel, run_kernel from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model from engine.util.build import build_model
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TAFGEngineKernel(EngineKernel): class TAFGEngineKernel(EngineKernel):
@ -24,6 +21,10 @@ class TAFGEngineKernel(EngineKernel):
perceptual_loss_cfg.pop("weight") perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight") gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
@ -32,6 +33,9 @@ class TAFGEngineKernel(EngineKernel):
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss() self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device())
def _process_batch(self, batch, inference=False): def _process_batch(self, batch, inference=False):
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size()) # batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
return batch return batch
@ -74,7 +78,9 @@ class TAFGEngineKernel(EngineKernel):
batch = self._process_batch(batch) batch = self._process_batch(batch)
loss = dict() loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"]) loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight _, loss_style = self.style_loss(generated["a"], batch["a"])
loss["style"] = self.config.loss.style.weight * loss_style
loss["perceptual"] = self.config.loss.perceptual.weight * loss_perceptual
for phase in "ab": for phase in "ab":
pred_fake = self.discriminators[phase](generated[phase]) pred_fake = self.discriminators[phase](generated[phase])
loss[f"gan_{phase}"] = 0 loss[f"gan_{phase}"] = 0
@ -93,10 +99,7 @@ class TAFGEngineKernel(EngineKernel):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"]) loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss( loss["edge"] = self.config.loss.edge.weight * self.edge_loss(generated["b"], batch["edge_a"][:, 0:1, :, :])
# self.generators["main"].module.style_encoders["b"](batch["b"]),
# self.generators["main"].module.style_encoders["b"](generated["b"])
# )
return loss return loss
def criterion_discriminators(self, batch, generated) -> dict: def criterion_discriminators(self, batch, generated) -> dict:

106
engine/TSIT.py Normal file
View File

@ -0,0 +1,106 @@
from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TSITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
def build_models(self) -> (dict, dict):
generators = dict(
main=build_model(self.config.model.generator)
)
discriminators = dict(
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["b"])
self.logger.debug(generators["main"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
with torch.set_grad_enabled(not inference):
fake = dict(
b=self.generators["main"](content_img=batch["a"], style_img=batch["b"])
)
return fake
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
for phase in "b":
pred_fake = self.discriminators[phase](generated[phase])
loss[f"gan_{phase}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
if self.config.loss.fm.weight > 0 and phase == "b":
pred_real = self.discriminators[phase](batch[phase])
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase])
pred_fake = self.discriminators[phase](generated[phase].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
return dict(
b=[batch["a"].detach(), batch["b"].detach(), generated["b"].detach()]
)
def run(task, config, _):
kernel = TSITEngineKernel(config)
run_kernel(task, config, kernel)

View File

@ -1,9 +1,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .base import ResidualBlock
from model.registry import MODEL
from torchvision.models import vgg19 from torchvision.models import vgg19
from model.normalization import select_norm_layer from model.normalization import select_norm_layer
from model.registry import MODEL
from .base import ResidualBlock
class VGG19StyleEncoder(nn.Module): class VGG19StyleEncoder(nn.Module):
@ -169,25 +170,37 @@ class StyleGenerator(nn.Module):
@MODEL.register_module("TAFG-Generator") @MODEL.register_module("TAFG-Generator")
class Generator(nn.Module): class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8, def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512,
num_adain_blocks=8, num_res_blocks=4,
base_channels=64, padding_mode="reflect"): base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_blocks = num_blocks self.num_adain_blocks=num_adain_blocks
self.style_encoders = nn.ModuleDict({ self.style_encoders = nn.ModuleDict({
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks, "a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
base_channels=base_channels, padding_mode=padding_mode), base_channels=base_channels, padding_mode=padding_mode),
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks, "b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
base_channels=base_channels, padding_mode=padding_mode), base_channels=base_channels, padding_mode=padding_mode),
}) })
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks, self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=8,
padding_mode=padding_mode, norm_type="IN") padding_mode=padding_mode, norm_type="IN")
res_block_channels = 2 ** 2 * base_channels res_block_channels = 2 ** 2 * base_channels
self.adain_resnet_a = nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks) self.resnet = nn.ModuleDict({
]) "a": nn.Sequential(*[
self.adain_resnet_b = nn.ModuleList([ ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks) ]),
]) "b": nn.Sequential(*[
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
])
})
self.adain_resnet = nn.ModuleDict({
"a": nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
]),
"b": nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
])
})
self.decoders = nn.ModuleDict({ self.decoders = nn.ModuleDict({
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode), "a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode),
@ -196,10 +209,10 @@ class Generator(nn.Module):
def forward(self, content_img, style_img, which_decoder: str = "a"): def forward(self, content_img, style_img, which_decoder: str = "a"):
x = self.content_encoder(content_img) x = self.content_encoder(content_img)
x = self.resnet[which_decoder](x)
styles = self.style_encoders[which_decoder](style_img) styles = self.style_encoders[which_decoder](style_img)
styles = torch.chunk(styles, self.num_blocks * 2, dim=1) styles = torch.chunk(styles, self.num_adain_blocks * 2, dim=1)
resnet = self.adain_resnet_a if which_decoder == "a" else self.adain_resnet_b for i, ar in enumerate(self.adain_resnet[which_decoder]):
for i, ar in enumerate(resnet):
ar.norm1.set_style(styles[2 * i]) ar.norm1.set_style(styles[2 * i])
ar.norm2.set_style(styles[2 * i + 1]) ar.norm2.set_style(styles[2 * i + 1])
x = ar(x) x = ar(x)

192
model/GAN/TSIT.py Normal file
View File

@ -0,0 +1,192 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
from model.normalization import AdaptiveInstanceNorm2d
from model.normalization import select_norm_layer
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, padding_mode='zeros', norm_type="IN", use_bias=None,
use_spectral=True):
super().__init__()
self.padding_mode = padding_mode
self.use_bias = use_bias
self.use_spectral = use_spectral
if use_bias is None:
# Only for IN, use bias since it does not have affine parameters.
self.use_bias = norm_type == "IN"
norm_layer = select_norm_layer(norm_type)
self.main = nn.Sequential(
self.conv_block(in_channels, in_channels),
norm_layer(in_channels),
nn.LeakyReLU(0.2, inplace=True),
self.conv_block(in_channels, out_channels),
norm_layer(out_channels),
nn.LeakyReLU(0.2, inplace=True),
)
self.skip = nn.Sequential(
self.conv_block(in_channels, out_channels, padding=0, kernel_size=1),
norm_layer(out_channels),
nn.LeakyReLU(0.2, inplace=True),
)
def conv_block(self, in_channels, out_channels, kernel_size=3, padding=1):
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding,
padding_mode=self.padding_mode, bias=self.use_bias)
if self.use_spectral:
return nn.utils.spectral_norm(conv)
else:
return conv
def forward(self, x):
return self.main(x) + self.skip(x)
class Interpolation(nn.Module):
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolation, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
recompute_scale_factor=False)
def __repr__(self):
return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
class FADE(nn.Module):
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
super().__init__()
self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
padding_mode="zeros")
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
padding_mode="zeros")
def forward(self, x, feature):
alpha = self.alpha_conv(feature)
beta = self.beta_conv(feature)
x = self.bn(x)
return alpha * x + beta
class FADEResBlock(nn.Module):
def __init__(self, use_spectral, features_channels, in_channels, out_channels):
super().__init__()
self.main = nn.Sequential(
FADE(use_spectral, features_channels, in_channels),
nn.LeakyReLU(0.2, inplace=True),
conv_block(use_spectral, in_channels, in_channels, kernel_size=3, padding=1),
FADE(use_spectral, features_channels, in_channels),
nn.LeakyReLU(0.2, inplace=True),
conv_block(use_spectral, in_channels, out_channels, kernel_size=3, padding=1),
)
self.skip = nn.Sequential(
FADE(use_spectral, features_channels, in_channels),
nn.LeakyReLU(0.2, inplace=True),
conv_block(use_spectral, in_channels, out_channels, kernel_size=1, padding=0),
)
self.up_sample = Interpolation(2, mode="nearest")
@staticmethod
def forward_with_fade(module, x, feature):
for layer in module:
if layer.__class__.__name__ == "FADE":
x = layer(x, feature)
else:
x = layer(x)
return x
def forward(self, x, feature):
out = self.forward_with_fade(self.main, x, feature) + self.forward_with_fade(self.main, x, feature)
return self.up_sample(out)
def conv_block(use_spectral, in_channels, out_channels, **kwargs):
conv = nn.Conv2d(in_channels, out_channels, **kwargs)
return nn.utils.spectral_norm(conv) if use_spectral else conv
@MODEL.register_module("TSIT-Generator")
class TSITGenerator(nn.Module):
def __init__(self, num_blocks=7, base_channels=64, content_in_channels=3, style_in_channels=3,
out_channels=3, use_spectral=True, input_layer_type="conv1x1"):
super().__init__()
self.num_blocks = num_blocks
self.base_channels = base_channels
self.use_spectral = use_spectral
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type)
self.content_stream = self.build_stream()
self.style_stream = self.build_stream()
self.generator = self.build_generator()
self.end_conv = nn.Sequential(
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
nn.Tanh()
)
def build_generator(self):
stream_sequence = []
multiple_now = min(2 ** self.num_blocks, 2 ** 4)
for i in range(1, self.num_blocks + 1):
m = self.num_blocks - i
multiple_prev = multiple_now
multiple_now = min(2 ** m, 2 ** 4)
stream_sequence.append(nn.Sequential(
AdaptiveInstanceNorm2d(multiple_prev * self.base_channels),
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
multiple_now * self.base_channels)
))
return nn.ModuleList(stream_sequence)
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
if input_layer_type == "conv7x7":
return nn.Sequential(
conv_block(self.use_spectral, in_channels, out_channels, kernel_size=7, stride=1,
padding_mode="zeros", padding=3, bias=True),
select_norm_layer("IN")(out_channels),
nn.ReLU(inplace=True)
)
elif input_layer_type == "conv1x1":
return conv_block(self.use_spectral, in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
raise NotImplemented
def build_stream(self):
multiple_now = 1
stream_sequence = []
for i in range(1, self.num_blocks + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
stream_sequence.append(nn.Sequential(
Interpolation(scale_factor=0.5, mode="nearest"),
ResBlock(multiple_prev * self.base_channels, multiple_now * self.base_channels,
use_spectral=self.use_spectral)
))
return nn.ModuleList(stream_sequence)
def forward(self, content_img, style_img):
c = self.content_input_layer(content_img)
s = self.style_input_layer(style_img)
content_features = []
style_features = []
for i in range(self.num_blocks):
s = self.style_stream[i](s)
c = self.content_stream[i](c)
content_features.append(c)
style_features.append(s)
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
for i in range(self.num_blocks):
m = - i - 1
layer = self.generator[i]
layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1))
z = layer[0](z)
z = layer[1](z, content_features[m])
return self.end_conv(z)

View File

@ -4,3 +4,4 @@ import model.GAN.TAFG
import model.GAN.UGATIT import model.GAN.UGATIT
import model.GAN.wrapper import model.GAN.wrapper
import model.GAN.base import model.GAN.base
import model.GAN.TSIT

View File

@ -1,6 +1,7 @@
import torch.nn as nn
import functools import functools
import torch import torch
import torch.nn as nn
def select_norm_layer(norm_type): def select_norm_layer(norm_type):

8
run.sh
View File

@ -5,16 +5,18 @@ TASK=$2
GPUS=$3 GPUS=$3
MORE_ARG=${*:4} MORE_ARG=${*:4}
RANDOM_MASTER=$(shuf -i 2000-65000 -n 1)
_command="print(len('${GPUS}'.split(',')))" _command="print(len('${GPUS}'.split(',')))"
GPU_COUNT=$(python3 -c "${_command}") GPU_COUNT=$(python3 -c "${_command}")
echo "GPU_COUNT:${GPU_COUNT}" echo "GPU_COUNT:${GPU_COUNT}"
echo CUDA_VISIBLE_DEVICES=$GPUS \ echo CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \ PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed "$MORE_ARG" main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed "$MORE_ARG"
CUDA_VISIBLE_DEVICES=$GPUS \ CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \ PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
--master_port=${RANDOM_MASTER} \
main.py "$TASK" "$CONFIG" $MORE_ARG --backup_config --setup_output_dir --setup_random_seed main.py "$TASK" "$CONFIG" $MORE_ARG --backup_config --setup_output_dir --setup_random_seed

46
tool/dump_tensorboard.py Normal file
View File

@ -0,0 +1,46 @@
#!/usr/bin/env python
# edit from https://gist.github.com/hysts/81a0d30ac4f33dfa0c8859383aec42c2
import argparse
import pathlib
import cv2
import numpy as np
from tensorboard.backend.event_processing import event_accumulator
def save(outdir: pathlib.Path, tag, event_acc):
events = event_acc.Images(tag)
for index, event in enumerate(events):
s = np.frombuffer(event.encoded_image_string, dtype=np.uint8)
image = cv2.imdecode(s, cv2.IMREAD_COLOR)
outpath = outdir / f"{tag.replace('/', '_')}@{index}.png"
cv2.imwrite(outpath.as_posix(), image)
# ffmpeg -framerate 1 -i ./tmp/test_b/%04d.jpg -vcodec mpeg4 test_b.mp4
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, required=True)
parser.add_argument('--outdir', type=str, required=True)
parser.add_argument("--tag", type=str, required=False)
args = parser.parse_args()
event_acc = event_accumulator.EventAccumulator(args.path, size_guidance={'images': 0})
event_acc.Reload()
outdir = pathlib.Path(args.outdir)
outdir.mkdir(exist_ok=True, parents=True)
if args.tag is None:
for tag in event_acc.Tags()['images']:
save(outdir, tag, event_acc)
else:
assert args.tag in event_acc.Tags()['images'], f"{args.tag} not in {event_acc.Tags()['images']}"
save(outdir, args.tag, event_acc)
if __name__ == '__main__':
main()