| @@ -510,7 +510,6 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||||
| y_pred_prob = self._predict_proba(loader=loader) | y_pred_prob = self._predict_proba(loader=loader) | ||||
| y_pred = y_pred_prob.max(1)[1] | y_pred = y_pred_prob.max(1)[1] | ||||
| # print(y_pred_prob, y_pred) | |||||
| y_true_tmp = [] | y_true_tmp = [] | ||||
| for data in loader: | for data in loader: | ||||
| @@ -521,10 +520,6 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||||
| y_true_tmp.append(labels) | y_true_tmp.append(labels) | ||||
| y_true = torch.cat(y_true_tmp, 0) | y_true = torch.cat(y_true_tmp, 0) | ||||
| print(y_pred, y_true) | |||||
| if not isinstance(feval, list): | if not isinstance(feval, list): | ||||
| feval = [feval] | feval = [feval] | ||||
| return_signle = True | return_signle = True | ||||