This commit is contained in:
Ray Wong 2020-09-17 09:34:53 +08:00
parent 2ff4a91057
commit 61e04de8a5
9 changed files with 168 additions and 288 deletions

View File

@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="14d" remoteFilesAllowedToDisappearOnAutoupload="false"> <component name="PublishConfigData" autoUpload="Always" serverName="15d" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData> <serverData>
<paths name="14d"> <paths name="14d">
<serverdata> <serverdata>

View File

@ -1,4 +1,4 @@
name: TAFG name: TAFG-vox2
engine: TAFG engine: TAFG
result_dir: ./result result_dir: ./result
max_pairs: 1500000 max_pairs: 1500000
@ -11,11 +11,11 @@ handler:
n_saved: 2 n_saved: 2
tensorboard: tensorboard:
scalar: 100 # log scalar `scalar` times per epoch scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch image: 4 # log image `image` times per epoch
misc: misc:
random_seed: 324 random_seed: 123
model: model:
generator: generator:
@ -24,7 +24,9 @@ model:
style_in_channels: 3 style_in_channels: 3
content_in_channels: 24 content_in_channels: 24
num_adain_blocks: 8 num_adain_blocks: 8
num_res_blocks: 0 num_res_blocks: 8
use_spectral_norm: True
style_use_fc: False
discriminator: discriminator:
_type: MultiScaleDiscriminator _type: MultiScaleDiscriminator
num_scale: 2 num_scale: 2
@ -51,26 +53,22 @@ loss:
criterion: 'L1' criterion: 'L1'
style_loss: False style_loss: False
perceptual_loss: True perceptual_loss: True
weight: 10 weight: 0
style:
layer_weights:
"3": 1
criterion: 'L1'
style_loss: True
perceptual_loss: False
weight: 10
fm:
level: 1
weight: 10
recon: recon:
level: 1 level: 1
weight: 10 weight: 10
style_recon: style_recon:
level: 1 level: 1
weight: 0 weight: 5
content_recon:
level: 1
weight: 10
edge: edge:
weight: 10 weight: 10
hed_pretrained_model_path: ./network-bsds500.pytorch hed_pretrained_model_path: ./network-bsds500.pytorch
cycle:
level: 1
weight: 10
optimizers: optimizers:
generator: generator:
@ -91,9 +89,9 @@ data:
target_lr: 0 target_lr: 0
buffer_size: 50 buffer_size: 50
dataloader: dataloader:
batch_size: 8 batch_size: 1
shuffle: True shuffle: True
num_workers: 2 num_workers: 1
pin_memory: True pin_memory: True
drop_last: True drop_last: True
dataset: dataset:
@ -116,7 +114,7 @@ data:
test: test:
which: video_dataset which: video_dataset
dataloader: dataloader:
batch_size: 8 batch_size: 1
shuffle: False shuffle: False
num_workers: 1 num_workers: 1
pin_memory: False pin_memory: False
@ -145,7 +143,7 @@ data:
pipeline: pipeline:
- Load - Load
- Resize: - Resize:
size: [ 256, 256 ] size: [ 128, 128 ]
- ToTensor - ToTensor
- Normalize: - Normalize:
mean: [ 0.5, 0.5, 0.5 ] mean: [ 0.5, 0.5, 0.5 ]

View File

@ -203,7 +203,7 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
op = Path(origin_path) op = Path(origin_path)
if self.edge_type.startswith("landmark_"): if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_") edge_type = self.edge_type.lstrip("landmark_")
use_landmark = True use_landmark = op.parent.name.endswith("A")
else: else:
edge_type = self.edge_type edge_type = self.edge_type
use_landmark = False use_landmark = False
@ -225,14 +225,11 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
a_idx = idx % len(self.A) a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
if self.with_path: output = dict(a={}, b={})
output = {"a": self.A[a_idx], "b": self.B[b_idx]} output["a"]["img"], output["a"]["path"] = self.A[a_idx]
output["edge_a"] = output["a"][1] output["b"]["img"], output["b"]["path"] = self.B[b_idx]
return output for p in "ab":
output = dict() output[p]["edge"] = self.get_edge(output[p]["path"])
output["a"], path_a = self.A[a_idx]
output["b"], path_b = self.B[b_idx]
output["edge_a"] = self.get_edge(path_a)
return output return output
def __len__(self): def __len__(self):

