diff --git a/autogl/module/train/link_prediction.py b/autogl/module/train/link_prediction.py index 4020481..eb427a6 100644 --- a/autogl/module/train/link_prediction.py +++ b/autogl/module/train/link_prediction.py @@ -243,11 +243,19 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer): res: The result of predicting on the given dataset. """ + try: + mask = data.test_mask if test_mask is None else test_mask + except: + mask = None data = data.to(self.device) self.model.model.eval() with torch.no_grad(): - z = self.model.model.lp_encode(data) - return z + res = self.model.model.lp_encode(data) + + if mask is None: + return res + else: + return res[mask] def train(self, dataset, keep_valid_result=True): """