Browse Source

PR [#32] change_model_func_name -> dev

change model function name decode to lp_decode in response to PR #31
tags/v0.3.1
Frozenmad GitHub 4 years ago
parent
commit
fbd420e8e5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 8 additions and 8 deletions
  1. +2
    -2
      autogl/module/model/gat.py
  2. +2
    -2
      autogl/module/model/gcn.py
  3. +2
    -2
      autogl/module/model/graphsage.py
  4. +2
    -2
      autogl/module/train/link_prediction.py

+ 2
- 2
autogl/module/model/gat.py View File

@@ -104,12 +104,12 @@ class GAT(torch.nn.Module):
# x = F.dropout(x, p=self.args["dropout"], training=self.training)
return x

def decode(self, z, pos_edge_index, neg_edge_index):
def lp_decode(self, z, pos_edge_index, neg_edge_index):
edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
return logits

def decode_all(self, z):
def lp_decode_all(self, z):
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()



+ 2
- 2
autogl/module/model/gcn.py View File

@@ -93,12 +93,12 @@ class GCN(torch.nn.Module):
# x = F.dropout(x, p=self.args["dropout"], training=self.training)
return x

def decode(self, z, pos_edge_index, neg_edge_index):
def lp_decode(self, z, pos_edge_index, neg_edge_index):
edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
return logits

def decode_all(self, z):
def lp_decode_all(self, z):
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()



+ 2
- 2
autogl/module/model/graphsage.py View File

@@ -180,12 +180,12 @@ class GraphSAGE(torch.nn.Module):
# x = F.dropout(x, p=self.args["dropout"], training=self.training)
return x

def decode(self, z, pos_edge_index, neg_edge_index):
def lp_decode(self, z, pos_edge_index, neg_edge_index):
edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
return logits

def decode_all(self, z):
def lp_decode_all(self, z):
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()



+ 2
- 2
autogl/module/train/link_prediction.py View File

@@ -198,7 +198,7 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer):
optimizer.zero_grad()
# res = self.model.model.forward(data)
z = self.model.model.encode(data)
link_logits = self.model.model.decode(
link_logits = self.model.model.lp_decode(
z, data.train_pos_edge_index, neg_edge_index
)
link_labels = self.get_link_labels(
@@ -320,7 +320,7 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer):
self.model.model.eval()
with torch.no_grad():
z = self.predict_only(data)
link_logits = self.model.model.decode(z, pos_edge_index, neg_edge_index)
link_logits = self.model.model.lp_decode(z, pos_edge_index, neg_edge_index)
link_probs = link_logits.sigmoid()

return link_probs


Loading…
Cancel
Save