diff --git a/.idea/deployment.xml b/.idea/deployment.xml index 0198fe4..8ccfb5e 100644 --- a/.idea/deployment.xml +++ b/.idea/deployment.xml @@ -1,6 +1,6 @@ - + diff --git a/configs/synthesizers/TAFG.yml b/configs/synthesizers/TAFG.yml index 939681f..1e2d133 100644 --- a/configs/synthesizers/TAFG.yml +++ b/configs/synthesizers/TAFG.yml @@ -1,7 +1,7 @@ name: TAFG-vox2 engine: TAFG result_dir: ./result -max_pairs: 1500000 +max_pairs: 1000000 handler: clear_cuda_cache: True @@ -12,10 +12,13 @@ handler: tensorboard: scalar: 100 # log scalar `scalar` times per epoch image: 4 # log image `image` times per epoch + test: + random: True + images: 10 misc: - random_seed: 123 + random_seed: 1004 model: generator: @@ -23,10 +26,13 @@ model: _bn_to_sync_bn: False style_in_channels: 3 content_in_channels: 24 - num_adain_blocks: 8 - num_res_blocks: 8 - use_spectral_norm: True - style_use_fc: False + use_spectral_norm: False + style_encoder_type: StyleEncoder + num_style_conv: 4 + style_dim: 8 + num_adain_blocks: 4 + num_res_blocks: 4 + discriminator: _type: MultiScaleDiscriminator num_scale: 2 @@ -54,17 +60,24 @@ loss: style_loss: False perceptual_loss: True weight: 0 + style: + layer_weights: + "3": 1 + criterion: 'L1' + style_loss: True + perceptual_loss: False + weight: 10 recon: level: 1 weight: 10 style_recon: level: 1 - weight: 5 + weight: 1 content_recon: level: 1 - weight: 10 + weight: 1 edge: - weight: 10 + weight: 5 hed_pretrained_model_path: ./network-bsds500.pytorch cycle: level: 1 @@ -89,7 +102,7 @@ data: target_lr: 0 buffer_size: 50 dataloader: - batch_size: 1 + batch_size: 8 shuffle: True num_workers: 1 pin_memory: True diff --git a/engine/TAFG.py b/engine/TAFG.py index 5871d28..f1d7379 100644 --- a/engine/TAFG.py +++ b/engine/TAFG.py @@ -20,6 +20,10 @@ class TAFGEngineKernel(EngineKernel): perceptual_loss_cfg.pop("weight") self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) + style_loss_cfg = OmegaConf.to_container(config.loss.style) + style_loss_cfg.pop("weight") + self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device()) + gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) @@ -68,14 +72,14 @@ class TAFGEngineKernel(EngineKernel): contents = dict() images = dict() with torch.set_grad_enabled(not inference): + contents["a"], styles["a"] = generator.encode(batch["a"]["edge"], batch["a"]["img"], "a", "a") + contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b") for ph in "ab": - contents[ph], styles[ph] = generator.encode(batch[ph]["edge"], batch[ph]["img"], ph, ph) - for ph in ("a2b", "b2a"): - images[f"fake_{ph[-1]}"] = generator.decode(contents[ph[0]], styles[ph[-1]], ph[-1]) - contents["recon_a"], styles["recon_b"] = generator.encode( - self.edge_loss.edge_extractor(images["fake_b"]), images["fake_b"], "b", "b") - images["a2a"] = generator.decode(contents["a"], styles["a"], "a") - images["b2b"] = generator.decode(contents["b"], styles["recon_b"], "b") + images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph) + images["a2b"] = generator.decode(contents["a"], styles["b"], "b") + contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]), + images["a2b"], "b", "b") + images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b") images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a") return dict(styles=styles, contents=contents, images=images) @@ -87,35 +91,38 @@ class TAFGEngineKernel(EngineKernel): loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss( generated["images"][f"{ph}2{ph}"], batch[ph]["img"]) - pred_fake = self.discriminators[ph](generated["images"][f"fake_{ph}"]) + pred_fake = self.discriminators[ph](generated["images"][f"a2{ph}"]) loss[f"gan_{ph}"] = 0 for sub_pred_fake in pred_fake: # last output is actual prediction loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight - loss[f"recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss( + loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss( generated["contents"]["a"], generated["contents"]["recon_a"] ) - loss[f"recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss( + loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss( generated["styles"]["b"], generated["styles"]["recon_b"] ) - for ph in ("a2b", "b2a"): - if self.config.loss.perceptual.weight > 0: - loss[f"perceptual_{ph}"] = self.config.loss.perceptual.weight * self.perceptual_loss( - batch[ph[0]]["img"], generated["images"][f"fake_{ph[-1]}"] - ) - if self.config.loss.edge.weight > 0: - loss[f"edge_a"] = self.config.loss.edge.weight * self.edge_loss( - generated["images"]["fake_b"], batch["a"]["edge"][:, 0:1, :, :] - ) - loss[f"edge_b"] = self.config.loss.edge.weight * self.edge_loss( - generated["images"]["fake_a"], batch["b"]["edge"] + if self.config.loss.perceptual.weight > 0: + loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss( + batch["a"]["img"], generated["images"]["a2b"] ) - if self.config.loss.cycle.weight > 0: - loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss( - batch["a"]["img"], generated["images"]["cycle_a"] + for ph in "ab": + if self.config.loss.cycle.weight > 0: + loss[f"cycle_{ph}"] = self.config.loss.cycle.weight * self.cycle_loss( + batch[ph]["img"], generated["images"][f"cycle_{ph}"] + ) + if self.config.loss.style.weight > 0: + loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss( + batch[ph]["img"], generated["images"][f"a2{ph}"] + ) + + if self.config.loss.edge.weight > 0: + loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss( + generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :] ) + return loss def criterion_discriminators(self, batch, generated) -> dict: @@ -123,7 +130,7 @@ class TAFGEngineKernel(EngineKernel): # batch = self._process_batch(batch) for phase in self.discriminators.keys(): pred_real = self.discriminators[phase](batch[phase]["img"]) - pred_fake = self.discriminators[phase](generated["images"][f"fake_{phase}"].detach()) + pred_fake = self.discriminators[phase](generated["images"][f"a2{phase}"].detach()) loss[f"gan_{phase}"] = 0 for i in range(len(pred_fake)): loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) @@ -142,13 +149,13 @@ class TAFGEngineKernel(EngineKernel): a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(), batch["a"]["img"].detach(), generated["images"]["a2a"].detach(), - generated["images"]["fake_b"].detach(), + generated["images"]["a2b"].detach(), generated["images"]["cycle_a"].detach(), ], b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(), batch["b"]["img"].detach(), generated["images"]["b2b"].detach(), - generated["images"]["fake_a"].detach()] + generated["images"]["cycle_b"].detach()] ) def change_engine(self, config, trainer): diff --git a/engine/base/i2i.py b/engine/base/i2i.py index b27df75..3454e00 100644 --- a/engine/base/i2i.py +++ b/engine/base/i2i.py @@ -58,6 +58,10 @@ class EngineKernel(object): self.logger = logging.getLogger(config.name) self.generators, self.discriminators = self.build_models() self.train_generator_first = True + self.engine = None + + def bind_engine(self, engine): + self.engine = engine def build_models(self) -> (dict, dict): raise NotImplemented @@ -154,6 +158,7 @@ def get_trainer(config, kernel: EngineKernel): trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd) kernel.change_engine(config, trainer) + kernel.bind_engine(trainer) RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d") @@ -186,9 +191,11 @@ def get_trainer(config, kernel: EngineKernel): with torch.no_grad(): g = torch.Generator() - g.manual_seed(config.misc.random_seed) - random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0] - for i in range(random_start, random_start + 10): + g.manual_seed(config.misc.random_seed + engine.state.epoch + if config.handler.test.random else config.misc.random_seed) + random_start = \ + torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0] + for i in range(random_start, random_start + config.handler.test.images): batch = convert_tensor(engine.state.test_dataset[i], idist.device()) for k in batch: if isinstance(batch[k], torch.Tensor): diff --git a/model/GAN/MUNIT.py b/model/GAN/MUNIT.py index 0c2e221..c113cf0 100644 --- a/model/GAN/MUNIT.py +++ b/model/GAN/MUNIT.py @@ -8,7 +8,7 @@ from model.normalization import select_norm_layer class StyleEncoder(nn.Module): def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False, - padding_mode='reflect', activation_type="ReLU", norm_type="NONE"): + max_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"): super(StyleEncoder, self).__init__() sequence = [Conv2dBlock( @@ -19,7 +19,7 @@ class StyleEncoder(nn.Module): multiple_now = 1 for i in range(1, num_conv + 1): multiple_prev = multiple_now - multiple_now = min(2 ** i, 2 ** 2) + multiple_now = min(2 ** i, 2 ** max_multiple) sequence.append(Conv2dBlock( multiple_prev * base_channels, multiple_now * base_channels, kernel_size=4, stride=2, padding=1, padding_mode=padding_mode, @@ -50,12 +50,8 @@ class ContentEncoder(nn.Module): use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type )) - for _ in range(num_res_blocks): - sequence.append( - ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type, - activation_type) - ) - + sequence += [ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type, + activation_type) for _ in range(num_res_blocks)] self.sequence = nn.Sequential(*sequence) def forward(self, x): diff --git a/model/GAN/TAFG.py b/model/GAN/TAFG.py index e8976c0..5b44b74 100644 --- a/model/GAN/TAFG.py +++ b/model/GAN/TAFG.py @@ -4,7 +4,7 @@ from torchvision.models import vgg19 from model.normalization import select_norm_layer from model.registry import MODEL -from .MUNIT import ContentEncoder, Fusion, Decoder +from .MUNIT import ContentEncoder, Fusion, Decoder, StyleEncoder from .base import ResBlock @@ -56,17 +56,26 @@ class VGG19StyleEncoder(nn.Module): @MODEL.register_module("TAFG-Generator") class Generator(nn.Module): def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False, - style_dim=512, style_use_fc=True, - num_adain_blocks=8, num_res_blocks=8, - base_channels=64, padding_mode="reflect"): + style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8, + num_res_blocks=8, base_channels=64, padding_mode="reflect"): super(Generator, self).__init__() self.num_adain_blocks = num_adain_blocks - self.style_encoders = nn.ModuleDict(dict( - a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, - norm_type="NONE"), - b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, - norm_type="NONE", fix_vgg19=False) - )) + if style_encoder_type == "StyleEncoder": + self.style_encoders = nn.ModuleDict(dict( + a=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm, + max_multiple=4, padding_mode=padding_mode, norm_type="NONE"), + b=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm, + max_multiple=4, padding_mode=padding_mode, norm_type="NONE"), + )) + elif style_encoder_type == "VGG19StyleEncoder": + self.style_encoders = nn.ModuleDict(dict( + a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, + norm_type="NONE"), + b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, + norm_type="NONE", fix_vgg19=False) + )) + else: + raise NotImplemented(f"do not support {style_encoder_type}") resnet_channels = 2 ** 2 * base_channels self.style_converters = nn.ModuleDict(dict( a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,