This commit is contained in:
Ray Wong 2020-09-17 09:34:53 +08:00
parent 2ff4a91057
commit 61e04de8a5
9 changed files with 168 additions and 288 deletions

View File

@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="14d" remoteFilesAllowedToDisappearOnAutoupload="false">
<component name="PublishConfigData" autoUpload="Always" serverName="15d" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="14d">
<serverdata>

View File

@ -1,4 +1,4 @@
name: TAFG
name: TAFG-vox2
engine: TAFG
result_dir: ./result
max_pairs: 1500000
@ -11,11 +11,11 @@ handler:
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
image: 4 # log image `image` times per epoch
misc:
random_seed: 324
random_seed: 123
model:
generator:
@ -24,7 +24,9 @@ model:
style_in_channels: 3
content_in_channels: 24
num_adain_blocks: 8
num_res_blocks: 0
num_res_blocks: 8
use_spectral_norm: True
style_use_fc: False
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
@ -51,26 +53,22 @@ loss:
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 10
style:
layer_weights:
"3": 1
criterion: 'L1'
style_loss: True
perceptual_loss: False
weight: 10
fm:
level: 1
weight: 10
weight: 0
recon:
level: 1
weight: 10
style_recon:
level: 1
weight: 0
weight: 5
content_recon:
level: 1
weight: 10
edge:
weight: 10
hed_pretrained_model_path: ./network-bsds500.pytorch
cycle:
level: 1
weight: 10
optimizers:
generator:
@ -91,9 +89,9 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 8
batch_size: 1
shuffle: True
num_workers: 2
num_workers: 1
pin_memory: True
drop_last: True
dataset:
@ -116,7 +114,7 @@ data:
test:
which: video_dataset
dataloader:
batch_size: 8
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
@ -145,7 +143,7 @@ data:
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]

View File

@ -203,7 +203,7 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
op = Path(origin_path)
if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_")
use_landmark = True
use_landmark = op.parent.name.endswith("A")
else:
edge_type = self.edge_type
use_landmark = False
@ -225,14 +225,11 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
if self.with_path:
output = {"a": self.A[a_idx], "b": self.B[b_idx]}
output["edge_a"] = output["a"][1]
return output
output = dict()
output["a"], path_a = self.A[a_idx]
output["b"], path_b = self.B[b_idx]
output["edge_a"] = self.get_edge(path_a)
output = dict(a={}, b={})
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
for p in "ab":
output[p]["edge"] = self.get_edge(output[p]["path"])
return output
def __len__(self):

View File

@ -58,20 +58,20 @@ class MUNITEngineKernel(EngineKernel):
styles = dict()
contents = dict()
images = dict()
with torch.set_grad_enabled(not inference):
for phase in "ab":
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
for phase in "ab":
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
for phase in ("a2b", "b2a"):
# images["a2b"] = Gb.decode(content_a, random_style_b)
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
if self.config.loss.recon.cycle.weight > 0:
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
return dict(styles=styles, contents=contents, images=images)
for phase in ("a2b", "b2a"):
# images["a2b"] = Gb.decode(content_a, random_style_b)
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
if self.config.loss.recon.cycle.weight > 0:
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict:
loss = dict()

View File

