Browse Source

add mask to predict_only

tags/v0.3.1
lihy96 Frozenmad 4 years ago
parent
commit
bcb258cbd6
1 changed files with 10 additions and 2 deletions
  1. +10
    -2
      autogl/module/train/link_prediction.py

+ 10
- 2
autogl/module/train/link_prediction.py View File

@@ -243,11 +243,19 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer):
res: The result of predicting on the given dataset.

"""
try:
mask = data.test_mask if test_mask is None else test_mask
except:
mask = None
data = data.to(self.device)
self.model.model.eval()
with torch.no_grad():
z = self.model.model.lp_encode(data)
return z
res = self.model.model.lp_encode(data)

if mask is None:
return res
else:
return res[mask]

def train(self, dataset, keep_valid_result=True):
"""


Loading…
Cancel
Save