|
|
|
@@ -314,6 +314,7 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): |
|
|
|
if self.pyg_dgl == 'pyg': |
|
|
|
data = data.to(self.device) |
|
|
|
pred.append(self.model.model(data)) |
|
|
|
label.append(data.y) |
|
|
|
elif self.pyg_dgl == 'dgl': |
|
|
|
data = [data[i].to(self.device) for i in range(len(data))] |
|
|
|
_, labels = data |
|
|
|
|