View File

@ -58,20 +58,20 @@ class MUNITEngineKernel(EngineKernel):
styles = dict() styles = dict()
contents = dict() contents = dict()
images = dict() images = dict()
with torch.set_grad_enabled(not inference):
for phase in "ab":
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
for phase in "ab": for phase in ("a2b", "b2a"):
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase]) # images["a2b"] = Gb.decode(content_a, random_style_b)
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase]) images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device()) # contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
for phase in ("a2b", "b2a"): if self.config.loss.recon.cycle.weight > 0:
# images["a2b"] = Gb.decode(content_a, random_style_b) images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"]) return dict(styles=styles, contents=contents, images=images)
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
if self.config.loss.recon.cycle.weight > 0:
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict: def criterion_generators(self, batch, generated) -> dict:
loss = dict() loss = dict()

View File

@ -3,8 +3,7 @@ from itertools import chain
import ignite.distributed as idist import ignite.distributed as idist
import torch import torch
import torch.nn as nn import torch.nn as nn
from ignite.engine import Events from omegaconf import OmegaConf
from omegaconf import read_write, OmegaConf
from engine.base.i2i import EngineKernel, run_kernel from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model from engine.util.build import build_model
@ -21,17 +20,14 @@ class TAFGEngineKernel(EngineKernel):
perceptual_loss_cfg.pop("weight") perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight") gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss() self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
self.content_recon_loss = nn.L1Loss() if config.loss.content_recon.level == 1 else nn.MSELoss()
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to( self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device()) idist.device())
@ -67,47 +63,67 @@ class TAFGEngineKernel(EngineKernel):
def forward(self, batch, inference=False) -> dict: def forward(self, batch, inference=False) -> dict:
generator = self.generators["main"] generator = self.generators["main"]
batch = self._process_batch(batch, inference) batch = self._process_batch(batch, inference)
styles = dict()
contents = dict()
images = dict()
with torch.set_grad_enabled(not inference): with torch.set_grad_enabled(not inference):
fake = dict( for ph in "ab":
a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"), contents[ph], styles[ph] = generator.encode(batch[ph]["edge"], batch[ph]["img"], ph, ph)
b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"), for ph in ("a2b", "b2a"):
) images[f"fake_{ph[-1]}"] = generator.decode(contents[ph[0]], styles[ph[-1]], ph[-1])
return fake contents["recon_a"], styles["recon_b"] = generator.encode(
self.edge_loss.edge_extractor(images["fake_b"]), images["fake_b"], "b", "b")
images["a2a"] = generator.decode(contents["a"], styles["a"], "a")
images["b2b"] = generator.decode(contents["b"], styles["recon_b"], "b")
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict: def criterion_generators(self, batch, generated) -> dict:
batch = self._process_batch(batch) batch = self._process_batch(batch)
loss = dict() loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
_, loss_style = self.style_loss(generated["a"], batch["a"]) for ph in "ab":
loss["style"] = self.config.loss.style.weight * loss_style loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
loss["perceptual"] = self.config.loss.perceptual.weight * loss_perceptual generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
for phase in "ab":
pred_fake = self.discriminators[phase](generated[phase]) pred_fake = self.discriminators[ph](generated["images"][f"fake_{ph}"])
loss[f"gan_{phase}"] = 0 loss[f"gan_{ph}"] = 0
for sub_pred_fake in pred_fake: for sub_pred_fake in pred_fake:
# last output is actual prediction # last output is actual prediction
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True) loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
loss[f"recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
generated["contents"]["a"], generated["contents"]["recon_a"]
)
loss[f"recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
generated["styles"]["b"], generated["styles"]["recon_b"]
)
if self.config.loss.fm.weight > 0 and phase == "b": for ph in ("a2b", "b2a"):
pred_real = self.discriminators[phase](batch[phase]) if self.config.loss.perceptual.weight > 0:
loss_fm = 0 loss[f"perceptual_{ph}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
num_scale_discriminator = len(pred_fake) batch[ph[0]]["img"], generated["images"][f"fake_{ph[-1]}"]
for i in range(num_scale_discriminator): )
# last output is the final prediction, so we exclude it if self.config.loss.edge.weight > 0:
num_intermediate_outputs = len(pred_fake[i]) - 1 loss[f"edge_a"] = self.config.loss.edge.weight * self.edge_loss(
for j in range(num_intermediate_outputs): generated["images"]["fake_b"], batch["a"]["edge"][:, 0:1, :, :]
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator )
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm loss[f"edge_b"] = self.config.loss.edge.weight * self.edge_loss(
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"]) generated["images"]["fake_a"], batch["b"]["edge"]
loss["edge"] = self.config.loss.edge.weight * self.edge_loss(generated["b"], batch["edge_a"][:, 0:1, :, :]) )
if self.config.loss.cycle.weight > 0:
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
batch["a"]["img"], generated["images"]["cycle_a"]
)
return loss return loss
def criterion_discriminators(self, batch, generated) -> dict: def criterion_discriminators(self, batch, generated) -> dict:
loss = dict() loss = dict()
# batch = self._process_batch(batch) # batch = self._process_batch(batch)
for phase in self.discriminators.keys(): for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase]) pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated[phase].detach()) pred_fake = self.discriminators[phase](generated["images"][f"fake_{phase}"].detach())
loss[f"gan_{phase}"] = 0 loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)): for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
@ -122,17 +138,25 @@ class TAFGEngineKernel(EngineKernel):
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
""" """
batch = self._process_batch(batch) batch = self._process_batch(batch)
edge = batch["edge_a"][:, 0:1, :, :]
return dict( return dict(
a=[edge.expand(-1, 3, -1, -1).detach(), batch["a"].detach(), batch["b"].detach(), generated["a"].detach(), a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
generated["b"].detach()] batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
generated["images"]["fake_b"].detach(),
generated["images"]["cycle_a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
generated["images"]["fake_a"].detach()]
) )
def change_engine(self, config, trainer): def change_engine(self, config, trainer):
@trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3))) pass
def change_config(engine): # @trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
with read_write(config): # def change_config(engine):
config.loss.perceptual.weight = 5 # with read_write(config):
# config.loss.perceptual.weight = 5
def run(task, config, _): def run(task, config, _):

