From bcb258cbd6db2b1092bae9fb7c919ca15c8e42a2 Mon Sep 17 00:00:00 2001 From: lihy96 Date: Wed, 25 Aug 2021 15:35:23 +0800 Subject: [PATCH] add mask to predict_only --- autogl/module/train/link_prediction.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/autogl/module/train/link_prediction.py b/autogl/module/train/link_prediction.py index 4020481..eb427a6 100644 --- a/autogl/module/train/link_prediction.py +++ b/autogl/module/train/link_prediction.py @@ -243,11 +243,19 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer): res: The result of predicting on the given dataset. """ + try: + mask = data.test_mask if test_mask is None else test_mask + except: + mask = None data = data.to(self.device) self.model.model.eval() with torch.no_grad(): - z = self.model.model.lp_encode(data) - return z + res = self.model.model.lp_encode(data) + + if mask is None: + return res + else: + return res[mask] def train(self, dataset, keep_valid_result=True): """