fix small bug in U-GAT-IT

This commit is contained in:
budui 2020-09-11 22:34:43 +08:00
parent 72d09aa483
commit 85b5c3f589
2 changed files with 3 additions and 2 deletions

View File

@ -92,6 +92,7 @@ data:
mean: [ 0.5, 0.5, 0.5 ] mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ] std: [ 0.5, 0.5, 0.5 ]
test: test:
which: video_dataset
dataloader: dataloader:
batch_size: 8 batch_size: 8
shuffle: False shuffle: False

View File

@ -135,8 +135,8 @@ class UGATITTestEngineKernel(TestEngineKernel):
def inference(self, batch): def inference(self, batch):
with torch.no_grad(): with torch.no_grad():
fake, _, _ = self.generators["a2b"](batch["a"]) fake, _, _ = self.generators["a2b"](batch[0])
return {"a": fake.detach()} return fake.detach()
def run(task, config, _): def run(task, config, _):