View File

@ -132,9 +132,12 @@ def get_trainer(config, kernel: EngineKernel):
generated = kernel.forward(batch) generated = kernel.forward(batch)
if kernel.train_generator_first: if kernel.train_generator_first:
# simultaneous, train G with simultaneous D
loss_g = train_generators(batch, generated) loss_g = train_generators(batch, generated)
loss_d = train_discriminators(batch, generated) loss_d = train_discriminators(batch, generated)
else: else:
# update discriminators first, not simultaneous.
# train G with updated discriminators
loss_d = train_discriminators(batch, generated) loss_d = train_discriminators(batch, generated)
loss_g = train_generators(batch, generated) loss_g = train_generators(batch, generated)
@ -152,8 +155,8 @@ def get_trainer(config, kernel: EngineKernel):
kernel.change_engine(config, trainer) kernel.change_engine(config, trainer)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d")
to_save = dict(trainer=trainer) to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers}) to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
@ -188,7 +191,13 @@ def get_trainer(config, kernel: EngineKernel):
for i in range(random_start, random_start + 10): for i in range(random_start, random_start + 10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device()) batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch: for k in batch:
batch[k] = batch[k].view(1, *batch[k].size()) if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].unsqueeze(0)
elif isinstance(batch[k], dict):
for kk in batch[k]:
if isinstance(batch[k][kk], torch.Tensor):
batch[k][kk] = batch[k][kk].unsqueeze(0)
generated = kernel.forward(batch) generated = kernel.forward(batch)
images = kernel.intermediate_images(batch, generated) images = kernel.intermediate_images(batch, generated)

View File