@ -3,8 +3,7 @@ from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from ignite.engine import Events
from omegaconf import read_write, OmegaConf
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
@ -21,17 +20,14 @@ class TAFGEngineKernel(EngineKernel):
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
self.content_recon_loss = nn.L1Loss() if config.loss.content_recon.level == 1 else nn.MSELoss()
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device())
@ -67,47 +63,67 @@ class TAFGEngineKernel(EngineKernel):
def forward(self, batch, inference=False) -> dict:
generator = self.generators["main"]
batch = self._process_batch(batch, inference)
styles = dict()
contents = dict()
images = dict()
with torch.set_grad_enabled(not inference):
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"),
b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"),
)
return fake
for ph in "ab":
contents[ph], styles[ph] = generator.encode(batch[ph]["edge"], batch[ph]["img"], ph, ph)
for ph in ("a2b", "b2a"):
images[f"fake_{ph[-1]}"] = generator.decode(contents[ph[0]], styles[ph[-1]], ph[-1])
contents["recon_a"], styles["recon_b"] = generator.encode(
self.edge_loss.edge_extractor(images["fake_b"]), images["fake_b"], "b", "b")
images["a2a"] = generator.decode(contents["a"], styles["a"], "a")
images["b2b"] = generator.decode(contents["b"], styles["recon_b"], "b")
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict:
batch = self._process_batch(batch)
loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
_, loss_style = self.style_loss(generated["a"], batch["a"])
loss["style"] = self.config.loss.style.weight * loss_style
loss["perceptual"] = self.config.loss.perceptual.weight * loss_perceptual
for phase in "ab":
pred_fake = self.discriminators[phase](generated[phase])
loss[f"gan_{phase}"] = 0
for ph in "ab":
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
pred_fake = self.discriminators[ph](generated["images"][f"fake_{ph}"])
loss[f"gan_{ph}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
loss[f"recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
generated["contents"]["a"], generated["contents"]["recon_a"]
)
loss[f"recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
generated["styles"]["b"], generated["styles"]["recon_b"]
)
if self.config.loss.fm.weight > 0 and phase == "b":
pred_real = self.discriminators[phase](batch[phase])
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
loss["edge"] = self.config.loss.edge.weight * self.edge_loss(generated["b"], batch["edge_a"][:, 0:1, :, :])
for ph in ("a2b", "b2a"):
if self.config.loss.perceptual.weight > 0:
loss[f"perceptual_{ph}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
batch[ph[0]]["img"], generated["images"][f"fake_{ph[-1]}"]
)
if self.config.loss.edge.weight > 0:
loss[f"edge_a"] = self.config.loss.edge.weight * self.edge_loss(
generated["images"]["fake_b"], batch["a"]["edge"][:, 0:1, :, :]
)
loss[f"edge_b"] = self.config.loss.edge.weight * self.edge_loss(
generated["images"]["fake_a"], batch["b"]["edge"]
)
if self.config.loss.cycle.weight > 0:
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
batch["a"]["img"], generated["images"]["cycle_a"]
)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
# batch = self._process_batch(batch)
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase])
pred_fake = self.discriminators[phase](generated[phase].detach())
pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated["images"][f"fake_{phase}"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
@ -122,17 +138,25 @@ class TAFGEngineKernel(EngineKernel):
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
batch = self._process_batch(batch)
edge = batch["edge_a"][:, 0:1, :, :]
return dict(
a=[edge.expand(-1, 3, -1, -1).detach(), batch["a"].detach(), batch["b"].detach(), generated["a"].detach(),
generated["b"].detach()]
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
generated["images"]["fake_b"].detach(),
generated["images"]["cycle_a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
generated["images"]["fake_a"].detach()]
)
def change_engine(self, config, trainer):
@trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
def change_config(engine):
with read_write(config):
config.loss.perceptual.weight = 5
pass
# @trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
# def change_config(engine):
# with read_write(config):
# config.loss.perceptual.weight = 5
def run(task, config, _):

View File

@ -132,9 +132,12 @@ def get_trainer(config, kernel: EngineKernel):
generated = kernel.forward(batch)
if kernel.train_generator_first:
# simultaneous, train G with simultaneous D
loss_g = train_generators(batch, generated)
loss_d = train_discriminators(batch, generated)
else:
# update discriminators first, not simultaneous.
# train G with updated discriminators
loss_d = train_discriminators(batch, generated)
loss_g = train_generators(batch, generated)
@ -152,8 +155,8 @@ def get_trainer(config, kernel: EngineKernel):
kernel.change_engine(config, trainer)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
@ -188,7 +191,13 @@ def get_trainer(config, kernel: EngineKernel):
for i in range(random_start, random_start + 10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
batch[k] = batch[k].view(1, *batch[k].size())
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].unsqueeze(0)
elif isinstance(batch[k], dict):
for kk in batch[k]:
if isinstance(batch[k][kk], torch.Tensor):
batch[k][kk] = batch[k][kk].unsqueeze(0)
generated = kernel.forward(batch)
images = kernel.intermediate_images(batch, generated)

View File

@ -92,6 +92,7 @@ class PerceptualLoss(nn.Module):
style_loss=False, norm_img=True, criterion='L1'):
super(PerceptualLoss, self).__init__()
self.norm_img = norm_img
assert perceptual_loss ^ style_loss, "There must be one and only one true in style or perceptual"
self.perceptual_loss = perceptual_loss
self.style_loss = style_loss
self.layer_weights = layer_weights
@ -127,8 +128,7 @@ class PerceptualLoss(nn.Module):
percep_loss = 0
for k in x_features.keys():
percep_loss += self.percep_criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
else:
percep_loss = None
return percep_loss
# calculate style loss
if self.style_loss:
@ -136,10 +136,7 @@ class PerceptualLoss(nn.Module):
for k in x_features.keys():
style_loss += self.style_criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
self.layer_weights[k]
else:
style_loss = None
return percep_loss, style_loss
return style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.

