From e71e8d95d09b85fc70eb7dd049ed9bd428119c3a Mon Sep 17 00:00:00 2001 From: budui Date: Tue, 1 Sep 2020 09:02:04 +0800 Subject: [PATCH] TAHG 0.0.3 --- .idea/deployment.xml | 9 ++++- .idea/misc.xml | 2 +- .idea/raycv.iml | 2 +- configs/synthesizers/TAHG.yml | 25 +++++++------- data/dataset.py | 9 ++--- engine/TAHG.py | 65 +++++++++++++++++++++++++++++------ model/GAN/TAHG.py | 14 ++++++-- util/handler.py | 7 ++-- 8 files changed, 97 insertions(+), 36 deletions(-) diff --git a/.idea/deployment.xml b/.idea/deployment.xml index cebeb10..f335efb 100644 --- a/.idea/deployment.xml +++ b/.idea/deployment.xml @@ -1,6 +1,6 @@ - + @@ -16,6 +16,13 @@ + + + + + + + diff --git a/.idea/misc.xml b/.idea/misc.xml index 1b9173d..1eef74e 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/.idea/raycv.iml b/.idea/raycv.iml index 9781a97..a25e5bf 100644 --- a/.idea/raycv.iml +++ b/.idea/raycv.iml @@ -2,7 +2,7 @@ - + diff --git a/configs/synthesizers/TAHG.yml b/configs/synthesizers/TAHG.yml index 1962084..797bbf2 100644 --- a/configs/synthesizers/TAHG.yml +++ b/configs/synthesizers/TAHG.yml @@ -17,14 +17,15 @@ checkpoint: interval: print_per_iteration: 10 # print once per 10 iteration tensorboard: - scalar: 10 - image: 500 + scalar: 100 + image: 2 model: generator: _type: TAHG-Generator style_in_channels: 3 content_in_channels: 1 + num_blocks: 4 discriminator: _type: TAHG-Discriminator in_channels: 3 @@ -37,22 +38,22 @@ loss: weight: 1.0 edge: criterion: 'L1' - hed_pretrained_model_path: "/root/network-bsds500.pytorch" - weight: 2 + hed_pretrained_model_path: "./network-bsds500.pytorch" + weight: 1 perceptual: layer_weights: -# "3": 1.0 - "0": 1.0 - "5": 1.0 - "10": 1.0 - "19": 1.0 + "3": 1.0 +# "0": 1.0 +# "5": 1.0 +# "10": 1.0 +# "19": 1.0 criterion: 'L2' style_loss: True perceptual_loss: False - weight: 100.0 + weight: 20 recon: level: 1 - weight: 2 + weight: 1 optimizers: generator: @@ -73,7 +74,7 @@ data: target_lr: 0 buffer_size: 50 dataloader: - batch_size: 48 + batch_size: 160 shuffle: True num_workers: 2 pin_memory: True diff --git a/data/dataset.py b/data/dataset.py index 69932a8..fa7989b 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -184,22 +184,23 @@ class GenerationUnpairedDatasetWithEdge(Dataset): self.edges_path = Path(edges_path) assert self.edges_path.exists() self.A = SingleFolderDataset(root_a, pipeline, with_path=True) - self.B = SingleFolderDataset(root_b, pipeline, with_path=False) + self.B = SingleFolderDataset(root_b, pipeline, with_path=True) self.random_pair = random_pair def get_edge(self, origin_path): op = Path(origin_path) edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png" img = Image.open(edge_path).resize(self.size) - return {"edge": F.to_tensor(img)} + return F.to_tensor(img) def __getitem__(self, idx): a_idx = idx % len(self.A) b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() output = dict() output["a"], path_a = self.A[a_idx] - output.update(self.get_edge(path_a)) - output["b"] = self.B[b_idx] + output["b"], path_b = self.B[b_idx] + output["edge_a"] = self.get_edge(path_a) + output["edge_b"] = self.get_edge(path_b) return output def __len__(self): diff --git a/engine/TAHG.py b/engine/TAHG.py index c25cfcb..71b705d 100644 --- a/engine/TAHG.py +++ b/engine/TAHG.py @@ -44,7 +44,7 @@ def build_lr_schedulers(optimizers, config): ) -def get_trainer(config, logger): +def get_trainer(config, logger, train_data_loader): generator = build_model(config.model.generator, config.distributed.model) discriminators = dict( a=build_model(config.model.discriminator, config.distributed.model), @@ -85,11 +85,12 @@ def get_trainer(config, logger): def _step(engine, batch): batch = convert_tensor(batch, idist.device()) real = dict(a=batch["a"], b=batch["b"]) - content_img = batch["edge"] fake = dict( - a=generator(content_img=content_img, style_img=real["a"], which_decoder="a"), - b=generator(content_img=content_img, style_img=real["b"], which_decoder="b"), + a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"), + b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"), ) + rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b") + rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b") optimizers["g"].zero_grad() loss_g = dict() @@ -99,8 +100,10 @@ def get_trainer(config, logger): loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True) _, t = perceptual_loss(fake[d], real[d]) loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t - loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], content_img) + loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], batch["edge_a"]) loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"]) + loss_g["recon_b"] = config.loss.recon.weight * recon_loss(rec_b, real["b"]) + loss_g["recon_bb"] = config.loss.recon.weight * recon_loss(rec_bb, real["b"]) sum(loss_g.values()).backward() optimizers["g"].step() @@ -118,7 +121,10 @@ def get_trainer(config, logger): optimizers["d"].step() generated_img = {f"real_{k}": real[k].detach() for k in real} + generated_img["rec_b"] = rec_b.detach() + generated_img["rec_bb"] = rec_b.detach() generated_img.update({f"fake_{k}": fake[k].detach() for k in fake}) + generated_img.update({f"edge_{k}": batch[f"edge_{k}"].expand(-1, 3, -1, -1).detach() for k in "ab"}) return { "loss": { "g": {ln: loss_g[ln].mean().item() for ln in loss_g}, @@ -153,20 +159,21 @@ def get_trainer(config, logger): loss[tl] = output["loss"][tl] return loss - tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform) + iter_per_epoch = len(train_data_loader) + tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch) if tensorboard_handler is not None: tensorboard_handler.attach( trainer, log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"), - event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar) + event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1)) ) - @trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image)) + @trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1))) def show_images(engine): output = engine.state.output image_order = dict( - a=["real_a", "fake_a"], - b=["real_b", "fake_b"] + a=["edge_a", "real_a", "fake_a", "fake_b"], + b=["edge_b", "real_b", "rec_b", "rec_bb"] ) for k in "ab": tensorboard_handler.writer.add_image( @@ -175,6 +182,42 @@ def get_trainer(config, logger): engine.state.iteration ) + with torch.no_grad(): + g = torch.Generator() + g.manual_seed(config.misc.random_seed) + random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0] + test_images = dict( + a=[[], [], [], []], + b=[[], [], [], []] + ) + for i in range(random_start, random_start + 10): + batch = convert_tensor(engine.state.test_dataset[i], idist.device()) + for k in batch: + batch[k] = batch[k].view(1, *batch[k].size()) + + real = dict(a=batch["a"], b=batch["b"]) + fake = dict( + a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"), + b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"), + ) + rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b") + rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b") + + test_images["a"][0].append(batch["edge_a"]) + test_images["a"][1].append(batch["a"]) + test_images["a"][2].append(fake["a"]) + test_images["a"][3].append(fake["b"]) + test_images["b"][0].append(batch["edge_b"]) + test_images["b"][1].append(batch["b"]) + test_images["b"][2].append(rec_b) + test_images["b"][3].append(rec_bb) + for n in "ab": + tensorboard_handler.writer.add_image( + f"test/{n}", + make_2d_grid([torch.cat(ti) for ti in test_images[n]]), + engine.state.iteration + ) + return trainer @@ -189,7 +232,7 @@ def run(task, config, logger): train_dataset = data.DATASET.build_with(config.data.train.dataset) logger.info(f"train with dataset:\n{train_dataset}") train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) - trainer = get_trainer(config, logger) + trainer = get_trainer(config, logger, train_data_loader) if idist.get_rank() == 0: test_dataset = data.DATASET.build_with(config.data.test.dataset) trainer.state.test_dataset = test_dataset diff --git a/model/GAN/TAHG.py b/model/GAN/TAHG.py index 4c38d18..afb619a 100644 --- a/model/GAN/TAHG.py +++ b/model/GAN/TAHG.py @@ -87,10 +87,17 @@ class ContentEncoder(nn.Module): class Decoder(nn.Module): - def __init__(self, out_channels, base_channels=64, num_down_sampling=2, padding_mode='reflect', norm_type="LN"): + def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect', + norm_type="LN"): super(Decoder, self).__init__() norm_layer = select_norm_layer(norm_type) use_bias = norm_type == "IN" + + res_block_channels = (2 ** 2) * base_channels + + self.resnet = nn.Sequential( + *[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)]) + # up sampling submodules = [] for i in range(num_down_sampling): @@ -109,6 +116,7 @@ class Decoder(nn.Module): ) def forward(self, x): + x = self.resnet(x) x = self.decoder(x) x = self.end_conv(x) return x @@ -159,8 +167,8 @@ class Generator(nn.Module): ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks) ]) self.decoders = nn.ModuleDict({ - "a": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode), - "b": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode) + "a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode), + "b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode) }) self.fc = nn.Sequential( diff --git a/util/handler.py b/util/handler.py index 510390f..d7cf1c2 100644 --- a/util/handler.py +++ b/util/handler.py @@ -88,16 +88,17 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_ engine.terminate() -def setup_tensorboard_handler(trainer: Engine, config, output_transform): +def setup_tensorboard_handler(trainer: Engine, config, output_transform, iter_per_epoch): if config.interval.tensorboard is None: return None if idist.get_rank() == 0: # Create a logger tb_logger = TensorboardLogger(log_dir=config.output_dir) + basic_event = Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1)) tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"), - event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar)) + event_name=basic_event) tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform), - event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar)) + event_name=basic_event) @trainer.on(Events.COMPLETED) @idist.one_rank_only()