From 85b5c3f589b6665529e9b9c475846a4ff2a3468d Mon Sep 17 00:00:00 2001 From: budui Date: Fri, 11 Sep 2020 22:34:43 +0800 Subject: [PATCH] fix small bug in U-GAT-IT --- configs/synthesizers/UGATIT.yml | 1 + engine/U-GAT-IT.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/configs/synthesizers/UGATIT.yml b/configs/synthesizers/UGATIT.yml index 81c68b5..4fe5ebf 100644 --- a/configs/synthesizers/UGATIT.yml +++ b/configs/synthesizers/UGATIT.yml @@ -92,6 +92,7 @@ data: mean: [ 0.5, 0.5, 0.5 ] std: [ 0.5, 0.5, 0.5 ] test: + which: video_dataset dataloader: batch_size: 8 shuffle: False diff --git a/engine/U-GAT-IT.py b/engine/U-GAT-IT.py index 9b7178c..ed9eebc 100644 --- a/engine/U-GAT-IT.py +++ b/engine/U-GAT-IT.py @@ -135,8 +135,8 @@ class UGATITTestEngineKernel(TestEngineKernel): def inference(self, batch): with torch.no_grad(): - fake, _, _ = self.generators["a2b"](batch["a"]) - return {"a": fake.detach()} + fake, _, _ = self.generators["a2b"](batch[0]) + return fake.detach() def run(task, config, _):