@ -92,6 +92,7 @@ class PerceptualLoss(nn.Module):
style_loss=False, norm_img=True, criterion='L1'): style_loss=False, norm_img=True, criterion='L1'):
super(PerceptualLoss, self).__init__() super(PerceptualLoss, self).__init__()
self.norm_img = norm_img self.norm_img = norm_img
assert perceptual_loss ^ style_loss, "There must be one and only one true in style or perceptual"
self.perceptual_loss = perceptual_loss self.perceptual_loss = perceptual_loss
self.style_loss = style_loss self.style_loss = style_loss
self.layer_weights = layer_weights self.layer_weights = layer_weights
@ -127,8 +128,7 @@ class PerceptualLoss(nn.Module):
percep_loss = 0 percep_loss = 0
for k in x_features.keys(): for k in x_features.keys():
percep_loss += self.percep_criterion(x_features[k], gt_features[k]) * self.layer_weights[k] percep_loss += self.percep_criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
else: return percep_loss
percep_loss = None
# calculate style loss # calculate style loss
if self.style_loss: if self.style_loss:
@ -136,10 +136,7 @@ class PerceptualLoss(nn.Module):
for k in x_features.keys(): for k in x_features.keys():
style_loss += self.style_criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \ style_loss += self.style_criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
self.layer_weights[k] self.layer_weights[k]
else: return style_loss
style_loss = None
return percep_loss, style_loss
def _gram_mat(self, x): def _gram_mat(self, x):
"""Calculate Gram matrix. """Calculate Gram matrix.

View File

