TAFG update
This commit is contained in:
parent
61e04de8a5
commit
b01016edb5
@ -1,6 +1,6 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="PublishConfigData" autoUpload="Always" serverName="15d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
<component name="PublishConfigData" autoUpload="Always" serverName="14d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||||
<serverData>
|
<serverData>
|
||||||
<paths name="14d">
|
<paths name="14d">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
name: TAFG-vox2
|
name: TAFG-vox2
|
||||||
engine: TAFG
|
engine: TAFG
|
||||||
result_dir: ./result
|
result_dir: ./result
|
||||||
max_pairs: 1500000
|
max_pairs: 1000000
|
||||||
|
|
||||||
handler:
|
handler:
|
||||||
clear_cuda_cache: True
|
clear_cuda_cache: True
|
||||||
@ -12,10 +12,13 @@ handler:
|
|||||||
tensorboard:
|
tensorboard:
|
||||||
scalar: 100 # log scalar `scalar` times per epoch
|
scalar: 100 # log scalar `scalar` times per epoch
|
||||||
image: 4 # log image `image` times per epoch
|
image: 4 # log image `image` times per epoch
|
||||||
|
test:
|
||||||
|
random: True
|
||||||
|
images: 10
|
||||||
|
|
||||||
|
|
||||||
misc:
|
misc:
|
||||||
random_seed: 123
|
random_seed: 1004
|
||||||
|
|
||||||
model:
|
model:
|
||||||
generator:
|
generator:
|
||||||
@ -23,10 +26,13 @@ model:
|
|||||||
_bn_to_sync_bn: False
|
_bn_to_sync_bn: False
|
||||||
style_in_channels: 3
|
style_in_channels: 3
|
||||||
content_in_channels: 24
|
content_in_channels: 24
|
||||||
num_adain_blocks: 8
|
use_spectral_norm: False
|
||||||
num_res_blocks: 8
|
style_encoder_type: StyleEncoder
|
||||||
use_spectral_norm: True
|
num_style_conv: 4
|
||||||
style_use_fc: False
|
style_dim: 8
|
||||||
|
num_adain_blocks: 4
|
||||||
|
num_res_blocks: 4
|
||||||
|
|
||||||
discriminator:
|
discriminator:
|
||||||
_type: MultiScaleDiscriminator
|
_type: MultiScaleDiscriminator
|
||||||
num_scale: 2
|
num_scale: 2
|
||||||
@ -54,17 +60,24 @@ loss:
|
|||||||
style_loss: False
|
style_loss: False
|
||||||
perceptual_loss: True
|
perceptual_loss: True
|
||||||
weight: 0
|
weight: 0
|
||||||
|
style:
|
||||||
|
layer_weights:
|
||||||
|
"3": 1
|
||||||
|
criterion: 'L1'
|
||||||
|
style_loss: True
|
||||||
|
perceptual_loss: False
|
||||||
|
weight: 10
|
||||||
recon:
|
recon:
|
||||||
level: 1
|
level: 1
|
||||||
weight: 10
|
weight: 10
|
||||||
style_recon:
|
style_recon:
|
||||||
level: 1
|
level: 1
|
||||||
weight: 5
|
weight: 1
|
||||||
content_recon:
|
content_recon:
|
||||||
level: 1
|
level: 1
|
||||||
weight: 10
|
weight: 1
|
||||||
edge:
|
edge:
|
||||||
weight: 10
|
weight: 5
|
||||||
hed_pretrained_model_path: ./network-bsds500.pytorch
|
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||||
cycle:
|
cycle:
|
||||||
level: 1
|
level: 1
|
||||||
@ -89,7 +102,7 @@ data:
|
|||||||
target_lr: 0
|
target_lr: 0
|
||||||
buffer_size: 50
|
buffer_size: 50
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 1
|
batch_size: 8
|
||||||
shuffle: True
|
shuffle: True
|
||||||
num_workers: 1
|
num_workers: 1
|
||||||
pin_memory: True
|
pin_memory: True
|
||||||
|
|||||||
@ -20,6 +20,10 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
perceptual_loss_cfg.pop("weight")
|
perceptual_loss_cfg.pop("weight")
|
||||||
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
style_loss_cfg = OmegaConf.to_container(config.loss.style)
|
||||||
|
style_loss_cfg.pop("weight")
|
||||||
|
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||||
gan_loss_cfg.pop("weight")
|
gan_loss_cfg.pop("weight")
|
||||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
@ -68,14 +72,14 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
contents = dict()
|
contents = dict()
|
||||||
images = dict()
|
images = dict()
|
||||||
with torch.set_grad_enabled(not inference):
|
with torch.set_grad_enabled(not inference):
|
||||||
|
contents["a"], styles["a"] = generator.encode(batch["a"]["edge"], batch["a"]["img"], "a", "a")
|
||||||
|
contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b")
|
||||||
for ph in "ab":
|
for ph in "ab":
|
||||||
contents[ph], styles[ph] = generator.encode(batch[ph]["edge"], batch[ph]["img"], ph, ph)
|
images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph)
|
||||||
for ph in ("a2b", "b2a"):
|
images["a2b"] = generator.decode(contents["a"], styles["b"], "b")
|
||||||
images[f"fake_{ph[-1]}"] = generator.decode(contents[ph[0]], styles[ph[-1]], ph[-1])
|
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
|
||||||
contents["recon_a"], styles["recon_b"] = generator.encode(
|
images["a2b"], "b", "b")
|
||||||
self.edge_loss.edge_extractor(images["fake_b"]), images["fake_b"], "b", "b")
|
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
|
||||||
images["a2a"] = generator.decode(contents["a"], styles["a"], "a")
|
|
||||||
images["b2b"] = generator.decode(contents["b"], styles["recon_b"], "b")
|
|
||||||
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
|
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
|
||||||
return dict(styles=styles, contents=contents, images=images)
|
return dict(styles=styles, contents=contents, images=images)
|
||||||
|
|
||||||
@ -87,35 +91,38 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
|
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
|
||||||
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
|
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
|
||||||
|
|
||||||
pred_fake = self.discriminators[ph](generated["images"][f"fake_{ph}"])
|
pred_fake = self.discriminators[ph](generated["images"][f"a2{ph}"])
|
||||||
loss[f"gan_{ph}"] = 0
|
loss[f"gan_{ph}"] = 0
|
||||||
for sub_pred_fake in pred_fake:
|
for sub_pred_fake in pred_fake:
|
||||||
# last output is actual prediction
|
# last output is actual prediction
|
||||||
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
|
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
|
||||||
loss[f"recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
|
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
|
||||||
generated["contents"]["a"], generated["contents"]["recon_a"]
|
generated["contents"]["a"], generated["contents"]["recon_a"]
|
||||||
)
|
)
|
||||||
loss[f"recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||||
generated["styles"]["b"], generated["styles"]["recon_b"]
|
generated["styles"]["b"], generated["styles"]["recon_b"]
|
||||||
)
|
)
|
||||||
|
|
||||||
for ph in ("a2b", "b2a"):
|
if self.config.loss.perceptual.weight > 0:
|
||||||
if self.config.loss.perceptual.weight > 0:
|
loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||||
loss[f"perceptual_{ph}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
batch["a"]["img"], generated["images"]["a2b"]
|
||||||
batch[ph[0]]["img"], generated["images"][f"fake_{ph[-1]}"]
|
|
||||||
)
|
|
||||||
if self.config.loss.edge.weight > 0:
|
|
||||||
loss[f"edge_a"] = self.config.loss.edge.weight * self.edge_loss(
|
|
||||||
generated["images"]["fake_b"], batch["a"]["edge"][:, 0:1, :, :]
|
|
||||||
)
|
|
||||||
loss[f"edge_b"] = self.config.loss.edge.weight * self.edge_loss(
|
|
||||||
generated["images"]["fake_a"], batch["b"]["edge"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.loss.cycle.weight > 0:
|
for ph in "ab":
|
||||||
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
|
if self.config.loss.cycle.weight > 0:
|
||||||
batch["a"]["img"], generated["images"]["cycle_a"]
|
loss[f"cycle_{ph}"] = self.config.loss.cycle.weight * self.cycle_loss(
|
||||||
|
batch[ph]["img"], generated["images"][f"cycle_{ph}"]
|
||||||
|
)
|
||||||
|
if self.config.loss.style.weight > 0:
|
||||||
|
loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
|
||||||
|
batch[ph]["img"], generated["images"][f"a2{ph}"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.loss.edge.weight > 0:
|
||||||
|
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
|
||||||
|
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def criterion_discriminators(self, batch, generated) -> dict:
|
def criterion_discriminators(self, batch, generated) -> dict:
|
||||||
@ -123,7 +130,7 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
# batch = self._process_batch(batch)
|
# batch = self._process_batch(batch)
|
||||||
for phase in self.discriminators.keys():
|
for phase in self.discriminators.keys():
|
||||||
pred_real = self.discriminators[phase](batch[phase]["img"])
|
pred_real = self.discriminators[phase](batch[phase]["img"])
|
||||||
pred_fake = self.discriminators[phase](generated["images"][f"fake_{phase}"].detach())
|
pred_fake = self.discriminators[phase](generated["images"][f"a2{phase}"].detach())
|
||||||
loss[f"gan_{phase}"] = 0
|
loss[f"gan_{phase}"] = 0
|
||||||
for i in range(len(pred_fake)):
|
for i in range(len(pred_fake)):
|
||||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||||
@ -142,13 +149,13 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
||||||
batch["a"]["img"].detach(),
|
batch["a"]["img"].detach(),
|
||||||
generated["images"]["a2a"].detach(),
|
generated["images"]["a2a"].detach(),
|
||||||
generated["images"]["fake_b"].detach(),
|
generated["images"]["a2b"].detach(),
|
||||||
generated["images"]["cycle_a"].detach(),
|
generated["images"]["cycle_a"].detach(),
|
||||||
],
|
],
|
||||||
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
|
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
|
||||||
batch["b"]["img"].detach(),
|
batch["b"]["img"].detach(),
|
||||||
generated["images"]["b2b"].detach(),
|
generated["images"]["b2b"].detach(),
|
||||||
generated["images"]["fake_a"].detach()]
|
generated["images"]["cycle_b"].detach()]
|
||||||
)
|
)
|
||||||
|
|
||||||
def change_engine(self, config, trainer):
|
def change_engine(self, config, trainer):
|
||||||
|
|||||||
@ -58,6 +58,10 @@ class EngineKernel(object):
|
|||||||
self.logger = logging.getLogger(config.name)
|
self.logger = logging.getLogger(config.name)
|
||||||
self.generators, self.discriminators = self.build_models()
|
self.generators, self.discriminators = self.build_models()
|
||||||
self.train_generator_first = True
|
self.train_generator_first = True
|
||||||
|
self.engine = None
|
||||||
|
|
||||||
|
def bind_engine(self, engine):
|
||||||
|
self.engine = engine
|
||||||
|
|
||||||
def build_models(self) -> (dict, dict):
|
def build_models(self) -> (dict, dict):
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
@ -154,6 +158,7 @@ def get_trainer(config, kernel: EngineKernel):
|
|||||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
||||||
|
|
||||||
kernel.change_engine(config, trainer)
|
kernel.change_engine(config, trainer)
|
||||||
|
kernel.bind_engine(trainer)
|
||||||
|
|
||||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g")
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g")
|
||||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d")
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d")
|
||||||
@ -186,9 +191,11 @@ def get_trainer(config, kernel: EngineKernel):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
g = torch.Generator()
|
g = torch.Generator()
|
||||||
g.manual_seed(config.misc.random_seed)
|
g.manual_seed(config.misc.random_seed + engine.state.epoch
|
||||||
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
|
if config.handler.test.random else config.misc.random_seed)
|
||||||
for i in range(random_start, random_start + 10):
|
random_start = \
|
||||||
|
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
|
||||||
|
for i in range(random_start, random_start + config.handler.test.images):
|
||||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||||
for k in batch:
|
for k in batch:
|
||||||
if isinstance(batch[k], torch.Tensor):
|
if isinstance(batch[k], torch.Tensor):
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from model.normalization import select_norm_layer
|
|||||||
|
|
||||||
class StyleEncoder(nn.Module):
|
class StyleEncoder(nn.Module):
|
||||||
def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False,
|
def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False,
|
||||||
padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
|
max_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
|
||||||
super(StyleEncoder, self).__init__()
|
super(StyleEncoder, self).__init__()
|
||||||
|
|
||||||
sequence = [Conv2dBlock(
|
sequence = [Conv2dBlock(
|
||||||
@ -19,7 +19,7 @@ class StyleEncoder(nn.Module):
|
|||||||
multiple_now = 1
|
multiple_now = 1
|
||||||
for i in range(1, num_conv + 1):
|
for i in range(1, num_conv + 1):
|
||||||
multiple_prev = multiple_now
|
multiple_prev = multiple_now
|
||||||
multiple_now = min(2 ** i, 2 ** 2)
|
multiple_now = min(2 ** i, 2 ** max_multiple)
|
||||||
sequence.append(Conv2dBlock(
|
sequence.append(Conv2dBlock(
|
||||||
multiple_prev * base_channels, multiple_now * base_channels,
|
multiple_prev * base_channels, multiple_now * base_channels,
|
||||||
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
||||||
@ -50,12 +50,8 @@ class ContentEncoder(nn.Module):
|
|||||||
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
|
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
|
||||||
))
|
))
|
||||||
|
|
||||||
for _ in range(num_res_blocks):
|
sequence += [ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type,
|
||||||
sequence.append(
|
activation_type) for _ in range(num_res_blocks)]
|
||||||
ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type,
|
|
||||||
activation_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.sequence = nn.Sequential(*sequence)
|
self.sequence = nn.Sequential(*sequence)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from torchvision.models import vgg19
|
|||||||
|
|
||||||
from model.normalization import select_norm_layer
|
from model.normalization import select_norm_layer
|
||||||
from model.registry import MODEL
|
from model.registry import MODEL
|
||||||
from .MUNIT import ContentEncoder, Fusion, Decoder
|
from .MUNIT import ContentEncoder, Fusion, Decoder, StyleEncoder
|
||||||
from .base import ResBlock
|
from .base import ResBlock
|
||||||
|
|
||||||
|
|
||||||
@ -56,17 +56,26 @@ class VGG19StyleEncoder(nn.Module):
|
|||||||
@MODEL.register_module("TAFG-Generator")
|
@MODEL.register_module("TAFG-Generator")
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
|
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
|
||||||
style_dim=512, style_use_fc=True,
|
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
|
||||||
num_adain_blocks=8, num_res_blocks=8,
|
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
|
||||||
base_channels=64, padding_mode="reflect"):
|
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
self.num_adain_blocks = num_adain_blocks
|
self.num_adain_blocks = num_adain_blocks
|
||||||
self.style_encoders = nn.ModuleDict(dict(
|
if style_encoder_type == "StyleEncoder":
|
||||||
a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
self.style_encoders = nn.ModuleDict(dict(
|
||||||
norm_type="NONE"),
|
a=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
||||||
b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
|
||||||
norm_type="NONE", fix_vgg19=False)
|
b=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
||||||
))
|
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
|
||||||
|
))
|
||||||
|
elif style_encoder_type == "VGG19StyleEncoder":
|
||||||
|
self.style_encoders = nn.ModuleDict(dict(
|
||||||
|
a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
||||||
|
norm_type="NONE"),
|
||||||
|
b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
||||||
|
norm_type="NONE", fix_vgg19=False)
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
raise NotImplemented(f"do not support {style_encoder_type}")
|
||||||
resnet_channels = 2 ** 2 * base_channels
|
resnet_channels = 2 ** 2 * base_channels
|
||||||
self.style_converters = nn.ModuleDict(dict(
|
self.style_converters = nn.ModuleDict(dict(
|
||||||
a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
|
a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user