TAHG 0.0.3

This commit is contained in:
budui 2020-09-01 09:02:04 +08:00
parent 89b54105c7
commit e71e8d95d0
8 changed files with 97 additions and 36 deletions

View File

@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="15d" remoteFilesAllowedToDisappearOnAutoupload="false">
<component name="PublishConfigData" autoUpload="Always" serverName="22d" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="14d">
<serverdata>
@ -16,6 +16,13 @@
</mappings>
</serverdata>
</paths>
<paths name="22d">
<serverdata>
<mappings>
<mapping deploy="/raycv" local="$PROJECT_DIR$" />
</mappings>
</serverdata>
</paths>
</serverData>
<option name="myAutoUpload" value="ALWAYS" />
</component>

View File

@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="22d-base" project-jdk-type="Python SDK" />
</project>

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="22d-base" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">

View File

@ -17,14 +17,15 @@ checkpoint:
interval:
print_per_iteration: 10 # print once per 10 iteration
tensorboard:
scalar: 10
image: 500
scalar: 100
image: 2
model:
generator:
_type: TAHG-Generator
style_in_channels: 3
content_in_channels: 1
num_blocks: 4
discriminator:
_type: TAHG-Discriminator
in_channels: 3
@ -37,22 +38,22 @@ loss:
weight: 1.0
edge:
criterion: 'L1'
hed_pretrained_model_path: "/root/network-bsds500.pytorch"
weight: 2
hed_pretrained_model_path: "./network-bsds500.pytorch"
weight: 1
perceptual:
layer_weights:
# "3": 1.0
"0": 1.0
"5": 1.0
"10": 1.0
"19": 1.0
"3": 1.0
# "0": 1.0
# "5": 1.0
# "10": 1.0
# "19": 1.0
criterion: 'L2'
style_loss: True
perceptual_loss: False
weight: 100.0
weight: 20
recon:
level: 1
weight: 2
weight: 1
optimizers:
generator:
@ -73,7 +74,7 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 48
batch_size: 160
shuffle: True
num_workers: 2
pin_memory: True

View File

@ -184,22 +184,23 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
self.edges_path = Path(edges_path)
assert self.edges_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=False)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
def get_edge(self, origin_path):
op = Path(origin_path)
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
img = Image.open(edge_path).resize(self.size)
return {"edge": F.to_tensor(img)}
return F.to_tensor(img)
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output = dict()
output["a"], path_a = self.A[a_idx]
output.update(self.get_edge(path_a))
output["b"] = self.B[b_idx]
output["b"], path_b = self.B[b_idx]
output["edge_a"] = self.get_edge(path_a)
output["edge_b"] = self.get_edge(path_b)
return output
def __len__(self):

View File

@ -44,7 +44,7 @@ def build_lr_schedulers(optimizers, config):
)
def get_trainer(config, logger):
def get_trainer(config, logger, train_data_loader):
generator = build_model(config.model.generator, config.distributed.model)
discriminators = dict(
a=build_model(config.model.discriminator, config.distributed.model),
@ -85,11 +85,12 @@ def get_trainer(config, logger):
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real = dict(a=batch["a"], b=batch["b"])
content_img = batch["edge"]
fake = dict(
a=generator(content_img=content_img, style_img=real["a"], which_decoder="a"),
b=generator(content_img=content_img, style_img=real["b"], which_decoder="b"),
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
)
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
optimizers["g"].zero_grad()
loss_g = dict()
@ -99,8 +100,10 @@ def get_trainer(config, logger):
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
_, t = perceptual_loss(fake[d], real[d])
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], content_img)
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], batch["edge_a"])
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
loss_g["recon_b"] = config.loss.recon.weight * recon_loss(rec_b, real["b"])
loss_g["recon_bb"] = config.loss.recon.weight * recon_loss(rec_bb, real["b"])
sum(loss_g.values()).backward()
optimizers["g"].step()
@ -118,7 +121,10 @@ def get_trainer(config, logger):
optimizers["d"].step()
generated_img = {f"real_{k}": real[k].detach() for k in real}
generated_img["rec_b"] = rec_b.detach()
generated_img["rec_bb"] = rec_b.detach()
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
generated_img.update({f"edge_{k}": batch[f"edge_{k}"].expand(-1, 3, -1, -1).detach() for k in "ab"})
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
@ -153,20 +159,21 @@ def get_trainer(config, logger):
loss[tl] = output["loss"][tl]
return loss
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
iter_per_epoch = len(train_data_loader)
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
if tensorboard_handler is not None:
tensorboard_handler.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
)
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
def show_images(engine):
output = engine.state.output
image_order = dict(
a=["real_a", "fake_a"],
b=["real_b", "fake_b"]
a=["edge_a", "real_a", "fake_a", "fake_b"],
b=["edge_b", "real_b", "rec_b", "rec_bb"]
)
for k in "ab":
tensorboard_handler.writer.add_image(
@ -175,6 +182,42 @@ def get_trainer(config, logger):
engine.state.iteration
)
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
test_images = dict(
a=[[], [], [], []],
b=[[], [], [], []]
)
for i in range(random_start, random_start + 10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
batch[k] = batch[k].view(1, *batch[k].size())
real = dict(a=batch["a"], b=batch["b"])
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
)
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
test_images["a"][0].append(batch["edge_a"])
test_images["a"][1].append(batch["a"])
test_images["a"][2].append(fake["a"])
test_images["a"][3].append(fake["b"])
test_images["b"][0].append(batch["edge_b"])
test_images["b"][1].append(batch["b"])
test_images["b"][2].append(rec_b)
test_images["b"][3].append(rec_bb)
for n in "ab":
tensorboard_handler.writer.add_image(
f"test/{n}",
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
engine.state.iteration
)
return trainer
@ -189,7 +232,7 @@ def run(task, config, logger):
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
trainer = get_trainer(config, logger, train_data_loader)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset

View File

@ -87,10 +87,17 @@ class ContentEncoder(nn.Module):
class Decoder(nn.Module):
def __init__(self, out_channels, base_channels=64, num_down_sampling=2, padding_mode='reflect', norm_type="LN"):
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
norm_type="LN"):
super(Decoder, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
res_block_channels = (2 ** 2) * base_channels
self.resnet = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
@ -109,6 +116,7 @@ class Decoder(nn.Module):
)
def forward(self, x):
x = self.resnet(x)
x = self.decoder(x)
x = self.end_conv(x)
return x
@ -159,8 +167,8 @@ class Generator(nn.Module):
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
])
self.decoders = nn.ModuleDict({
"a": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode),
"b": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode)
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode),
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode)
})
self.fc = nn.Sequential(

View File

@ -88,16 +88,17 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
engine.terminate()
def setup_tensorboard_handler(trainer: Engine, config, output_transform):
def setup_tensorboard_handler(trainer: Engine, config, output_transform, iter_per_epoch):
if config.interval.tensorboard is None:
return None
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
basic_event = Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
event_name=basic_event)
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
event_name=basic_event)
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()