base
This commit is contained in:
parent
e71e8d95d0
commit
14d4247112
@ -1,4 +1,4 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="22d-base" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
||||||
@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="22d-base" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
<component name="TestRunnerService">
|
<component name="TestRunnerService">
|
||||||
|
|||||||
145
configs/synthesizers/TAFG.yml
Normal file
145
configs/synthesizers/TAFG.yml
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
name: TAFG
|
||||||
|
engine: TAFG
|
||||||
|
result_dir: ./result
|
||||||
|
max_pairs: 1000000
|
||||||
|
|
||||||
|
misc:
|
||||||
|
random_seed: 324
|
||||||
|
|
||||||
|
checkpoint:
|
||||||
|
epoch_interval: 1 # one checkpoint every 1 epoch
|
||||||
|
n_saved: 2
|
||||||
|
|
||||||
|
interval:
|
||||||
|
print_per_iteration: 10 # print once per 10 iteration
|
||||||
|
tensorboard:
|
||||||
|
scalar: 100
|
||||||
|
image: 2
|
||||||
|
|
||||||
|
model:
|
||||||
|
generator:
|
||||||
|
_type: TAHG-Generator
|
||||||
|
_bn_to_sync_bn: False
|
||||||
|
style_in_channels: 3
|
||||||
|
content_in_channels: 1
|
||||||
|
num_blocks: 4
|
||||||
|
discriminator:
|
||||||
|
_type: MultiScaleDiscriminator
|
||||||
|
num_scale: 2
|
||||||
|
discriminator_cfg:
|
||||||
|
_type: base-PatchDiscriminator
|
||||||
|
in_channels: 3
|
||||||
|
base_channels: 64
|
||||||
|
use_spectral: True
|
||||||
|
need_intermediate_feature: True
|
||||||
|
|
||||||
|
loss:
|
||||||
|
gan:
|
||||||
|
loss_type: hinge
|
||||||
|
real_label_val: 1.0
|
||||||
|
fake_label_val: 0.0
|
||||||
|
weight: 1.0
|
||||||
|
perceptual:
|
||||||
|
layer_weights:
|
||||||
|
"1": 0.03125
|
||||||
|
"6": 0.0625
|
||||||
|
"11": 0.125
|
||||||
|
"20": 0.25
|
||||||
|
"29": 1
|
||||||
|
criterion: 'L1'
|
||||||
|
style_loss: False
|
||||||
|
perceptual_loss: True
|
||||||
|
weight: 1
|
||||||
|
style:
|
||||||
|
layer_weights:
|
||||||
|
"1": 0.03125
|
||||||
|
"6": 0.0625
|
||||||
|
"11": 0.125
|
||||||
|
"20": 0.25
|
||||||
|
"29": 1
|
||||||
|
criterion: 'L2'
|
||||||
|
style_loss: True
|
||||||
|
perceptual_loss: False
|
||||||
|
weight: 0
|
||||||
|
fm:
|
||||||
|
level: 1
|
||||||
|
weight: 1
|
||||||
|
recon:
|
||||||
|
level: 1
|
||||||
|
weight: 1
|
||||||
|
|
||||||
|
optimizers:
|
||||||
|
generator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 0.0001
|
||||||
|
betas: [ 0, 0.9 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
discriminator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 4e-4
|
||||||
|
betas: [ 0, 0.9 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
|
||||||
|
data:
|
||||||
|
train:
|
||||||
|
scheduler:
|
||||||
|
start_proportion: 0.5
|
||||||
|
target_lr: 0
|
||||||
|
buffer_size: 50
|
||||||
|
dataloader:
|
||||||
|
batch_size: 256
|
||||||
|
shuffle: True
|
||||||
|
num_workers: 2
|
||||||
|
pin_memory: True
|
||||||
|
drop_last: True
|
||||||
|
dataset:
|
||||||
|
_type: GenerationUnpairedDatasetWithEdge
|
||||||
|
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
|
||||||
|
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
|
||||||
|
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||||
|
edge_type: "hed"
|
||||||
|
size: [128, 128]
|
||||||
|
random_pair: True
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [128, 128]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
test:
|
||||||
|
dataloader:
|
||||||
|
batch_size: 8
|
||||||
|
shuffle: False
|
||||||
|
num_workers: 1
|
||||||
|
pin_memory: False
|
||||||
|
drop_last: False
|
||||||
|
dataset:
|
||||||
|
_type: GenerationUnpairedDatasetWithEdge
|
||||||
|
root_a: "/data/i2i/VoxCeleb2Anime/testA"
|
||||||
|
root_b: "/data/i2i/VoxCeleb2Anime/testB"
|
||||||
|
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||||
|
edge_type: "hed"
|
||||||
|
random_pair: False
|
||||||
|
size: [128, 128]
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [128, 128]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
video_dataset:
|
||||||
|
_type: SingleFolderDataset
|
||||||
|
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
|
||||||
|
with_path: True
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
@ -3,10 +3,6 @@ engine: TAHG
|
|||||||
result_dir: ./result
|
result_dir: ./result
|
||||||
max_pairs: 1000000
|
max_pairs: 1000000
|
||||||
|
|
||||||
distributed:
|
|
||||||
model:
|
|
||||||
# broadcast_buffers: False
|
|
||||||
|
|
||||||
misc:
|
misc:
|
||||||
random_seed: 324
|
random_seed: 324
|
||||||
|
|
||||||
@ -23,6 +19,7 @@ interval:
|
|||||||
model:
|
model:
|
||||||
generator:
|
generator:
|
||||||
_type: TAHG-Generator
|
_type: TAHG-Generator
|
||||||
|
_bn_to_sync_bn: False
|
||||||
style_in_channels: 3
|
style_in_channels: 3
|
||||||
content_in_channels: 1
|
content_in_channels: 1
|
||||||
num_blocks: 4
|
num_blocks: 4
|
||||||
|
|||||||
133
engine/TAFG.py
Normal file
133
engine/TAFG.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
from itertools import chain
|
||||||
|
from math import ceil
|
||||||
|
|
||||||
|
from omegaconf import read_write, OmegaConf
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import ignite.distributed as idist
|
||||||
|
|
||||||
|
import data
|
||||||
|
from engine.base.i2i import get_trainer, EngineKernel, build_model
|
||||||
|
from model.weight_init import generation_init_weights
|
||||||
|
|
||||||
|
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||||
|
from loss.gan import GANLoss
|
||||||
|
|
||||||
|
|
||||||
|
class TAFGEngineKernel(EngineKernel):
|
||||||
|
def __init__(self, config, logger):
|
||||||
|
super().__init__(config, logger)
|
||||||
|
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
||||||
|
perceptual_loss_cfg.pop("weight")
|
||||||
|
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||||
|
gan_loss_cfg.pop("weight")
|
||||||
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
|
||||||
|
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||||
|
|
||||||
|
def build_models(self) -> (dict, dict):
|
||||||
|
generators = dict(
|
||||||
|
main=build_model(self.config.model.generator)
|
||||||
|
)
|
||||||
|
discriminators = dict(
|
||||||
|
a=build_model(self.config.model.discriminator),
|
||||||
|
b=build_model(self.config.model.discriminator)
|
||||||
|
)
|
||||||
|
self.logger.debug(discriminators["a"])
|
||||||
|
self.logger.debug(generators["main"])
|
||||||
|
|
||||||
|
for m in chain(generators.values(), discriminators.values()):
|
||||||
|
generation_init_weights(m)
|
||||||
|
|
||||||
|
return generators, discriminators
|
||||||
|
|
||||||
|
def setup_before_d(self):
|
||||||
|
for discriminator in self.discriminators.values():
|
||||||
|
discriminator.requires_grad_(True)
|
||||||
|
|
||||||
|
def setup_before_g(self):
|
||||||
|
for discriminator in self.discriminators.values():
|
||||||
|
discriminator.requires_grad_(False)
|
||||||
|
|
||||||
|
def forward(self, batch, inference=False) -> dict:
|
||||||
|
generator = self.generators["main"]
|
||||||
|
with torch.set_grad_enabled(not inference):
|
||||||
|
fake = dict(
|
||||||
|
a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"),
|
||||||
|
b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"),
|
||||||
|
)
|
||||||
|
return fake
|
||||||
|
|
||||||
|
def criterion_generators(self, batch, generated) -> dict:
|
||||||
|
loss = dict()
|
||||||
|
loss["perceptual"], _, = self.perceptual_loss(generated["b"], batch["b"]) * self.config.loss.perceptual.weight
|
||||||
|
for phase in "ab":
|
||||||
|
pred_fake = self.discriminators[phase](generated[phase])
|
||||||
|
for i, sub_pred_fake in enumerate(pred_fake):
|
||||||
|
# last output is actual prediction
|
||||||
|
loss[f"gan_{phase}_sub_{i}"] = self.gan_loss(sub_pred_fake[-1], True)
|
||||||
|
|
||||||
|
if self.config.loss.fm.weight > 0 and phase == "b":
|
||||||
|
pred_real = self.discriminators[phase](batch[phase])
|
||||||
|
loss_fm = 0
|
||||||
|
num_scale_discriminator = len(pred_fake)
|
||||||
|
for i in range(num_scale_discriminator):
|
||||||
|
# last output is the final prediction, so we exclude it
|
||||||
|
num_intermediate_outputs = len(pred_fake[i]) - 1
|
||||||
|
for j in range(num_intermediate_outputs):
|
||||||
|
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||||
|
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
||||||
|
loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def criterion_discriminators(self, batch, generated) -> dict:
|
||||||
|
loss = dict()
|
||||||
|
for phase in self.discriminators.keys():
|
||||||
|
pred_real = self.discriminators[phase](batch[phase])
|
||||||
|
pred_fake = self.discriminators[phase](generated[phase].detach())
|
||||||
|
loss[f"gan_{phase}"] = 0
|
||||||
|
for i in range(len(pred_fake)):
|
||||||
|
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||||
|
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def intermediate_images(self, batch, generated) -> dict:
|
||||||
|
"""
|
||||||
|
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||||
|
:param batch:
|
||||||
|
:param generated: dict of images
|
||||||
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||||
|
"""
|
||||||
|
return dict(
|
||||||
|
a=[batch[f"edge_a"].expand(-1, 3, -1, -1).detach(), batch["a"].detach(), generated["a"].detach()],
|
||||||
|
b=[batch["b"].detach(), generated["b"].detach()]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run(task, config, logger):
|
||||||
|
assert torch.backends.cudnn.enabled
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
logger.info(f"start task {task}")
|
||||||
|
with read_write(config):
|
||||||
|
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
|
||||||
|
|
||||||
|
if task == "train":
|
||||||
|
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||||
|
logger.info(f"train with dataset:\n{train_dataset}")
|
||||||
|
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||||
|
trainer = get_trainer(config, TAFGEngineKernel(config, logger), len(train_data_loader))
|
||||||
|
if idist.get_rank() == 0:
|
||||||
|
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||||
|
trainer.state.test_dataset = test_dataset
|
||||||
|
try:
|
||||||
|
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
else:
|
||||||
|
return NotImplemented(f"invalid task: {task}")
|
||||||
0
engine/base/__init__.py
Normal file
0
engine/base/__init__.py
Normal file
187
engine/base/i2i.py
Normal file
187
engine/base/i2i.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
from itertools import chain
|
||||||
|
from math import ceil
|
||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import ignite.distributed as idist
|
||||||
|
from ignite.engine import Events, Engine
|
||||||
|
from ignite.metrics import RunningAverage
|
||||||
|
from ignite.utils import convert_tensor
|
||||||
|
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
|
||||||
|
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||||
|
|
||||||
|
from model import MODEL
|
||||||
|
from util.image import make_2d_grid
|
||||||
|
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||||
|
from util.build import build_optimizer
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(cfg):
|
||||||
|
cfg = OmegaConf.to_container(cfg)
|
||||||
|
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
|
||||||
|
model = MODEL.build_with(cfg)
|
||||||
|
if bn_to_sync_bn:
|
||||||
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||||
|
return idist.auto_model(model)
|
||||||
|
|
||||||
|
|
||||||
|
def build_lr_schedulers(optimizers, config):
|
||||||
|
# TODO: support more scheduler type
|
||||||
|
g_milestones_values = [
|
||||||
|
(0, config.optimizers.generator.lr),
|
||||||
|
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
||||||
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||||
|
]
|
||||||
|
d_milestones_values = [
|
||||||
|
(0, config.optimizers.discriminator.lr),
|
||||||
|
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
||||||
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||||
|
]
|
||||||
|
return dict(
|
||||||
|
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
|
||||||
|
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EngineKernel(object):
|
||||||
|
def __init__(self, config, logger):
|
||||||
|
self.config = config
|
||||||
|
self.logger = logger
|
||||||
|
self.generators, self.discriminators = self.build_models()
|
||||||
|
|
||||||
|
def build_models(self) -> (dict, dict):
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def to_save(self):
|
||||||
|
to_save = {}
|
||||||
|
to_save.update({f"generator_{k}": self.generators[k] for k in self.generators})
|
||||||
|
to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators})
|
||||||
|
return to_save
|
||||||
|
|
||||||
|
def setup_before_d(self):
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def setup_before_g(self):
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def forward(self, batch, inference=False) -> dict:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def criterion_generators(self, batch, generated) -> dict:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def criterion_discriminators(self, batch, generated) -> dict:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def intermediate_images(self, batch, generated) -> dict:
|
||||||
|
"""
|
||||||
|
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||||
|
:param batch:
|
||||||
|
:param generated: dict of images
|
||||||
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||||
|
"""
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
||||||
|
logger = logging.getLogger(config.name)
|
||||||
|
generators, discriminators = ek.generators, ek.discriminators
|
||||||
|
optimizers = dict(
|
||||||
|
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
|
||||||
|
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
|
||||||
|
)
|
||||||
|
logger.info("build optimizers", optimizers)
|
||||||
|
|
||||||
|
lr_schedulers = build_lr_schedulers(optimizers, config)
|
||||||
|
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
||||||
|
|
||||||
|
def _step(engine, batch):
|
||||||
|
batch = convert_tensor(batch, idist.device())
|
||||||
|
|
||||||
|
generated = ek.forward(batch)
|
||||||
|
|
||||||
|
ek.setup_before_g()
|
||||||
|
optimizers["g"].zero_grad()
|
||||||
|
loss_g = ek.criterion_generators(batch, generated)
|
||||||
|
sum(loss_g.values()).backward()
|
||||||
|
optimizers["g"].step()
|
||||||
|
|
||||||
|
ek.setup_before_d()
|
||||||
|
optimizers["d"].zero_grad()
|
||||||
|
loss_d = ek.criterion_discriminators(batch, generated)
|
||||||
|
sum(loss_d.values()).backward()
|
||||||
|
optimizers["d"].step()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loss": dict(g=loss_g, d=loss_d),
|
||||||
|
"img": ek.intermediate_images(batch, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
trainer = Engine(_step)
|
||||||
|
trainer.logger = logger
|
||||||
|
for lr_shd in lr_schedulers.values():
|
||||||
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
||||||
|
|
||||||
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||||
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
||||||
|
to_save = dict(trainer=trainer)
|
||||||
|
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
||||||
|
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||||
|
to_save.update(ek.to_save())
|
||||||
|
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
|
||||||
|
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||||
|
|
||||||
|
def output_transform(output):
|
||||||
|
loss = dict()
|
||||||
|
for tl in output["loss"]:
|
||||||
|
if isinstance(output["loss"][tl], dict):
|
||||||
|
for l in output["loss"][tl]:
|
||||||
|
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||||
|
else:
|
||||||
|
loss[tl] = output["loss"][tl]
|
||||||
|
return loss
|
||||||
|
|
||||||
|
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
|
||||||
|
if tensorboard_handler is not None:
|
||||||
|
tensorboard_handler.attach(
|
||||||
|
trainer,
|
||||||
|
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
||||||
|
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||||
|
)
|
||||||
|
|
||||||
|
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
|
||||||
|
def show_images(engine):
|
||||||
|
output = engine.state.output
|
||||||
|
test_images = {}
|
||||||
|
for k in output["img"]:
|
||||||
|
image_list = output["img"][k]
|
||||||
|
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list), engine.state.iteration)
|
||||||
|
test_images[k] = []
|
||||||
|
for i in range(len(image_list)):
|
||||||
|
test_images[k].append([])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(config.misc.random_seed)
|
||||||
|
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
|
||||||
|
for i in range(random_start, random_start + 10):
|
||||||
|
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||||
|
for k in batch:
|
||||||
|
batch[k] = batch[k].view(1, *batch[k].size())
|
||||||
|
generated = ek.forward(batch)
|
||||||
|
images = ek.intermediate_images(batch, generated)
|
||||||
|
|
||||||
|
for k in test_images:
|
||||||
|
for j in range(len(images[k])):
|
||||||
|
test_images[k][j].append(images[k][j])
|
||||||
|
for k in test_images:
|
||||||
|
tensorboard_handler.writer.add_image(
|
||||||
|
f"test/{k}",
|
||||||
|
make_2d_grid([torch.cat(ti) for ti in test_images[k]]),
|
||||||
|
engine.state.iteration
|
||||||
|
)
|
||||||
|
return trainer
|
||||||
61
model/GAN/base.py
Normal file
61
model/GAN/base.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from model.normalization import select_norm_layer
|
||||||
|
from model import MODEL
|
||||||
|
|
||||||
|
|
||||||
|
# based SPADE or pix2pixHD Discriminator
|
||||||
|
@MODEL.register_module("base-PatchDiscriminator")
|
||||||
|
class PatchDiscriminator(nn.Module):
|
||||||
|
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
|
||||||
|
need_intermediate_feature=False):
|
||||||
|
super().__init__()
|
||||||
|
self.need_intermediate_feature = need_intermediate_feature
|
||||||
|
|
||||||
|
kernel_size = 4
|
||||||
|
padding = math.ceil((kernel_size - 1.0) / 2)
|
||||||
|
norm_layer = select_norm_layer(norm_type)
|
||||||
|
use_bias = norm_type == "IN"
|
||||||
|
padding_mode = "zeros"
|
||||||
|
|
||||||
|
sequence = [nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, base_channels, kernel_size, stride=2, padding=padding),
|
||||||
|
nn.LeakyReLU(0.2, False)
|
||||||
|
)]
|
||||||
|
multiple_now = 1
|
||||||
|
for i in range(1, num_conv):
|
||||||
|
multiple_prev = multiple_now
|
||||||
|
multiple_now = min(2 ** i, 2 ** 3)
|
||||||
|
stride = 1 if i == num_conv - 1 else 2
|
||||||
|
sequence.append(nn.Sequential(
|
||||||
|
self.build_conv2d(use_spectral, base_channels * multiple_prev, base_channels * multiple_now,
|
||||||
|
kernel_size, stride, padding, bias=use_bias, padding_mode=padding_mode),
|
||||||
|
norm_layer(base_channels * multiple_now),
|
||||||
|
nn.LeakyReLU(0.2, inplace=False),
|
||||||
|
))
|
||||||
|
multiple_now = min(2 ** num_conv, 8)
|
||||||
|
sequence.append(nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding,
|
||||||
|
padding_mode=padding_mode))
|
||||||
|
self.conv_blocks = nn.ModuleList(sequence)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_conv2d(use_spectral, in_channels: int, out_channels: int, kernel_size, stride, padding,
|
||||||
|
bias=True, padding_mode: str = 'zeros'):
|
||||||
|
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, padding_mode=padding_mode)
|
||||||
|
if not use_spectral:
|
||||||
|
return conv
|
||||||
|
return nn.utils.spectral_norm(conv)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.need_intermediate_feature:
|
||||||
|
intermediate_feature = []
|
||||||
|
for layer in self.conv_blocks:
|
||||||
|
x = layer(x)
|
||||||
|
intermediate_feature.append(x)
|
||||||
|
return tuple(intermediate_feature)
|
||||||
|
else:
|
||||||
|
for layer in self.conv_blocks:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
25
model/GAN/wrapper.py
Normal file
25
model/GAN/wrapper.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from model import MODEL
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL.register_module()
|
||||||
|
class MultiScaleDiscriminator(nn.Module):
|
||||||
|
def __init__(self, num_scale, discriminator_cfg):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.discriminator_list = nn.ModuleList([
|
||||||
|
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
|
||||||
|
])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def down_sample(x):
|
||||||
|
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
results = []
|
||||||
|
for discriminator in self.discriminator_list:
|
||||||
|
results.append(discriminator(x))
|
||||||
|
x = self.down_sample(x)
|
||||||
|
return results
|
||||||
@ -3,3 +3,5 @@ import model.GAN.residual_generator
|
|||||||
import model.GAN.TAHG
|
import model.GAN.TAHG
|
||||||
import model.GAN.UGATIT
|
import model.GAN.UGATIT
|
||||||
import model.fewshot
|
import model.fewshot
|
||||||
|
import model.GAN.wrapper
|
||||||
|
import model.GAN.base
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import inspect
|
|||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
import warnings
|
||||||
|
|
||||||
class _Registry:
|
class _Registry:
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
@ -51,6 +51,12 @@ class _Registry:
|
|||||||
else:
|
else:
|
||||||
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
|
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
|
||||||
|
|
||||||
|
for k in args:
|
||||||
|
assert isinstance(k, str)
|
||||||
|
if k.startswith("_"):
|
||||||
|
warnings.warn(f"got param start with `_`: {k}, will remove it")
|
||||||
|
args.pop(k)
|
||||||
|
|
||||||
if not (isinstance(default_args, dict) or default_args is None):
|
if not (isinstance(default_args, dict) or default_args is None):
|
||||||
raise TypeError('default_args must be a dict or None, '
|
raise TypeError('default_args must be a dict or None, '
|
||||||
f'but got {type(default_args)}')
|
f'but got {type(default_args)}')
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user