TAHG 0.0.3
This commit is contained in:
parent
89b54105c7
commit
e71e8d95d0
@ -1,6 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="15d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="22d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="14d">
|
||||
<serverdata>
|
||||
@ -16,6 +16,13 @@
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="22d">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/raycv" local="$PROJECT_DIR$" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
</serverData>
|
||||
<option name="myAutoUpload" value="ALWAYS" />
|
||||
</component>
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="22d-base" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="22d-base" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
|
||||
@ -17,14 +17,15 @@ checkpoint:
|
||||
interval:
|
||||
print_per_iteration: 10 # print once per 10 iteration
|
||||
tensorboard:
|
||||
scalar: 10
|
||||
image: 500
|
||||
scalar: 100
|
||||
image: 2
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: TAHG-Generator
|
||||
style_in_channels: 3
|
||||
content_in_channels: 1
|
||||
num_blocks: 4
|
||||
discriminator:
|
||||
_type: TAHG-Discriminator
|
||||
in_channels: 3
|
||||
@ -37,22 +38,22 @@ loss:
|
||||
weight: 1.0
|
||||
edge:
|
||||
criterion: 'L1'
|
||||
hed_pretrained_model_path: "/root/network-bsds500.pytorch"
|
||||
weight: 2
|
||||
hed_pretrained_model_path: "./network-bsds500.pytorch"
|
||||
weight: 1
|
||||
perceptual:
|
||||
layer_weights:
|
||||
# "3": 1.0
|
||||
"0": 1.0
|
||||
"5": 1.0
|
||||
"10": 1.0
|
||||
"19": 1.0
|
||||
"3": 1.0
|
||||
# "0": 1.0
|
||||
# "5": 1.0
|
||||
# "10": 1.0
|
||||
# "19": 1.0
|
||||
criterion: 'L2'
|
||||
style_loss: True
|
||||
perceptual_loss: False
|
||||
weight: 100.0
|
||||
weight: 20
|
||||
recon:
|
||||
level: 1
|
||||
weight: 2
|
||||
weight: 1
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
@ -73,7 +74,7 @@ data:
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 48
|
||||
batch_size: 160
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
|
||||
@ -184,22 +184,23 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
self.edges_path = Path(edges_path)
|
||||
assert self.edges_path.exists()
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=False)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||
self.random_pair = random_pair
|
||||
|
||||
def get_edge(self, origin_path):
|
||||
op = Path(origin_path)
|
||||
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
|
||||
img = Image.open(edge_path).resize(self.size)
|
||||
return {"edge": F.to_tensor(img)}
|
||||
return F.to_tensor(img)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
output = dict()
|
||||
output["a"], path_a = self.A[a_idx]
|
||||
output.update(self.get_edge(path_a))
|
||||
output["b"] = self.B[b_idx]
|
||||
output["b"], path_b = self.B[b_idx]
|
||||
output["edge_a"] = self.get_edge(path_a)
|
||||
output["edge_b"] = self.get_edge(path_b)
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@ -44,7 +44,7 @@ def build_lr_schedulers(optimizers, config):
|
||||
)
|
||||
|
||||
|
||||
def get_trainer(config, logger):
|
||||
def get_trainer(config, logger, train_data_loader):
|
||||
generator = build_model(config.model.generator, config.distributed.model)
|
||||
discriminators = dict(
|
||||
a=build_model(config.model.discriminator, config.distributed.model),
|
||||
@ -85,11 +85,12 @@ def get_trainer(config, logger):
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real = dict(a=batch["a"], b=batch["b"])
|
||||
content_img = batch["edge"]
|
||||
fake = dict(
|
||||
a=generator(content_img=content_img, style_img=real["a"], which_decoder="a"),
|
||||
b=generator(content_img=content_img, style_img=real["b"], which_decoder="b"),
|
||||
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
|
||||
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
|
||||
)
|
||||
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
|
||||
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
|
||||
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = dict()
|
||||
@ -99,8 +100,10 @@ def get_trainer(config, logger):
|
||||
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
|
||||
_, t = perceptual_loss(fake[d], real[d])
|
||||
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
|
||||
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], content_img)
|
||||
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], batch["edge_a"])
|
||||
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
|
||||
loss_g["recon_b"] = config.loss.recon.weight * recon_loss(rec_b, real["b"])
|
||||
loss_g["recon_bb"] = config.loss.recon.weight * recon_loss(rec_bb, real["b"])
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
|
||||
@ -118,7 +121,10 @@ def get_trainer(config, logger):
|
||||
optimizers["d"].step()
|
||||
|
||||
generated_img = {f"real_{k}": real[k].detach() for k in real}
|
||||
generated_img["rec_b"] = rec_b.detach()
|
||||
generated_img["rec_bb"] = rec_b.detach()
|
||||
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
|
||||
generated_img.update({f"edge_{k}": batch[f"edge_{k}"].expand(-1, 3, -1, -1).detach() for k in "ab"})
|
||||
return {
|
||||
"loss": {
|
||||
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
||||
@ -153,20 +159,21 @@ def get_trainer(config, logger):
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
|
||||
iter_per_epoch = len(train_data_loader)
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
|
||||
if tensorboard_handler is not None:
|
||||
tensorboard_handler.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
|
||||
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||
)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
|
||||
def show_images(engine):
|
||||
output = engine.state.output
|
||||
image_order = dict(
|
||||
a=["real_a", "fake_a"],
|
||||
b=["real_b", "fake_b"]
|
||||
a=["edge_a", "real_a", "fake_a", "fake_b"],
|
||||
b=["edge_b", "real_b", "rec_b", "rec_bb"]
|
||||
)
|
||||
for k in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
@ -175,6 +182,42 @@ def get_trainer(config, logger):
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed)
|
||||
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
|
||||
test_images = dict(
|
||||
a=[[], [], [], []],
|
||||
b=[[], [], [], []]
|
||||
)
|
||||
for i in range(random_start, random_start + 10):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
for k in batch:
|
||||
batch[k] = batch[k].view(1, *batch[k].size())
|
||||
|
||||
real = dict(a=batch["a"], b=batch["b"])
|
||||
fake = dict(
|
||||
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
|
||||
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
|
||||
)
|
||||
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
|
||||
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
|
||||
|
||||
test_images["a"][0].append(batch["edge_a"])
|
||||
test_images["a"][1].append(batch["a"])
|
||||
test_images["a"][2].append(fake["a"])
|
||||
test_images["a"][3].append(fake["b"])
|
||||
test_images["b"][0].append(batch["edge_b"])
|
||||
test_images["b"][1].append(batch["b"])
|
||||
test_images["b"][2].append(rec_b)
|
||||
test_images["b"][3].append(rec_bb)
|
||||
for n in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{n}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
@ -189,7 +232,7 @@ def run(task, config, logger):
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, logger)
|
||||
trainer = get_trainer(config, logger, train_data_loader)
|
||||
if idist.get_rank() == 0:
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
trainer.state.test_dataset = test_dataset
|
||||
|
||||
@ -87,10 +87,17 @@ class ContentEncoder(nn.Module):
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, out_channels, base_channels=64, num_down_sampling=2, padding_mode='reflect', norm_type="LN"):
|
||||
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
|
||||
norm_type="LN"):
|
||||
super(Decoder, self).__init__()
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
res_block_channels = (2 ** 2) * base_channels
|
||||
|
||||
self.resnet = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
|
||||
|
||||
# up sampling
|
||||
submodules = []
|
||||
for i in range(num_down_sampling):
|
||||
@ -109,6 +116,7 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.resnet(x)
|
||||
x = self.decoder(x)
|
||||
x = self.end_conv(x)
|
||||
return x
|
||||
@ -159,8 +167,8 @@ class Generator(nn.Module):
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||
])
|
||||
self.decoders = nn.ModuleDict({
|
||||
"a": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode),
|
||||
"b": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode)
|
||||
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode),
|
||||
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode)
|
||||
})
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
|
||||
@ -88,16 +88,17 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
engine.terminate()
|
||||
|
||||
|
||||
def setup_tensorboard_handler(trainer: Engine, config, output_transform):
|
||||
def setup_tensorboard_handler(trainer: Engine, config, output_transform, iter_per_epoch):
|
||||
if config.interval.tensorboard is None:
|
||||
return None
|
||||
if idist.get_rank() == 0:
|
||||
# Create a logger
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
basic_event = Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
|
||||
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
|
||||
event_name=basic_event)
|
||||
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
|
||||
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
|
||||
event_name=basic_event)
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
@idist.one_rank_only()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user