change model function name decode to lp_decode in response to PR #31tags/v0.3.1
| @@ -104,12 +104,12 @@ class GAT(torch.nn.Module): | |||||
| # x = F.dropout(x, p=self.args["dropout"], training=self.training) | # x = F.dropout(x, p=self.args["dropout"], training=self.training) | ||||
| return x | return x | ||||
| def decode(self, z, pos_edge_index, neg_edge_index): | |||||
| def lp_decode(self, z, pos_edge_index, neg_edge_index): | |||||
| edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) | edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) | ||||
| logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) | ||||
| return logits | return logits | ||||
| def decode_all(self, z): | |||||
| def lp_decode_all(self, z): | |||||
| prob_adj = z @ z.t() | prob_adj = z @ z.t() | ||||
| return (prob_adj > 0).nonzero(as_tuple=False).t() | return (prob_adj > 0).nonzero(as_tuple=False).t() | ||||
| @@ -93,12 +93,12 @@ class GCN(torch.nn.Module): | |||||
| # x = F.dropout(x, p=self.args["dropout"], training=self.training) | # x = F.dropout(x, p=self.args["dropout"], training=self.training) | ||||
| return x | return x | ||||
| def decode(self, z, pos_edge_index, neg_edge_index): | |||||
| def lp_decode(self, z, pos_edge_index, neg_edge_index): | |||||
| edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) | edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) | ||||
| logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) | ||||
| return logits | return logits | ||||
| def decode_all(self, z): | |||||
| def lp_decode_all(self, z): | |||||
| prob_adj = z @ z.t() | prob_adj = z @ z.t() | ||||
| return (prob_adj > 0).nonzero(as_tuple=False).t() | return (prob_adj > 0).nonzero(as_tuple=False).t() | ||||
| @@ -180,12 +180,12 @@ class GraphSAGE(torch.nn.Module): | |||||
| # x = F.dropout(x, p=self.args["dropout"], training=self.training) | # x = F.dropout(x, p=self.args["dropout"], training=self.training) | ||||
| return x | return x | ||||
| def decode(self, z, pos_edge_index, neg_edge_index): | |||||
| def lp_decode(self, z, pos_edge_index, neg_edge_index): | |||||
| edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) | edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) | ||||
| logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) | ||||
| return logits | return logits | ||||
| def decode_all(self, z): | |||||
| def lp_decode_all(self, z): | |||||
| prob_adj = z @ z.t() | prob_adj = z @ z.t() | ||||
| return (prob_adj > 0).nonzero(as_tuple=False).t() | return (prob_adj > 0).nonzero(as_tuple=False).t() | ||||
| @@ -198,7 +198,7 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer): | |||||
| optimizer.zero_grad() | optimizer.zero_grad() | ||||
| # res = self.model.model.forward(data) | # res = self.model.model.forward(data) | ||||
| z = self.model.model.encode(data) | z = self.model.model.encode(data) | ||||
| link_logits = self.model.model.decode( | |||||
| link_logits = self.model.model.lp_decode( | |||||
| z, data.train_pos_edge_index, neg_edge_index | z, data.train_pos_edge_index, neg_edge_index | ||||
| ) | ) | ||||
| link_labels = self.get_link_labels( | link_labels = self.get_link_labels( | ||||
| @@ -320,7 +320,7 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer): | |||||
| self.model.model.eval() | self.model.model.eval() | ||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| z = self.predict_only(data) | z = self.predict_only(data) | ||||
| link_logits = self.model.model.decode(z, pos_edge_index, neg_edge_index) | |||||
| link_logits = self.model.model.lp_decode(z, pos_edge_index, neg_edge_index) | |||||
| link_probs = link_logits.sigmoid() | link_probs = link_logits.sigmoid() | ||||
| return link_probs | return link_probs | ||||