Browse Source

Update graph_classification_full.py

tags/v0.3.1
lihy96 Frozenmad 4 years ago
parent
commit
6d7c3d8ca2
1 changed files with 0 additions and 5 deletions
  1. +0
    -5
      autogl/module/train/graph_classification_full.py

+ 0
- 5
autogl/module/train/graph_classification_full.py View File

@@ -510,7 +510,6 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer):


y_pred_prob = self._predict_proba(loader=loader) y_pred_prob = self._predict_proba(loader=loader)
y_pred = y_pred_prob.max(1)[1] y_pred = y_pred_prob.max(1)[1]
# print(y_pred_prob, y_pred)


y_true_tmp = [] y_true_tmp = []
for data in loader: for data in loader:
@@ -521,10 +520,6 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer):
y_true_tmp.append(labels) y_true_tmp.append(labels)
y_true = torch.cat(y_true_tmp, 0) y_true = torch.cat(y_true_tmp, 0)



print(y_pred, y_true)


if not isinstance(feval, list): if not isinstance(feval, list):
feval = [feval] feval = [feval]
return_signle = True return_signle = True


Loading…
Cancel
Save