|
|
|
@@ -510,7 +510,6 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): |
|
|
|
|
|
|
|
y_pred_prob = self._predict_proba(loader=loader) |
|
|
|
y_pred = y_pred_prob.max(1)[1] |
|
|
|
# print(y_pred_prob, y_pred) |
|
|
|
|
|
|
|
y_true_tmp = [] |
|
|
|
for data in loader: |
|
|
|
@@ -521,10 +520,6 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): |
|
|
|
y_true_tmp.append(labels) |
|
|
|
y_true = torch.cat(y_true_tmp, 0) |
|
|
|
|
|
|
|
|
|
|
|
print(y_pred, y_true) |
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(feval, list): |
|
|
|
feval = [feval] |
|
|
|
return_signle = True |
|
|
|
|