diff --git a/autogl/module/nas/estimator/one_shot.py b/autogl/module/nas/estimator/one_shot.py index 206f539..21d296c 100644 --- a/autogl/module/nas/estimator/one_shot.py +++ b/autogl/module/nas/estimator/one_shot.py @@ -20,7 +20,7 @@ class OneShotEstimator(BaseEstimator): y = dset.y[getattr(dset, f'{mask}_mask')] loss = self.loss_f(pred, y) #acc=sum(pred.max(1)[1]==y).item()/y.size(0) - probs = F.softmax(pred, dim = 1).cpu().numpy() + probs = F.softmax(pred, dim = 1).detach().cpu().numpy() y = y.cpu() metrics = [eva.evaluate(probs, y) for eva in self.evaluation] return metrics, loss