From: @wilfchen Reviewed-by: @limingqi107,@cristoval Signed-off-by: @cristovaltags/v1.1.0
| @@ -350,8 +350,8 @@ class Adam(Optimizer): | |||||
| if not isinstance(value, str): | if not isinstance(value, str): | ||||
| raise TypeError("The value must be str type, but got value type is {}".format(type(value))) | raise TypeError("The value must be str type, but got value type is {}".format(type(value))) | ||||
| if value not in ('CPU', 'Ascend'): | |||||
| raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) | |||||
| if value not in ('CPU', 'Ascend', 'GPU'): | |||||
| raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value)) | |||||
| self._is_device = (value != 'CPU') | self._is_device = (value != 'CPU') | ||||
| self._target = value | self._target = value | ||||
| @@ -192,8 +192,8 @@ class FTRL(Optimizer): | |||||
| if not isinstance(value, str): | if not isinstance(value, str): | ||||
| raise TypeError("The value must be str type, but got value type is {}".format(type(value))) | raise TypeError("The value must be str type, but got value type is {}".format(type(value))) | ||||
| if value not in ('CPU', 'Ascend'): | |||||
| raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) | |||||
| if value not in ('CPU', 'Ascend', 'GPU'): | |||||
| raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value)) | |||||
| if value == 'CPU': | if value == 'CPU': | ||||
| self.sparse_opt = P.FusedSparseFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking) | self.sparse_opt = P.FusedSparseFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking) | ||||
| @@ -257,8 +257,8 @@ class LazyAdam(Optimizer): | |||||
| if not isinstance(value, str): | if not isinstance(value, str): | ||||
| raise TypeError("The value must be str type, but got value type is {}".format(type(value))) | raise TypeError("The value must be str type, but got value type is {}".format(type(value))) | ||||
| if value not in ('CPU', 'Ascend'): | |||||
| raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) | |||||
| if value not in ('CPU', 'Ascend', 'GPU'): | |||||
| raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value)) | |||||
| self._is_device = (value != 'CPU') | self._is_device = (value != 'CPU') | ||||
| self._target = value | self._target = value | ||||
| @@ -105,7 +105,7 @@ class Optimizer(Cell): | |||||
| weight_decay = self._preprocess_weight_decay(weight_decay) | weight_decay = self._preprocess_weight_decay(weight_decay) | ||||
| self._unique = True | self._unique = True | ||||
| self._target = 'Ascend' | |||||
| self._target = context.get_context("device_target") | |||||
| self.dynamic_lr = False | self.dynamic_lr = False | ||||
| self.assignadd = None | self.assignadd = None | ||||
| self.global_step = None | self.global_step = None | ||||
| @@ -194,8 +194,7 @@ class Optimizer(Cell): | |||||
| @property | @property | ||||
| def target(self): | def target(self): | ||||
| """The method is used to determine whether the parameter is updated on host or device. The input type is str | """The method is used to determine whether the parameter is updated on host or device. The input type is str | ||||
| and can only be 'CPU' and 'Ascend'. In GPU environment, users can only configure value as 'CPU'. | |||||
| The method is read-only.""" | |||||
| and can only be 'CPU', 'Ascend' or 'GPU'.""" | |||||
| return self._target | return self._target | ||||
| @target.setter | @target.setter | ||||
| @@ -161,8 +161,8 @@ class ProximalAdagrad(Optimizer): | |||||
| if not isinstance(value, str): | if not isinstance(value, str): | ||||
| raise TypeError("The value must be str type, but got value type is {}".format(type(value))) | raise TypeError("The value must be str type, but got value type is {}".format(type(value))) | ||||
| if value not in ('CPU', 'Ascend'): | |||||
| raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) | |||||
| if value not in ('CPU', 'Ascend', 'GPU'): | |||||
| raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value)) | |||||
| if value == 'CPU': | if value == 'CPU': | ||||
| self.sparse_opt = P.FusedSparseProximalAdagrad(self.use_locking).add_prim_attr("primitive_target", "CPU") | self.sparse_opt = P.FusedSparseProximalAdagrad(self.use_locking).add_prim_attr("primitive_target", "CPU") | ||||