diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 7fb1831eaf..bd1cd16319 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -350,8 +350,8 @@ class Adam(Optimizer): if not isinstance(value, str): raise TypeError("The value must be str type, but got value type is {}".format(type(value))) - if value not in ('CPU', 'Ascend'): - raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) + if value not in ('CPU', 'Ascend', 'GPU'): + raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value)) self._is_device = (value != 'CPU') self._target = value diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index a95c08de16..013dcea198 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -192,8 +192,8 @@ class FTRL(Optimizer): if not isinstance(value, str): raise TypeError("The value must be str type, but got value type is {}".format(type(value))) - if value not in ('CPU', 'Ascend'): - raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) + if value not in ('CPU', 'Ascend', 'GPU'): + raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value)) if value == 'CPU': self.sparse_opt = P.FusedSparseFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking) diff --git a/mindspore/nn/optim/lazyadam.py b/mindspore/nn/optim/lazyadam.py index d893d9a2c4..92dc0ecd53 100644 --- a/mindspore/nn/optim/lazyadam.py +++ b/mindspore/nn/optim/lazyadam.py @@ -257,8 +257,8 @@ class LazyAdam(Optimizer): if not isinstance(value, str): raise TypeError("The value must be str type, but got value type is {}".format(type(value))) - if value not in ('CPU', 'Ascend'): - raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) + if value not in ('CPU', 'Ascend', 'GPU'): + raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value)) self._is_device = (value != 'CPU') self._target = value diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index c57f647d1b..3835c30bd3 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -105,7 +105,7 @@ class Optimizer(Cell): weight_decay = self._preprocess_weight_decay(weight_decay) self._unique = True - self._target = 'Ascend' + self._target = context.get_context("device_target") self.dynamic_lr = False self.assignadd = None self.global_step = None @@ -194,8 +194,7 @@ class Optimizer(Cell): @property def target(self): """The method is used to determine whether the parameter is updated on host or device. The input type is str - and can only be 'CPU' and 'Ascend'. In GPU environment, users can only configure value as 'CPU'. - The method is read-only.""" + and can only be 'CPU', 'Ascend' or 'GPU'.""" return self._target @target.setter diff --git a/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/nn/optim/proximal_ada_grad.py index 9031ee58eb..923c6feb73 100644 --- a/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/nn/optim/proximal_ada_grad.py @@ -161,8 +161,8 @@ class ProximalAdagrad(Optimizer): if not isinstance(value, str): raise TypeError("The value must be str type, but got value type is {}".format(type(value))) - if value not in ('CPU', 'Ascend'): - raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) + if value not in ('CPU', 'Ascend', 'GPU'): + raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value)) if value == 'CPU': self.sparse_opt = P.FusedSparseProximalAdagrad(self.use_locking).add_prim_attr("primitive_target", "CPU")