Browse Source

fix warning

develop/0.4/predevelop
defineZYP 3 years ago
parent
commit
e936f0fd39
2 changed files with 3 additions and 3 deletions
  1. +1
    -1
      autogl/datasets/_ogb.py
  2. +2
    -2
      examples/nodeclf_ogb.py

+ 1
- 1
autogl/datasets/_ogb.py View File

@@ -39,7 +39,7 @@ class _OGBNDatasetUtil(_OGBDatasetUtil):
edge_index = SparseTensor(row=torch.tensor(edge_index[0]), col=torch.tensor(edge_index[1]), value=edge_feat, sparse_sizes=(num_nodes, num_nodes))
edge_index = edge_index.to_symmetric()
row, col, _ = edge_index.coo()
edge_index = [row.cpu().detach().numpy(), col.cpu().detach().numpy()]
edge_index = np.array([row.cpu().detach().numpy(), col.cpu().detach().numpy()])
homogeneous_static_graph: GeneralStaticGraph = (
GeneralStaticGraphGenerator.create_homogeneous_static_graph(
dict([


+ 2
- 2
examples/nodeclf_ogb.py View File

@@ -174,8 +174,8 @@ def main():
num_classes = len(np.unique(labels.numpy()))

if args.use_sage:
model = SAGE(data.num_features, args.hidden_channels,
dataset.num_classes, args.num_layers,
model = SAGE(dataset[0].nodes.data[feat].size(1), args.hidden_channels,
num_classes, args.num_layers,
args.dropout).to(device)
else:
model = GCN(dataset[0].nodes.data[feat].size(1), args.hidden_channels,


Loading…
Cancel
Save