add MUNIT
This commit is contained in:
parent
f70658eaed
commit
2ff4a91057
@ -1,11 +1,11 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="PublishConfigData" autoUpload="Always" serverName="15d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
<component name="PublishConfigData" autoUpload="Always" serverName="14d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||||
<serverData>
|
<serverData>
|
||||||
<paths name="14d">
|
<paths name="14d">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
<mappings>
|
<mappings>
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
<mapping deploy="raycv" local="$PROJECT_DIR$" web="/" />
|
||||||
</mappings>
|
</mappings>
|
||||||
</serverdata>
|
</serverdata>
|
||||||
</paths>
|
</paths>
|
||||||
|
|||||||
132
configs/synthesizers/MUNIT.yml
Normal file
132
configs/synthesizers/MUNIT.yml
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
name: MUNIT-edges2shoes
|
||||||
|
engine: MUNIT
|
||||||
|
result_dir: ./result
|
||||||
|
max_pairs: 1000000
|
||||||
|
|
||||||
|
handler:
|
||||||
|
clear_cuda_cache: True
|
||||||
|
set_epoch_for_dist_sampler: True
|
||||||
|
checkpoint:
|
||||||
|
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||||
|
n_saved: 2
|
||||||
|
tensorboard:
|
||||||
|
scalar: 100 # log scalar `scalar` times per epoch
|
||||||
|
image: 2 # log image `image` times per epoch
|
||||||
|
|
||||||
|
|
||||||
|
misc:
|
||||||
|
random_seed: 324
|
||||||
|
|
||||||
|
model:
|
||||||
|
generator:
|
||||||
|
_type: MUNIT-Generator
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
base_channels: 64
|
||||||
|
num_sampling: 2
|
||||||
|
num_style_dim: 8
|
||||||
|
num_style_conv: 4
|
||||||
|
num_content_res_blocks: 4
|
||||||
|
num_decoder_res_blocks: 4
|
||||||
|
num_fusion_dim: 256
|
||||||
|
num_fusion_blocks: 3
|
||||||
|
|
||||||
|
discriminator:
|
||||||
|
_type: MultiScaleDiscriminator
|
||||||
|
num_scale: 2
|
||||||
|
discriminator_cfg:
|
||||||
|
_type: PatchDiscriminator
|
||||||
|
in_channels: 3
|
||||||
|
base_channels: 64
|
||||||
|
use_spectral: True
|
||||||
|
need_intermediate_feature: True
|
||||||
|
|
||||||
|
loss:
|
||||||
|
gan:
|
||||||
|
loss_type: lsgan
|
||||||
|
real_label_val: 1.0
|
||||||
|
fake_label_val: 0.0
|
||||||
|
weight: 1.0
|
||||||
|
perceptual:
|
||||||
|
layer_weights:
|
||||||
|
"1": 0.03125
|
||||||
|
"6": 0.0625
|
||||||
|
"11": 0.125
|
||||||
|
"20": 0.25
|
||||||
|
"29": 1
|
||||||
|
criterion: 'L1'
|
||||||
|
style_loss: False
|
||||||
|
perceptual_loss: True
|
||||||
|
weight: 0
|
||||||
|
recon:
|
||||||
|
level: 1
|
||||||
|
style:
|
||||||
|
weight: 1
|
||||||
|
content:
|
||||||
|
weight: 1
|
||||||
|
image:
|
||||||
|
weight: 10
|
||||||
|
cycle:
|
||||||
|
weight: 0
|
||||||
|
|
||||||
|
optimizers:
|
||||||
|
generator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 0.0001
|
||||||
|
betas: [ 0.5, 0.999 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
discriminator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 4e-4
|
||||||
|
betas: [ 0.5, 0.999 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
|
||||||
|
data:
|
||||||
|
train:
|
||||||
|
scheduler:
|
||||||
|
start_proportion: 0.5
|
||||||
|
target_lr: 0
|
||||||
|
buffer_size: 50
|
||||||
|
dataloader:
|
||||||
|
batch_size: 1
|
||||||
|
shuffle: True
|
||||||
|
num_workers: 1
|
||||||
|
pin_memory: True
|
||||||
|
drop_last: True
|
||||||
|
dataset:
|
||||||
|
_type: GenerationUnpairedDataset
|
||||||
|
root_a: "/data/i2i/edges2shoes/trainA"
|
||||||
|
root_b: "/data/i2i/edges2shoes/trainB"
|
||||||
|
random_pair: True
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 286, 286 ]
|
||||||
|
- RandomCrop:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- RandomHorizontalFlip
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
test:
|
||||||
|
which: dataset
|
||||||
|
dataloader:
|
||||||
|
batch_size: 8
|
||||||
|
shuffle: False
|
||||||
|
num_workers: 1
|
||||||
|
pin_memory: False
|
||||||
|
drop_last: False
|
||||||
|
dataset:
|
||||||
|
_type: GenerationUnpairedDataset
|
||||||
|
root_a: "/data/i2i/edges2shoes/testA"
|
||||||
|
root_b: "/data/i2i/edges2shoes/testB"
|
||||||
|
random_pair: False
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
154
engine/MUNIT.py
Normal file
154
engine/MUNIT.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import ignite.distributed as idist
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||||
|
from engine.util.build import build_model
|
||||||
|
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||||
|
from loss.gan import GANLoss
|
||||||
|
|
||||||
|
|
||||||
|
def mse_loss(x, target_flag):
|
||||||
|
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||||
|
|
||||||
|
|
||||||
|
def bce_loss(x, target_flag):
|
||||||
|
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||||
|
|
||||||
|
|
||||||
|
class MUNITEngineKernel(EngineKernel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
||||||
|
perceptual_loss_cfg.pop("weight")
|
||||||
|
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||||
|
gan_loss_cfg.pop("weight")
|
||||||
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
|
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||||
|
self.train_generator_first = False
|
||||||
|
|
||||||
|
def build_models(self) -> (dict, dict):
|
||||||
|
generators = dict(
|
||||||
|
a=build_model(self.config.model.generator),
|
||||||
|
b=build_model(self.config.model.generator)
|
||||||
|
)
|
||||||
|
discriminators = dict(
|
||||||
|
a=build_model(self.config.model.discriminator),
|
||||||
|
b=build_model(self.config.model.discriminator)
|
||||||
|
)
|
||||||
|
self.logger.debug(discriminators["a"])
|
||||||
|
self.logger.debug(generators["a"])
|
||||||
|
|
||||||
|
return generators, discriminators
|
||||||
|
|
||||||
|
def setup_after_g(self):
|
||||||
|
for discriminator in self.discriminators.values():
|
||||||
|
discriminator.requires_grad_(True)
|
||||||
|
|
||||||
|
def setup_before_g(self):
|
||||||
|
for discriminator in self.discriminators.values():
|
||||||
|
discriminator.requires_grad_(False)
|
||||||
|
|
||||||
|
def forward(self, batch, inference=False) -> dict:
|
||||||
|
styles = dict()
|
||||||
|
contents = dict()
|
||||||
|
images = dict()
|
||||||
|
|
||||||
|
for phase in "ab":
|
||||||
|
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
|
||||||
|
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
|
||||||
|
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
|
||||||
|
|
||||||
|
for phase in ("a2b", "b2a"):
|
||||||
|
# images["a2b"] = Gb.decode(content_a, random_style_b)
|
||||||
|
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
|
||||||
|
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
|
||||||
|
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
|
||||||
|
if self.config.loss.recon.cycle.weight > 0:
|
||||||
|
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
|
||||||
|
return dict(styles=styles, contents=contents, images=images)
|
||||||
|
|
||||||
|
def criterion_generators(self, batch, generated) -> dict:
|
||||||
|
loss = dict()
|
||||||
|
for phase in "ab":
|
||||||
|
loss[f"recon_image_{phase}"] = self.config.loss.recon.image.weight * self.recon_loss(
|
||||||
|
batch[phase], generated["images"]["{0}2{0}".format(phase)])
|
||||||
|
loss[f"recon_content_{phase}"] = self.config.loss.recon.content.weight * self.recon_loss(
|
||||||
|
generated["contents"][phase], generated["contents"]["a2b" if phase == "a" else "b2a"])
|
||||||
|
loss[f"recon_style_{phase}"] = self.config.loss.recon.style.weight * self.recon_loss(
|
||||||
|
generated["styles"][f"random_{phase}"], generated["styles"]["b2a" if phase == "a" else "a2b"])
|
||||||
|
pred_fake = self.discriminators[phase](generated["images"]["b2a" if phase == "a" else "a2b"])
|
||||||
|
loss[f"gan_{phase}"] = 0
|
||||||
|
for sub_pred_fake in pred_fake:
|
||||||
|
# last output is actual prediction
|
||||||
|
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
|
||||||
|
if self.config.loss.recon.cycle.weight > 0:
|
||||||
|
loss[f"recon_cycle_{phase}"] = self.config.loss.recon.cycle.weight * self.recon_loss(
|
||||||
|
batch[phase], generated["images"]["a2b2a" if phase == "a" else "b2a2b"])
|
||||||
|
if self.config.loss.perceptual.weight > 0:
|
||||||
|
loss[f"perceptual_{phase}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||||
|
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def criterion_discriminators(self, batch, generated) -> dict:
|
||||||
|
loss = dict()
|
||||||
|
for phase in ("a2b", "b2a"):
|
||||||
|
pred_real = self.discriminators[phase[-1]](batch[phase[-1]])
|
||||||
|
pred_fake = self.discriminators[phase[-1]](generated["images"][phase].detach())
|
||||||
|
loss[f"gan_{phase[-1]}"] = 0
|
||||||
|
for i in range(len(pred_fake)):
|
||||||
|
loss[f"gan_{phase[-1]}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||||
|
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def intermediate_images(self, batch, generated) -> dict:
|
||||||
|
"""
|
||||||
|
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||||
|
:param batch:
|
||||||
|
:param generated: dict of images
|
||||||
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||||
|
"""
|
||||||
|
generated = {img: generated["images"][img].detach() for img in generated["images"]}
|
||||||
|
images = dict()
|
||||||
|
for phase in "ab":
|
||||||
|
images[phase] = [batch[phase].detach(), generated["{0}2{0}".format(phase)],
|
||||||
|
generated["a2b" if phase == "a" else "b2a"]]
|
||||||
|
if self.config.loss.recon.cycle.weight > 0:
|
||||||
|
images[phase].append(generated["a2b2a" if phase == "a" else "b2a2b"])
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
class MUNITTestEngineKernel(TestEngineKernel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def build_generators(self) -> dict:
|
||||||
|
generators = dict(
|
||||||
|
a=build_model(self.config.model.generator),
|
||||||
|
b=build_model(self.config.model.generator)
|
||||||
|
)
|
||||||
|
return generators
|
||||||
|
|
||||||
|
def to_load(self):
|
||||||
|
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
||||||
|
|
||||||
|
def inference(self, batch):
|
||||||
|
with torch.no_grad():
|
||||||
|
fake, _, _ = self.generators["a2b"](batch[0])
|
||||||
|
return fake.detach()
|
||||||
|
|
||||||
|
|
||||||
|
def run(task, config, _):
|
||||||
|
if task == "train":
|
||||||
|
kernel = MUNITEngineKernel(config)
|
||||||
|
run_kernel(task, config, kernel)
|
||||||
|
elif task == "test":
|
||||||
|
kernel = MUNITTestEngineKernel(config)
|
||||||
|
run_kernel(task, config, kernel)
|
||||||
|
else:
|
||||||
|
raise NotImplemented
|
||||||
@ -6,7 +6,6 @@ channels:
|
|||||||
dependencies:
|
dependencies:
|
||||||
- python=3.8
|
- python=3.8
|
||||||
- numpy
|
- numpy
|
||||||
- ipython
|
|
||||||
- tqdm
|
- tqdm
|
||||||
- pyyaml
|
- pyyaml
|
||||||
- pytorch=1.6.*
|
- pytorch=1.6.*
|
||||||
|
|||||||
154
model/GAN/MUNIT.py
Normal file
154
model/GAN/MUNIT.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from model import MODEL
|
||||||
|
from model.GAN.base import Conv2dBlock, ResBlock
|
||||||
|
from model.normalization import select_norm_layer
|
||||||
|
|
||||||
|
|
||||||
|
class StyleEncoder(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False,
|
||||||
|
padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
|
||||||
|
super(StyleEncoder, self).__init__()
|
||||||
|
|
||||||
|
sequence = [Conv2dBlock(
|
||||||
|
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||||
|
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
|
||||||
|
)]
|
||||||
|
|
||||||
|
multiple_now = 1
|
||||||
|
for i in range(1, num_conv + 1):
|
||||||
|
multiple_prev = multiple_now
|
||||||
|
multiple_now = min(2 ** i, 2 ** 2)
|
||||||
|
sequence.append(Conv2dBlock(
|
||||||
|
multiple_prev * base_channels, multiple_now * base_channels,
|
||||||
|
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
||||||
|
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
|
||||||
|
))
|
||||||
|
sequence.append(nn.AdaptiveAvgPool2d(1))
|
||||||
|
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
|
||||||
|
sequence.append(nn.Conv2d(multiple_now * base_channels, out_dim, kernel_size=1, stride=1, padding=0))
|
||||||
|
self.model = nn.Sequential(*sequence)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.model(x).view(x.size(0), -1)
|
||||||
|
|
||||||
|
|
||||||
|
class ContentEncoder(nn.Module):
|
||||||
|
def __init__(self, in_channels, num_down_sampling, num_res_blocks, base_channels=64, use_spectral_norm=False,
|
||||||
|
padding_mode='reflect', activation_type="ReLU", norm_type="IN"):
|
||||||
|
super().__init__()
|
||||||
|
sequence = [Conv2dBlock(
|
||||||
|
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||||
|
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
|
||||||
|
)]
|
||||||
|
|
||||||
|
for i in range(num_down_sampling):
|
||||||
|
sequence.append(Conv2dBlock(
|
||||||
|
base_channels * (2 ** i), base_channels * (2 ** (i + 1)),
|
||||||
|
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
||||||
|
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
|
||||||
|
))
|
||||||
|
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
sequence.append(
|
||||||
|
ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type,
|
||||||
|
activation_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sequence = nn.Sequential(*sequence)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.sequence(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, num_up_sampling, num_res_blocks,
|
||||||
|
use_spectral_norm=False, res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU",
|
||||||
|
padding_mode='reflect'):
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
self.res_norm_type = res_norm_type
|
||||||
|
self.res_blocks = nn.ModuleList([
|
||||||
|
ResBlock(in_channels, use_spectral_norm, padding_mode, res_norm_type, activation_type=activation_type)
|
||||||
|
for _ in range(num_res_blocks)
|
||||||
|
])
|
||||||
|
sequence = list()
|
||||||
|
channels = in_channels
|
||||||
|
for i in range(num_up_sampling):
|
||||||
|
sequence.append(nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=2),
|
||||||
|
Conv2dBlock(channels, channels // 2,
|
||||||
|
kernel_size=5, stride=1, padding=2, padding_mode=padding_mode,
|
||||||
|
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
|
||||||
|
),
|
||||||
|
))
|
||||||
|
channels = channels // 2
|
||||||
|
sequence.append(
|
||||||
|
Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect",
|
||||||
|
use_spectral_norm=use_spectral_norm, activation_type="Tanh", norm_type="NONE"))
|
||||||
|
self.sequence = nn.Sequential(*sequence)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for blk in self.res_blocks:
|
||||||
|
x = blk(x)
|
||||||
|
return self.sequence(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Fusion(nn.Module):
|
||||||
|
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
|
||||||
|
super().__init__()
|
||||||
|
norm_layer = select_norm_layer(norm_type)
|
||||||
|
self.start_fc = nn.Sequential(
|
||||||
|
nn.Linear(in_features, base_features),
|
||||||
|
norm_layer(base_features),
|
||||||
|
nn.ReLU(True),
|
||||||
|
)
|
||||||
|
self.fcs = nn.Sequential(*[
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Linear(base_features, base_features),
|
||||||
|
norm_layer(base_features),
|
||||||
|
nn.ReLU(True),
|
||||||
|
) for _ in range(n_blocks - 2)
|
||||||
|
])
|
||||||
|
self.end_fc = nn.Sequential(
|
||||||
|
nn.Linear(base_features, out_features),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.start_fc(x)
|
||||||
|
x = self.fcs(x)
|
||||||
|
return self.end_fc(x)
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL.register_module("MUNIT-Generator")
|
||||||
|
class Generator(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, base_channels, num_sampling, num_style_dim, num_style_conv,
|
||||||
|
num_content_res_blocks, num_decoder_res_blocks, num_fusion_dim, num_fusion_blocks,
|
||||||
|
use_spectral_norm=False, activation_type="ReLU", padding_mode='reflect'):
|
||||||
|
super().__init__()
|
||||||
|
self.num_decoder_res_blocks = num_decoder_res_blocks
|
||||||
|
self.content_encoder = ContentEncoder(in_channels, num_sampling, num_content_res_blocks, base_channels,
|
||||||
|
use_spectral_norm, padding_mode, activation_type, norm_type="IN")
|
||||||
|
self.style_encoder = StyleEncoder(in_channels, num_style_dim, num_style_conv, base_channels, use_spectral_norm,
|
||||||
|
padding_mode, activation_type, norm_type="NONE")
|
||||||
|
content_channels = base_channels * (2 ** 2)
|
||||||
|
self.decoder = Decoder(content_channels, out_channels, num_sampling,
|
||||||
|
num_decoder_res_blocks, use_spectral_norm, "AdaIN", norm_type="LN",
|
||||||
|
activation_type=activation_type, padding_mode=padding_mode)
|
||||||
|
self.fusion = Fusion(num_style_dim, num_decoder_res_blocks * 2 * content_channels * 2,
|
||||||
|
base_features=num_fusion_dim, n_blocks=num_fusion_blocks, norm_type="NONE")
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.content_encoder(x), self.style_encoder(x)
|
||||||
|
|
||||||
|
def decode(self, content, style):
|
||||||
|
as_param_style = torch.chunk(self.fusion(style), self.num_decoder_res_blocks * 2, dim=1)
|
||||||
|
# set style for decoder
|
||||||
|
for i, blk in enumerate(self.decoder.res_blocks):
|
||||||
|
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
||||||
|
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
||||||
|
return self.decoder(content)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
content, style = self.encode(x)
|
||||||
|
return self.decode(content, style)
|
||||||
@ -1,10 +1,11 @@
|
|||||||
import math
|
from functools import partial
|
||||||
|
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from model.normalization import select_norm_layer
|
|
||||||
from model import MODEL
|
from model import MODEL
|
||||||
|
from model.normalization import select_norm_layer
|
||||||
|
|
||||||
|
|
||||||
class GANImageBuffer(object):
|
class GANImageBuffer(object):
|
||||||
@ -137,3 +138,66 @@ class ResidualBlock(nn.Module):
|
|||||||
x = self.relu1(self.norm1(self.conv1(x)))
|
x = self.relu1(self.norm1(self.conv1(x)))
|
||||||
x = self.norm2(self.conv2(x))
|
x = self.norm2(self.conv2(x))
|
||||||
return x + res
|
return x + res
|
||||||
|
|
||||||
|
|
||||||
|
_DO_NO_THING_FUNC = lambda x: x
|
||||||
|
|
||||||
|
|
||||||
|
def select_activation(t):
|
||||||
|
if t == "ReLU":
|
||||||
|
return partial(nn.ReLU, inplace=True)
|
||||||
|
elif t == "LeakyReLU":
|
||||||
|
return partial(nn.LeakyReLU, negative_slope=0.2, inplace=True)
|
||||||
|
elif t == "Tanh":
|
||||||
|
return partial(nn.Tanh)
|
||||||
|
elif t == "NONE":
|
||||||
|
return _DO_NO_THING_FUNC
|
||||||
|
else:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
def _use_bias_checker(norm_type):
|
||||||
|
return norm_type not in ["IN", "BN", "AdaIN"]
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, out_channels: int, use_spectral_norm=False, activation_type="ReLU",
|
||||||
|
bias=None, norm_type="NONE", **conv_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.activation_type = activation_type
|
||||||
|
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
||||||
|
conv = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||||
|
self.convolution = nn.utils.spectral_norm(conv) if use_spectral_norm else conv
|
||||||
|
if norm_type != "NONE":
|
||||||
|
self.normalization = select_norm_layer(norm_type)(out_channels)
|
||||||
|
if activation_type != "NONE":
|
||||||
|
self.activation = select_activation(activation_type)()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.convolution(x)
|
||||||
|
if self.norm_type != "NONE":
|
||||||
|
x = self.normalization(x)
|
||||||
|
if self.activation_type != "NONE":
|
||||||
|
x = self.activation(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, num_channels, use_spectral_norm=False, padding_mode='reflect',
|
||||||
|
norm_type="IN", activation_type="relu", use_bias=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_type = norm_type
|
||||||
|
if use_bias is None:
|
||||||
|
# bias will be canceled after channel wise normalization
|
||||||
|
use_bias = _use_bias_checker(norm_type)
|
||||||
|
|
||||||
|
self.conv1 = Conv2dBlock(num_channels, num_channels, use_spectral_norm,
|
||||||
|
kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias,
|
||||||
|
norm_type=norm_type, activation_type=activation_type)
|
||||||
|
self.conv2 = Conv2dBlock(num_channels, num_channels, use_spectral_norm,
|
||||||
|
kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias,
|
||||||
|
norm_type=norm_type, activation_type="NONE")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv2(self.conv1(x)) + x
|
||||||
|
|||||||
@ -5,3 +5,4 @@ import model.GAN.UGATIT
|
|||||||
import model.GAN.wrapper
|
import model.GAN.wrapper
|
||||||
import model.GAN.base
|
import model.GAN.base
|
||||||
import model.GAN.TSIT
|
import model.GAN.TSIT
|
||||||
|
import model.GAN.MUNIT
|
||||||
Loading…
Reference in New Issue
Block a user