Browse Source

fix bugs of pyg backend

tags/v0.3.1
Frozenmad 4 years ago
parent
commit
bebccfa19e
1 changed files with 1 additions and 0 deletions
  1. +1
    -0
      autogl/module/train/graph_classification_full.py

+ 1
- 0
autogl/module/train/graph_classification_full.py View File

@@ -314,6 +314,7 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer):
if self.pyg_dgl == 'pyg':
data = data.to(self.device)
pred.append(self.model.model(data))
label.append(data.y)
elif self.pyg_dgl == 'dgl':
data = [data[i].to(self.device) for i in range(len(data))]
_, labels = data


Loading…
Cancel
Save