TAHG 0.0.3
This commit is contained in:
parent
89b54105c7
commit
e71e8d95d0
@ -1,6 +1,6 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="PublishConfigData" autoUpload="Always" serverName="15d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
<component name="PublishConfigData" autoUpload="Always" serverName="22d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||||
<serverData>
|
<serverData>
|
||||||
<paths name="14d">
|
<paths name="14d">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
@ -16,6 +16,13 @@
|
|||||||
</mappings>
|
</mappings>
|
||||||
</serverdata>
|
</serverdata>
|
||||||
</paths>
|
</paths>
|
||||||
|
<paths name="22d">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping deploy="/raycv" local="$PROJECT_DIR$" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
</serverData>
|
</serverData>
|
||||||
<option name="myAutoUpload" value="ALWAYS" />
|
<option name="myAutoUpload" value="ALWAYS" />
|
||||||
</component>
|
</component>
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="22d-base" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
||||||
@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="22d-base" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
<component name="TestRunnerService">
|
<component name="TestRunnerService">
|
||||||
|
|||||||
@ -17,14 +17,15 @@ checkpoint:
|
|||||||
interval:
|
interval:
|
||||||
print_per_iteration: 10 # print once per 10 iteration
|
print_per_iteration: 10 # print once per 10 iteration
|
||||||
tensorboard:
|
tensorboard:
|
||||||
scalar: 10
|
scalar: 100
|
||||||
image: 500
|
image: 2
|
||||||
|
|
||||||
model:
|
model:
|
||||||
generator:
|
generator:
|
||||||
_type: TAHG-Generator
|
_type: TAHG-Generator
|
||||||
style_in_channels: 3
|
style_in_channels: 3
|
||||||
content_in_channels: 1
|
content_in_channels: 1
|
||||||
|
num_blocks: 4
|
||||||
discriminator:
|
discriminator:
|
||||||
_type: TAHG-Discriminator
|
_type: TAHG-Discriminator
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
@ -37,22 +38,22 @@ loss:
|
|||||||
weight: 1.0
|
weight: 1.0
|
||||||
edge:
|
edge:
|
||||||
criterion: 'L1'
|
criterion: 'L1'
|
||||||
hed_pretrained_model_path: "/root/network-bsds500.pytorch"
|
hed_pretrained_model_path: "./network-bsds500.pytorch"
|
||||||
weight: 2
|
weight: 1
|
||||||
perceptual:
|
perceptual:
|
||||||
layer_weights:
|
layer_weights:
|
||||||
# "3": 1.0
|
"3": 1.0
|
||||||
"0": 1.0
|
# "0": 1.0
|
||||||
"5": 1.0
|
# "5": 1.0
|
||||||
"10": 1.0
|
# "10": 1.0
|
||||||
"19": 1.0
|
# "19": 1.0
|
||||||
criterion: 'L2'
|
criterion: 'L2'
|
||||||
style_loss: True
|
style_loss: True
|
||||||
perceptual_loss: False
|
perceptual_loss: False
|
||||||
weight: 100.0
|
weight: 20
|
||||||
recon:
|
recon:
|
||||||
level: 1
|
level: 1
|
||||||
weight: 2
|
weight: 1
|
||||||
|
|
||||||
optimizers:
|
optimizers:
|
||||||
generator:
|
generator:
|
||||||
@ -73,7 +74,7 @@ data:
|
|||||||
target_lr: 0
|
target_lr: 0
|
||||||
buffer_size: 50
|
buffer_size: 50
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 48
|
batch_size: 160
|
||||||
shuffle: True
|
shuffle: True
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
pin_memory: True
|
pin_memory: True
|
||||||
|
|||||||
@ -184,22 +184,23 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
|||||||
self.edges_path = Path(edges_path)
|
self.edges_path = Path(edges_path)
|
||||||
assert self.edges_path.exists()
|
assert self.edges_path.exists()
|
||||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=False)
|
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||||
self.random_pair = random_pair
|
self.random_pair = random_pair
|
||||||
|
|
||||||
def get_edge(self, origin_path):
|
def get_edge(self, origin_path):
|
||||||
op = Path(origin_path)
|
op = Path(origin_path)
|
||||||
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
|
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
|
||||||
img = Image.open(edge_path).resize(self.size)
|
img = Image.open(edge_path).resize(self.size)
|
||||||
return {"edge": F.to_tensor(img)}
|
return F.to_tensor(img)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
a_idx = idx % len(self.A)
|
a_idx = idx % len(self.A)
|
||||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||||
output = dict()
|
output = dict()
|
||||||
output["a"], path_a = self.A[a_idx]
|
output["a"], path_a = self.A[a_idx]
|
||||||
output.update(self.get_edge(path_a))
|
output["b"], path_b = self.B[b_idx]
|
||||||
output["b"] = self.B[b_idx]
|
output["edge_a"] = self.get_edge(path_a)
|
||||||
|
output["edge_b"] = self.get_edge(path_b)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|||||||
@ -44,7 +44,7 @@ def build_lr_schedulers(optimizers, config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_trainer(config, logger):
|
def get_trainer(config, logger, train_data_loader):
|
||||||
generator = build_model(config.model.generator, config.distributed.model)
|
generator = build_model(config.model.generator, config.distributed.model)
|
||||||
discriminators = dict(
|
discriminators = dict(
|
||||||
a=build_model(config.model.discriminator, config.distributed.model),
|
a=build_model(config.model.discriminator, config.distributed.model),
|
||||||
@ -85,11 +85,12 @@ def get_trainer(config, logger):
|
|||||||
def _step(engine, batch):
|
def _step(engine, batch):
|
||||||
batch = convert_tensor(batch, idist.device())
|
batch = convert_tensor(batch, idist.device())
|
||||||
real = dict(a=batch["a"], b=batch["b"])
|
real = dict(a=batch["a"], b=batch["b"])
|
||||||
content_img = batch["edge"]
|
|
||||||
fake = dict(
|
fake = dict(
|
||||||
a=generator(content_img=content_img, style_img=real["a"], which_decoder="a"),
|
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
|
||||||
b=generator(content_img=content_img, style_img=real["b"], which_decoder="b"),
|
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
|
||||||
)
|
)
|
||||||
|
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
|
||||||
|
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
|
||||||
|
|
||||||
optimizers["g"].zero_grad()
|
optimizers["g"].zero_grad()
|
||||||
loss_g = dict()
|
loss_g = dict()
|
||||||
@ -99,8 +100,10 @@ def get_trainer(config, logger):
|
|||||||
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
|
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
|
||||||
_, t = perceptual_loss(fake[d], real[d])
|
_, t = perceptual_loss(fake[d], real[d])
|
||||||
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
|
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
|
||||||
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], content_img)
|
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], batch["edge_a"])
|
||||||
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
|
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
|
||||||
|
loss_g["recon_b"] = config.loss.recon.weight * recon_loss(rec_b, real["b"])
|
||||||
|
loss_g["recon_bb"] = config.loss.recon.weight * recon_loss(rec_bb, real["b"])
|
||||||
sum(loss_g.values()).backward()
|
sum(loss_g.values()).backward()
|
||||||
optimizers["g"].step()
|
optimizers["g"].step()
|
||||||
|
|
||||||
@ -118,7 +121,10 @@ def get_trainer(config, logger):
|
|||||||
optimizers["d"].step()
|
optimizers["d"].step()
|
||||||
|
|
||||||
generated_img = {f"real_{k}": real[k].detach() for k in real}
|
generated_img = {f"real_{k}": real[k].detach() for k in real}
|
||||||
|
generated_img["rec_b"] = rec_b.detach()
|
||||||
|
generated_img["rec_bb"] = rec_b.detach()
|
||||||
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
|
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
|
||||||
|
generated_img.update({f"edge_{k}": batch[f"edge_{k}"].expand(-1, 3, -1, -1).detach() for k in "ab"})
|
||||||
return {
|
return {
|
||||||
"loss": {
|
"loss": {
|
||||||
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
||||||
@ -153,20 +159,21 @@ def get_trainer(config, logger):
|
|||||||
loss[tl] = output["loss"][tl]
|
loss[tl] = output["loss"][tl]
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
|
iter_per_epoch = len(train_data_loader)
|
||||||
|
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
|
||||||
if tensorboard_handler is not None:
|
if tensorboard_handler is not None:
|
||||||
tensorboard_handler.attach(
|
tensorboard_handler.attach(
|
||||||
trainer,
|
trainer,
|
||||||
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
||||||
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
|
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||||
)
|
)
|
||||||
|
|
||||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
|
||||||
def show_images(engine):
|
def show_images(engine):
|
||||||
output = engine.state.output
|
output = engine.state.output
|
||||||
image_order = dict(
|
image_order = dict(
|
||||||
a=["real_a", "fake_a"],
|
a=["edge_a", "real_a", "fake_a", "fake_b"],
|
||||||
b=["real_b", "fake_b"]
|
b=["edge_b", "real_b", "rec_b", "rec_bb"]
|
||||||
)
|
)
|
||||||
for k in "ab":
|
for k in "ab":
|
||||||
tensorboard_handler.writer.add_image(
|
tensorboard_handler.writer.add_image(
|
||||||
@ -175,6 +182,42 @@ def get_trainer(config, logger):
|
|||||||
engine.state.iteration
|
engine.state.iteration
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(config.misc.random_seed)
|
||||||
|
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
|
||||||
|
test_images = dict(
|
||||||
|
a=[[], [], [], []],
|
||||||
|
b=[[], [], [], []]
|
||||||
|
)
|
||||||
|
for i in range(random_start, random_start + 10):
|
||||||
|
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||||
|
for k in batch:
|
||||||
|
batch[k] = batch[k].view(1, *batch[k].size())
|
||||||
|
|
||||||
|
real = dict(a=batch["a"], b=batch["b"])
|
||||||
|
fake = dict(
|
||||||
|
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
|
||||||
|
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
|
||||||
|
)
|
||||||
|
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
|
||||||
|
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
|
||||||
|
|
||||||
|
test_images["a"][0].append(batch["edge_a"])
|
||||||
|
test_images["a"][1].append(batch["a"])
|
||||||
|
test_images["a"][2].append(fake["a"])
|
||||||
|
test_images["a"][3].append(fake["b"])
|
||||||
|
test_images["b"][0].append(batch["edge_b"])
|
||||||
|
test_images["b"][1].append(batch["b"])
|
||||||
|
test_images["b"][2].append(rec_b)
|
||||||
|
test_images["b"][3].append(rec_bb)
|
||||||
|
for n in "ab":
|
||||||
|
tensorboard_handler.writer.add_image(
|
||||||
|
f"test/{n}",
|
||||||
|
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
|
||||||
|
engine.state.iteration
|
||||||
|
)
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
@ -189,7 +232,7 @@ def run(task, config, logger):
|
|||||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||||
logger.info(f"train with dataset:\n{train_dataset}")
|
logger.info(f"train with dataset:\n{train_dataset}")
|
||||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||||
trainer = get_trainer(config, logger)
|
trainer = get_trainer(config, logger, train_data_loader)
|
||||||
if idist.get_rank() == 0:
|
if idist.get_rank() == 0:
|
||||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||||
trainer.state.test_dataset = test_dataset
|
trainer.state.test_dataset = test_dataset
|
||||||
|
|||||||
@ -87,10 +87,17 @@ class ContentEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, out_channels, base_channels=64, num_down_sampling=2, padding_mode='reflect', norm_type="LN"):
|
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
|
||||||
|
norm_type="LN"):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
norm_layer = select_norm_layer(norm_type)
|
norm_layer = select_norm_layer(norm_type)
|
||||||
use_bias = norm_type == "IN"
|
use_bias = norm_type == "IN"
|
||||||
|
|
||||||
|
res_block_channels = (2 ** 2) * base_channels
|
||||||
|
|
||||||
|
self.resnet = nn.Sequential(
|
||||||
|
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
|
||||||
|
|
||||||
# up sampling
|
# up sampling
|
||||||
submodules = []
|
submodules = []
|
||||||
for i in range(num_down_sampling):
|
for i in range(num_down_sampling):
|
||||||
@ -109,6 +116,7 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
x = self.resnet(x)
|
||||||
x = self.decoder(x)
|
x = self.decoder(x)
|
||||||
x = self.end_conv(x)
|
x = self.end_conv(x)
|
||||||
return x
|
return x
|
||||||
@ -159,8 +167,8 @@ class Generator(nn.Module):
|
|||||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||||
])
|
])
|
||||||
self.decoders = nn.ModuleDict({
|
self.decoders = nn.ModuleDict({
|
||||||
"a": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode),
|
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode),
|
||||||
"b": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode)
|
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode)
|
||||||
})
|
})
|
||||||
|
|
||||||
self.fc = nn.Sequential(
|
self.fc = nn.Sequential(
|
||||||
|
|||||||
@ -88,16 +88,17 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
|||||||
engine.terminate()
|
engine.terminate()
|
||||||
|
|
||||||
|
|
||||||
def setup_tensorboard_handler(trainer: Engine, config, output_transform):
|
def setup_tensorboard_handler(trainer: Engine, config, output_transform, iter_per_epoch):
|
||||||
if config.interval.tensorboard is None:
|
if config.interval.tensorboard is None:
|
||||||
return None
|
return None
|
||||||
if idist.get_rank() == 0:
|
if idist.get_rank() == 0:
|
||||||
# Create a logger
|
# Create a logger
|
||||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||||
|
basic_event = Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||||
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
|
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
|
||||||
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
|
event_name=basic_event)
|
||||||
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
|
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
|
||||||
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
|
event_name=basic_event)
|
||||||
|
|
||||||
@trainer.on(Events.COMPLETED)
|
@trainer.on(Events.COMPLETED)
|
||||||
@idist.one_rank_only()
|
@idist.one_rank_only()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user