View File

@ -4,16 +4,17 @@ from torchvision.models import vgg19
from model.normalization import select_norm_layer
from model.registry import MODEL
from .base import ResidualBlock
from .MUNIT import ContentEncoder, Fusion, Decoder
from .base import ResBlock
class VGG19StyleEncoder(nn.Module):
def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE",
vgg19_layers=(0, 5, 10, 19)):
vgg19_layers=(0, 5, 10, 19), fix_vgg19=True):
super().__init__()
self.vgg19_layers = vgg19_layers
self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1]
self.vgg19.requires_grad_(False)
self.vgg19.requires_grad_(not fix_vgg19)
norm_layer = select_norm_layer(norm_type)
@ -52,203 +53,57 @@ class VGG19StyleEncoder(nn.Module):
return x.view(x.size(0), -1)
class ContentEncoder(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=8, padding_mode='reflect', norm_type="IN"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_conv = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=True),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)
# down sampling
submodules = []
num_down_sampling = 2
for i in range(num_down_sampling):
multiple = 2 ** i
submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=4, stride=2, padding=1, bias=True),
norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True)
]
self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels
self.resnet = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
def forward(self, x):
x = self.start_conv(x)
x = self.encoder(x)
x = self.resnet(x)
return x
class Decoder(nn.Module):
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
norm_type="LN"):
super(Decoder, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
res_block_channels = (2 ** 2) * base_channels
self.resnet = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
multiple = 2 ** (num_down_sampling - i)
submodules += [
nn.Upsample(scale_factor=2),
nn.Conv2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=5, stride=1,
padding=2, padding_mode=padding_mode, bias=use_bias),
norm_layer(num_features=base_channels * multiple // 2),
nn.ReLU(inplace=True),
]
self.decoder = nn.Sequential(*submodules)
self.end_conv = nn.Sequential(
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
nn.Tanh()
)
def forward(self, x):
x = self.resnet(x)
x = self.decoder(x)
x = self.end_conv(x)
return x
class Fusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_fc = nn.Sequential(
nn.Linear(in_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
)
self.fcs = nn.Sequential(*[
nn.Sequential(
nn.Linear(base_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
) for _ in range(n_blocks - 2)
])
self.end_fc = nn.Sequential(
nn.Linear(base_features, out_features),
)
def forward(self, x):
x = self.start_fc(x)
x = self.fcs(x)
return self.end_fc(x)
class StyleGenerator(nn.Module):
def __init__(self, style_in_channels, style_dim=512, num_blocks=8, base_channels=64, padding_mode="reflect"):
super().__init__()
self.num_blocks = num_blocks
self.style_encoder = VGG19StyleEncoder(
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE")
self.fc = nn.Sequential(
nn.Linear(style_dim, style_dim),
nn.ReLU(True),
)
res_block_channels = 2 ** 2 * base_channels
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE")
def forward(self, x):
styles = self.fusion(self.fc(self.style_encoder(x)))
return styles
@MODEL.register_module("TAFG-Generator")
class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512,
num_adain_blocks=8, num_res_blocks=4,
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
style_dim=512, style_use_fc=True,
num_adain_blocks=8, num_res_blocks=8,
base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__()
self.num_adain_blocks=num_adain_blocks
self.style_encoders = nn.ModuleDict({
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
base_channels=base_channels, padding_mode=padding_mode),
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
base_channels=base_channels, padding_mode=padding_mode),
})
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=8,
padding_mode=padding_mode, norm_type="IN")
res_block_channels = 2 ** 2 * base_channels
self.resnet = nn.ModuleDict({
"a": nn.Sequential(*[
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
]),
"b": nn.Sequential(*[
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
])
})
self.adain_resnet = nn.ModuleDict({
"a": nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
]),
"b": nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
])
self.num_adain_blocks = num_adain_blocks
self.style_encoders = nn.ModuleDict(dict(
a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
norm_type="NONE"),
b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
norm_type="NONE", fix_vgg19=False)
))
resnet_channels = 2 ** 2 * base_channels
self.style_converters = nn.ModuleDict(dict(
a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE"),
b=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE"),
))
self.content_encoders = nn.ModuleDict({
"a": ContentEncoder(content_in_channels, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm),
"b": ContentEncoder(1, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm)
})
self.decoders = nn.ModuleDict({
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode),
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode)
})
self.content_resnet = nn.Sequential(*[
ResBlock(resnet_channels, use_spectral_norm, padding_mode, "IN")
for _ in range(num_res_blocks)
])
self.decoders = nn.ModuleDict(dict(
a=Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
b=Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
))
def forward(self, content_img, style_img, which_decoder: str = "a"):
x = self.content_encoder(content_img)
x = self.resnet[which_decoder](x)
styles = self.style_encoders[which_decoder](style_img)
styles = torch.chunk(styles, self.num_adain_blocks * 2, dim=1)
for i, ar in enumerate(self.adain_resnet[which_decoder]):
ar.norm1.set_style(styles[2 * i])
ar.norm2.set_style(styles[2 * i + 1])
x = ar(x)
return self.decoders[which_decoder](x)
def encode(self, content_img, style_img, which_content, which_style):
content = self.content_resnet(self.content_encoders[which_content](content_img))
style = self.style_encoders[which_style](style_img)
return content, style
def decode(self, content, style, which):
decoder = self.decoders[which]
as_param_style = torch.chunk(self.style_converters[which](style), self.num_adain_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return decoder(content)
@MODEL.register_module("TAFG-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN",
padding_mode="reflect"):
super(Discriminator, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
sequence = [nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=use_bias),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)]
# stacked intermediate layers,
# gradually increasing the number of filters
multiple_now = 1
for n in range(1, num_down_sampling + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** n, 4)
sequence += [
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=3,
padding=1, stride=2, bias=use_bias),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=True)
]
for _ in range(num_blocks):
sequence.append(ResidualBlock(base_channels * multiple_now, padding_mode, norm_type))
self.model = nn.Sequential(*sequence)
def forward(self, x):
return self.model(x)
def forward(self, content_img, style_img, which_content, which_style):
content, style = self.encode(content_img, style_img, which_content, which_style)
return self.decode(content, style, which_style)

View File

@ -185,7 +185,7 @@ class Conv2dBlock(nn.Module):
class ResBlock(nn.Module):
def __init__(self, num_channels, use_spectral_norm=False, padding_mode='reflect',
norm_type="IN", activation_type="relu", use_bias=None):
norm_type="IN", activation_type="ReLU", use_bias=None):
super().__init__()
self.norm_type = norm_type
if use_bias is None: