diff --git a/autogl/module/train/graph_classification_full.py b/autogl/module/train/graph_classification_full.py index 6ebcb8b..fa045a0 100644 --- a/autogl/module/train/graph_classification_full.py +++ b/autogl/module/train/graph_classification_full.py @@ -314,6 +314,7 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): if self.pyg_dgl == 'pyg': data = data.to(self.device) pred.append(self.model.model(data)) + label.append(data.y) elif self.pyg_dgl == 'dgl': data = [data[i].to(self.device) for i in range(len(data))] _, labels = data