From eb6afdfc17c50fc8687c8fa62da77fc78b201771 Mon Sep 17 00:00:00 2001 From: Frozenmad Date: Sat, 26 Jun 2021 08:24:04 +0000 Subject: [PATCH] fix device error --- autogl/module/nas/estimator/one_shot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogl/module/nas/estimator/one_shot.py b/autogl/module/nas/estimator/one_shot.py index 206f539..21d296c 100644 --- a/autogl/module/nas/estimator/one_shot.py +++ b/autogl/module/nas/estimator/one_shot.py @@ -20,7 +20,7 @@ class OneShotEstimator(BaseEstimator): y = dset.y[getattr(dset, f'{mask}_mask')] loss = self.loss_f(pred, y) #acc=sum(pred.max(1)[1]==y).item()/y.size(0) - probs = F.softmax(pred, dim = 1).cpu().numpy() + probs = F.softmax(pred, dim = 1).detach().cpu().numpy() y = y.cpu() metrics = [eva.evaluate(probs, y) for eva in self.evaluation] return metrics, loss