|
|
|
@@ -198,7 +198,7 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer): |
|
|
|
optimizer.zero_grad() |
|
|
|
# res = self.model.model.forward(data) |
|
|
|
z = self.model.model.encode(data) |
|
|
|
link_logits = self.model.model.decode( |
|
|
|
link_logits = self.model.model.lp_decode( |
|
|
|
z, data.train_pos_edge_index, neg_edge_index |
|
|
|
) |
|
|
|
link_labels = self.get_link_labels( |
|
|
|
@@ -320,7 +320,7 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer): |
|
|
|
self.model.model.eval() |
|
|
|
with torch.no_grad(): |
|
|
|
z = self.predict_only(data) |
|
|
|
link_logits = self.model.model.decode(z, pos_edge_index, neg_edge_index) |
|
|
|
link_logits = self.model.model.lp_decode(z, pos_edge_index, neg_edge_index) |
|
|
|
link_probs = link_logits.sigmoid() |
|
|
|
|
|
|
|
return link_probs |
|
|
|
|