From 1159328106bbe499ffbbbb55c34f50f7230b3b35 Mon Sep 17 00:00:00 2001 From: lihy96 Date: Tue, 22 Jun 2021 16:42:43 +0800 Subject: [PATCH] change model function name decode to lp_decode --- autogl/module/model/gat.py | 4 ++-- autogl/module/model/gcn.py | 4 ++-- autogl/module/model/graphsage.py | 4 ++-- autogl/module/train/link_prediction.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/autogl/module/model/gat.py b/autogl/module/model/gat.py index f0963fc..c0b9b4b 100644 --- a/autogl/module/model/gat.py +++ b/autogl/module/model/gat.py @@ -104,12 +104,12 @@ class GAT(torch.nn.Module): # x = F.dropout(x, p=self.args["dropout"], training=self.training) return x - def decode(self, z, pos_edge_index, neg_edge_index): + def lp_decode(self, z, pos_edge_index, neg_edge_index): edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) return logits - def decode_all(self, z): + def lp_decode_all(self, z): prob_adj = z @ z.t() return (prob_adj > 0).nonzero(as_tuple=False).t() diff --git a/autogl/module/model/gcn.py b/autogl/module/model/gcn.py index e28bd80..02507e7 100644 --- a/autogl/module/model/gcn.py +++ b/autogl/module/model/gcn.py @@ -93,12 +93,12 @@ class GCN(torch.nn.Module): # x = F.dropout(x, p=self.args["dropout"], training=self.training) return x - def decode(self, z, pos_edge_index, neg_edge_index): + def lp_decode(self, z, pos_edge_index, neg_edge_index): edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) return logits - def decode_all(self, z): + def lp_decode_all(self, z): prob_adj = z @ z.t() return (prob_adj > 0).nonzero(as_tuple=False).t() diff --git a/autogl/module/model/graphsage.py b/autogl/module/model/graphsage.py index 5b09817..56463e1 100644 --- a/autogl/module/model/graphsage.py +++ b/autogl/module/model/graphsage.py @@ -180,12 +180,12 @@ class GraphSAGE(torch.nn.Module): # x = F.dropout(x, p=self.args["dropout"], training=self.training) return x - def decode(self, z, pos_edge_index, neg_edge_index): + def lp_decode(self, z, pos_edge_index, neg_edge_index): edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) return logits - def decode_all(self, z): + def lp_decode_all(self, z): prob_adj = z @ z.t() return (prob_adj > 0).nonzero(as_tuple=False).t() diff --git a/autogl/module/train/link_prediction.py b/autogl/module/train/link_prediction.py index 56cf3fe..056060f 100644 --- a/autogl/module/train/link_prediction.py +++ b/autogl/module/train/link_prediction.py @@ -198,7 +198,7 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer): optimizer.zero_grad() # res = self.model.model.forward(data) z = self.model.model.encode(data) - link_logits = self.model.model.decode( + link_logits = self.model.model.lp_decode( z, data.train_pos_edge_index, neg_edge_index ) link_labels = self.get_link_labels( @@ -320,7 +320,7 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer): self.model.model.eval() with torch.no_grad(): z = self.predict_only(data) - link_logits = self.model.model.decode(z, pos_edge_index, neg_edge_index) + link_logits = self.model.model.lp_decode(z, pos_edge_index, neg_edge_index) link_probs = link_logits.sigmoid() return link_probs