|
|
|
@@ -243,11 +243,19 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer): |
|
|
|
res: The result of predicting on the given dataset. |
|
|
|
|
|
|
|
""" |
|
|
|
try: |
|
|
|
mask = data.test_mask if test_mask is None else test_mask |
|
|
|
except: |
|
|
|
mask = None |
|
|
|
data = data.to(self.device) |
|
|
|
self.model.model.eval() |
|
|
|
with torch.no_grad(): |
|
|
|
z = self.model.model.lp_encode(data) |
|
|
|
return z |
|
|
|
res = self.model.model.lp_encode(data) |
|
|
|
|
|
|
|
if mask is None: |
|
|
|
return res |
|
|
|
else: |
|
|
|
return res[mask] |
|
|
|
|
|
|
|
def train(self, dataset, keep_valid_result=True): |
|
|
|
""" |
|
|
|
|