@ -4,16 +4,17 @@ from torchvision.models import vgg19
from model.normalization import select_norm_layer from model.normalization import select_norm_layer
from model.registry import MODEL from model.registry import MODEL
from .base import ResidualBlock from .MUNIT import ContentEncoder, Fusion, Decoder
from .base import ResBlock
class VGG19StyleEncoder(nn.Module): class VGG19StyleEncoder(nn.Module):
def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE", def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE",
vgg19_layers=(0, 5, 10, 19)): vgg19_layers=(0, 5, 10, 19), fix_vgg19=True):
super().__init__() super().__init__()
self.vgg19_layers = vgg19_layers self.vgg19_layers = vgg19_layers
self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1] self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1]
self.vgg19.requires_grad_(False) self.vgg19.requires_grad_(not fix_vgg19)
norm_layer = select_norm_layer(norm_type) norm_layer = select_norm_layer(norm_type)
@ -52,203 +53,57 @@ class VGG19StyleEncoder(nn.Module):
return x.view(x.size(0), -1) return x.view(x.size(0), -1)
class ContentEncoder(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=8, padding_mode='reflect', norm_type="IN"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_conv = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=True),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)
# down sampling
submodules = []
num_down_sampling = 2
for i in range(num_down_sampling):
multiple = 2 ** i
submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=4, stride=2, padding=1, bias=True),
norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True)
]
self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels
self.resnet = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
def forward(self, x):
x = self.start_conv(x)
x = self.encoder(x)
x = self.resnet(x)
return x
class Decoder(nn.Module):
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
norm_type="LN"):
super(Decoder, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
res_block_channels = (2 ** 2) * base_channels
self.resnet = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
multiple = 2 ** (num_down_sampling - i)
submodules += [
nn.Upsample(scale_factor=2),
nn.Conv2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=5, stride=1,
padding=2, padding_mode=padding_mode, bias=use_bias),
norm_layer(num_features=base_channels * multiple // 2),
nn.ReLU(inplace=True),
]
self.decoder = nn.Sequential(*submodules)
self.end_conv = nn.Sequential(
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
nn.Tanh()
)
def forward(self, x):
x = self.resnet(x)
x = self.decoder(x)
x = self.end_conv(x)
return x
class Fusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_fc = nn.Sequential(
nn.Linear(in_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
)
self.fcs = nn.Sequential(*[
nn.Sequential(
nn.Linear(base_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
) for _ in range(n_blocks - 2)
])
self.end_fc = nn.Sequential(
nn.Linear(base_features, out_features),
)
def forward(self, x):
x = self.start_fc(x)
x = self.fcs(x)
return self.end_fc(x)
class StyleGenerator(nn.Module):
def __init__(self, style_in_channels, style_dim=512, num_blocks=8, base_channels=64, padding_mode="reflect"):
super().__init__()
self.num_blocks = num_blocks
self.style_encoder = VGG19StyleEncoder(
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE")
self.fc = nn.Sequential(
nn.Linear(style_dim, style_dim),
nn.ReLU(True),
)
res_block_channels = 2 ** 2 * base_channels
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE")
def forward(self, x):
styles = self.fusion(self.fc(self.style_encoder(x)))
return styles
@MODEL.register_module("TAFG-Generator") @MODEL.register_module("TAFG-Generator")
class Generator(nn.Module): class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
num_adain_blocks=8, num_res_blocks=4, style_dim=512, style_use_fc=True,
num_adain_blocks=8, num_res_blocks=8,
base_channels=64, padding_mode="reflect"): base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_adain_blocks=num_adain_blocks self.num_adain_blocks = num_adain_blocks
self.style_encoders = nn.ModuleDict({ self.style_encoders = nn.ModuleDict(dict(
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks, a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
base_channels=base_channels, padding_mode=padding_mode), norm_type="NONE"),
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks, b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
base_channels=base_channels, padding_mode=padding_mode), norm_type="NONE", fix_vgg19=False)
}) ))
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=8, resnet_channels = 2 ** 2 * base_channels
padding_mode=padding_mode, norm_type="IN") self.style_converters = nn.ModuleDict(dict(
res_block_channels = 2 ** 2 * base_channels a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE"),
self.resnet = nn.ModuleDict({ b=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
"a": nn.Sequential(*[ norm_type="NONE"),
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks) ))
]), self.content_encoders = nn.ModuleDict({
"b": nn.Sequential(*[ "a": ContentEncoder(content_in_channels, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm),
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks) "b": ContentEncoder(1, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm)
])
})
self.adain_resnet = nn.ModuleDict({
"a": nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
]),
"b": nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
])
}) })
self.decoders = nn.ModuleDict({ self.content_resnet = nn.Sequential(*[
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode), ResBlock(resnet_channels, use_spectral_norm, padding_mode, "IN")
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode) for _ in range(num_res_blocks)
}) ])
self.decoders = nn.ModuleDict(dict(
a=Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
b=Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
))
def forward(self, content_img, style_img, which_decoder: str = "a"): def encode(self, content_img, style_img, which_content, which_style):
x = self.content_encoder(content_img) content = self.content_resnet(self.content_encoders[which_content](content_img))
x = self.resnet[which_decoder](x) style = self.style_encoders[which_style](style_img)
styles = self.style_encoders[which_decoder](style_img) return content, style
styles = torch.chunk(styles, self.num_adain_blocks * 2, dim=1)
for i, ar in enumerate(self.adain_resnet[which_decoder]):
ar.norm1.set_style(styles[2 * i])
ar.norm2.set_style(styles[2 * i + 1])
x = ar(x)
return self.decoders[which_decoder](x)
def decode(self, content, style, which):
decoder = self.decoders[which]
as_param_style = torch.chunk(self.style_converters[which](style), self.num_adain_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return decoder(content)
@MODEL.register_module("TAFG-Discriminator") def forward(self, content_img, style_img, which_content, which_style):
class Discriminator(nn.Module): content, style = self.encode(content_img, style_img, which_content, which_style)
def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN", return self.decode(content, style, which_style)
padding_mode="reflect"):
super(Discriminator, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
sequence = [nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=use_bias),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)]
# stacked intermediate layers,
# gradually increasing the number of filters
multiple_now = 1
for n in range(1, num_down_sampling + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** n, 4)
sequence += [
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=3,
padding=1, stride=2, bias=use_bias),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=True)
]
for _ in range(num_blocks):
sequence.append(ResidualBlock(base_channels * multiple_now, padding_mode, norm_type))
self.model = nn.Sequential(*sequence)
def forward(self, x):
return self.model(x)

View File

@ -185,7 +185,7 @@ class Conv2dBlock(nn.Module):
class ResBlock(nn.Module): class ResBlock(nn.Module):
def __init__(self, num_channels, use_spectral_norm=False, padding_mode='reflect', def __init__(self, num_channels, use_spectral_norm=False, padding_mode='reflect',
norm_type="IN", activation_type="relu", use_bias=None): norm_type="IN", activation_type="ReLU", use_bias=None):
super().__init__() super().__init__()
self.norm_type = norm_type self.norm_type = norm_type
if use_bias is None: if use_bias is None: