diff --git a/configs/synthesizers/UGATIT.yml b/configs/synthesizers/UGATIT.yml index 81c68b5..4fe5ebf 100644 --- a/configs/synthesizers/UGATIT.yml +++ b/configs/synthesizers/UGATIT.yml @@ -92,6 +92,7 @@ data: mean: [ 0.5, 0.5, 0.5 ] std: [ 0.5, 0.5, 0.5 ] test: + which: video_dataset dataloader: batch_size: 8 shuffle: False diff --git a/engine/U-GAT-IT.py b/engine/U-GAT-IT.py index 9b7178c..ed9eebc 100644 --- a/engine/U-GAT-IT.py +++ b/engine/U-GAT-IT.py @@ -135,8 +135,8 @@ class UGATITTestEngineKernel(TestEngineKernel): def inference(self, batch): with torch.no_grad(): - fake, _, _ = self.generators["a2b"](batch["a"]) - return {"a": fake.detach()} + fake, _, _ = self.generators["a2b"](batch[0]) + return fake.detach() def run(task, config, _):