Merge pull request !6760 from lijiaqi/sparse_optimizertags/v1.1.0
| @@ -27,6 +27,8 @@ from mindspore._checkparam import Rel | |||
| from .optimizer import Optimizer | |||
| _adam_opt = C.MultitypeFuncGraph("adam_opt") | |||
| _scaler_one = Tensor(1, mstype.int32) | |||
| _scaler_ten = Tensor(10, mstype.float32) | |||
| @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", | |||
| @@ -85,31 +87,80 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d | |||
| return gradient | |||
| @_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _run_opt_with_sparse(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, | |||
| gradient, params, moment1, moment2, ps_parameter): | |||
| @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, | |||
| beta2_power, beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter): | |||
| """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" | |||
| success = True | |||
| indices = gradient.indices | |||
| values = gradient.values | |||
| if ps_parameter: | |||
| op_shape = P.Shape() | |||
| shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), | |||
| shapes = (op_shape(params), op_shape(m), op_shape(v), | |||
| op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), | |||
| op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) | |||
| success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, | |||
| eps, values, indices), shapes), params)) | |||
| else: | |||
| success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, | |||
| return success | |||
| if not target: | |||
| success = F.depend(success, sparse_opt(params, m, v, beta1_power, beta2_power, lr, beta1, beta2, | |||
| eps, values, indices)) | |||
| else: | |||
| op_mul = P.Mul() | |||
| op_square = P.Square() | |||
| op_sqrt = P.Sqrt() | |||
| scatter_add = P.ScatterAdd(use_locking) | |||
| assign_m = F.assign(m, op_mul(beta1, m)) | |||
| assign_v = F.assign(v, op_mul(beta2, v)) | |||
| grad_indices = gradient.indices | |||
| grad_value = gradient.values | |||
| next_m = scatter_add(m, | |||
| grad_indices, | |||
| op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) | |||
| next_v = scatter_add(v, | |||
| grad_indices, | |||
| op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value))) | |||
| if use_nesterov: | |||
| m_temp = next_m * _scaler_ten | |||
| assign_m_nesterov = F.assign(m, op_mul(beta1, next_m)) | |||
| div_value = scatter_add(m, | |||
| op_mul(grad_indices, _scaler_one), | |||
| op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) | |||
| param_update = div_value / (op_sqrt(next_v) + eps) | |||
| m_recover = F.assign(m, m_temp / _scaler_ten) | |||
| F.control_depend(m_temp, assign_m_nesterov) | |||
| F.control_depend(assign_m_nesterov, div_value) | |||
| F.control_depend(param_update, m_recover) | |||
| else: | |||
| param_update = next_m / (op_sqrt(next_v) + eps) | |||
| lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power) | |||
| next_param = params - lr_t * param_update | |||
| F.control_depend(assign_m, next_m) | |||
| F.control_depend(assign_v, next_v) | |||
| success = F.depend(success, F.assign(params, next_param)) | |||
| success = F.depend(success, F.assign(m, next_m)) | |||
| success = F.depend(success, F.assign(v, next_v)) | |||
| return success | |||
| @_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _run_opt_with_one_number(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, | |||
| params, moment1, moment2, ps_parameter): | |||
| @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, | |||
| beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter): | |||
| """Apply adam optimizer to the weight parameter using Tensor.""" | |||
| success = True | |||
| if ps_parameter: | |||
| @@ -161,8 +212,8 @@ class Adam(Optimizer): | |||
| To improve parameter groups performance, the customized order of parameters is supported. | |||
| The sparse strategy is applied while the SparseGatherV2 operator is used for forward network. | |||
| The sparse feature is under continuous development. The sparse | |||
| behavior is currently performed on the CPU. | |||
| The sparse feature is under continuous development. If the sparse strategy wants to be executed on the host, | |||
| set the target to the CPU. | |||
| Args: | |||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | |||
| @@ -242,14 +293,16 @@ class Adam(Optimizer): | |||
| self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") | |||
| self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power") | |||
| self.eps = Tensor(eps, mstype.float32) | |||
| self.use_nesterov = use_nesterov | |||
| self.use_locking = use_locking | |||
| self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') | |||
| self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') | |||
| self._is_device = True | |||
| self.hyper_map = C.HyperMap() | |||
| self.opt = P.Adam(use_locking, use_nesterov) | |||
| self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov) | |||
| self.sparse_opt.add_prim_attr("primitive", "CPU") | |||
| self._ps_pull = P.Pull() | |||
| self._ps_push = P.Push("Adam", [0, 1, 2]) | |||
| self._ps_push.add_prim_attr("use_nesterov", use_nesterov) | |||
| @@ -260,6 +313,7 @@ class Adam(Optimizer): | |||
| moment2 = self.moment2 | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| gradients = self._grad_sparse_indices_deduplicate(gradients) | |||
| lr = self.get_lr() | |||
| beta1_power = self.beta1_power * self.beta1 | |||
| @@ -268,14 +322,26 @@ class Adam(Optimizer): | |||
| self.beta2_power = beta2_power | |||
| if self.is_group_lr: | |||
| success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | |||
| self.use_locking, self.use_nesterov, self._is_device, | |||
| beta1_power, beta2_power, self.beta1, self.beta2, self.eps), | |||
| lr, gradients, params, moment1, moment2, self.ps_parameters) | |||
| else: | |||
| success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | |||
| self.use_locking, self.use_nesterov, self._is_device, | |||
| beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr), | |||
| gradients, params, moment1, moment2, self.ps_parameters) | |||
| return success | |||
| @Optimizer.target.setter | |||
| def target(self, value): | |||
| """If the input value is set to "CPU", the parameters will be updated on the host using the Fused | |||
| optimizer operation.""" | |||
| if value not in ('CPU', 'Ascend'): | |||
| raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) | |||
| self._is_device = (value != 'CPU') | |||
| self._target = value | |||
| class AdamWeightDecay(Optimizer): | |||
| """ | |||
| @@ -89,7 +89,8 @@ class FTRL(Optimizer): | |||
| To improve parameter groups performance, the customized order of parameters can be supported. | |||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. | |||
| The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU. | |||
| The sparse feature is under continuous development. If the sparse strategy wants to be executed on the host, | |||
| set the target to the CPU. | |||
| Args: | |||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | |||
| @@ -154,12 +155,14 @@ class FTRL(Optimizer): | |||
| self.linear = self.parameters.clone(prefix="linear", init='zeros') | |||
| self.l1 = l1 | |||
| self.l2 = l2 | |||
| self.lr = learning_rate | |||
| self.lr_power = lr_power | |||
| if not self.is_group: | |||
| self.decay_flags = tuple((lambda: True)() for x in self.parameters) | |||
| self.hyper_map = C.HyperMap() | |||
| self.opt = P.ApplyFtrl(use_locking=use_locking) | |||
| self.sparse_opt = P.FusedSparseFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking) | |||
| self.use_locking = use_locking | |||
| self.sparse_opt = P.SparseApplyFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking) | |||
| self._ps_pull = P.Pull() | |||
| self._ps_push = P.Push("Ftrl", [0, 1, 2]) | |||
| self._ps_push.add_prim_attr("init_accum", initial_accum) | |||
| @@ -174,9 +177,26 @@ class FTRL(Optimizer): | |||
| linear = self.linear | |||
| grads = self.decay_weight(grads) | |||
| grads = self.scale_grad(grads) | |||
| grads = self._grad_sparse_indices_deduplicate(grads) | |||
| lr = self.get_lr() | |||
| success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | |||
| self.l1, self.l2, self.lr_power, lr), | |||
| linear, grads, params, moments, self.ps_parameters) | |||
| return success | |||
| @Optimizer.target.setter | |||
| def target(self, value): | |||
| """If the input value is set to "CPU", the parameters will be updated on the host using the Fused | |||
| optimizer operation.""" | |||
| if value not in ('CPU', 'Ascend'): | |||
| raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) | |||
| if value == 'CPU': | |||
| self.sparse_opt = P.FusedSparseFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking) | |||
| self.sparse_opt.add_prim_attr("primitive", "CPU") | |||
| else: | |||
| self.sparse_opt = P.SparseApplyFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking) | |||
| self._target = value | |||
| @@ -27,31 +27,57 @@ from .optimizer import Optimizer | |||
| _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") | |||
| @_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _run_opt_with_sparse(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, | |||
| lr, gradient, params, moment1, moment2, ps_parameter): | |||
| @_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power, | |||
| beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter): | |||
| """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" | |||
| success = True | |||
| indices = gradient.indices | |||
| values = gradient.values | |||
| if ps_parameter: | |||
| op_shape = P.Shape() | |||
| shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), | |||
| shapes = (op_shape(params), op_shape(m), op_shape(v), | |||
| op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), | |||
| op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) | |||
| success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, | |||
| eps, values, indices), shapes), params)) | |||
| else: | |||
| success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, | |||
| return success | |||
| if not target: | |||
| success = F.depend(success, sparse_opt(params, m, v, beta1_power, beta2_power, lr, beta1, beta2, | |||
| eps, values, indices)) | |||
| else: | |||
| op_gather = P.GatherV2() | |||
| op_sqrt = P.Sqrt() | |||
| scatter_add = P.ScatterAdd(use_locking) | |||
| scatter_update = P.ScatterUpdate(use_locking) | |||
| m_slice = op_gather(m, indices, 0) | |||
| v_slice = op_gather(v, indices, 0) | |||
| next_m = m_slice * beta1 + values * (1 - beta1) | |||
| next_v = v_slice * beta2 + values * values * (1 - beta2) | |||
| lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power) | |||
| if use_nesterov: | |||
| m_temp = beta1 * next_m + values * (1 - beta1) | |||
| param_update = m_temp / (op_sqrt(next_v) + eps) | |||
| else: | |||
| param_update = next_m / (op_sqrt(next_v) + eps) | |||
| success = F.depend(success, scatter_add(params, indices, - lr_t * param_update)) | |||
| success = F.depend(success, scatter_update(m, indices, next_m)) | |||
| success = F.depend(success, scatter_update(v, indices, next_v)) | |||
| return success | |||
| @_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _run_opt_with_one_number(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, | |||
| lr, gradient, params, moment1, moment2, ps_parameter): | |||
| @_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, | |||
| beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter): | |||
| """Apply lazy adam optimizer to the weight parameter using Tensor.""" | |||
| success = True | |||
| if ps_parameter: | |||
| @@ -108,7 +134,7 @@ class LazyAdam(Optimizer): | |||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. | |||
| The sparse behavior, to be notice, is not equivalent to the | |||
| original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under | |||
| continuous development. The sparse behavior is currently performed on the CPU. | |||
| continuous development. If the sparse strategy wants to be executed on the host, set the target to the CPU. | |||
| Args: | |||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | |||
| @@ -191,14 +217,14 @@ class LazyAdam(Optimizer): | |||
| self.eps = Tensor(eps, mstype.float32) | |||
| self.use_nesterov = use_nesterov | |||
| self.use_locking = use_locking | |||
| self._is_device = True | |||
| self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') | |||
| self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') | |||
| self.hyper_map = C.HyperMap() | |||
| self.opt = P.Adam(use_locking, use_nesterov) | |||
| self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov) | |||
| self.sparse_opt.add_prim_attr("primitive", "CPU") | |||
| self._ps_pull = P.Pull() | |||
| self._ps_push = P.Push("Adam", [0, 1, 2]) | |||
| self._ps_push.add_prim_attr("use_nesterov", use_nesterov) | |||
| @@ -206,6 +232,7 @@ class LazyAdam(Optimizer): | |||
| def construct(self, gradients): | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| gradients = self._grad_sparse_indices_deduplicate(gradients) | |||
| lr = self.get_lr() | |||
| self.beta1_power = self.beta1_power * self.beta1 | |||
| @@ -213,10 +240,22 @@ class LazyAdam(Optimizer): | |||
| if self.is_group_lr: | |||
| success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | |||
| self.use_locking, self.use_nesterov, self._is_device, | |||
| self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps), | |||
| lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters) | |||
| else: | |||
| success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | |||
| self.use_locking, self.use_nesterov, self._is_device, | |||
| self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, lr), | |||
| gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters) | |||
| return success | |||
| @Optimizer.target.setter | |||
| def target(self, value): | |||
| """If the input value is set to "CPU", the parameters will be updated on the host using the Fused | |||
| optimizer operation.""" | |||
| if value not in ('CPU', 'Ascend'): | |||
| raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) | |||
| self._is_device = (value != 'CPU') | |||
| self._target = value | |||
| @@ -26,7 +26,6 @@ from mindspore.common.initializer import initializer | |||
| from mindspore.common.tensor import Tensor, RowTensor | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from mindspore import log as logger | |||
| from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode | |||
| from mindspore.context import ParallelMode | |||
| @@ -105,6 +104,8 @@ class Optimizer(Cell): | |||
| weight_decay = self._preprocess_weight_decay(weight_decay) | |||
| self._unique = True | |||
| self._target = 'Ascend' | |||
| self.dynamic_lr = False | |||
| self.assignadd = None | |||
| self.global_step = None | |||
| @@ -173,6 +174,30 @@ class Optimizer(Cell): | |||
| else: | |||
| self.optim_filter = (True,) * self.param_length | |||
| @property | |||
| def unique(self): | |||
| """This method is to see whether to make unique,This method is read-only.""" | |||
| return self._unique | |||
| @unique.setter | |||
| def unique(self, value): | |||
| """Set whether the input value is unique.""" | |||
| if not isinstance(value, bool): | |||
| raise TypeError("The value type must be bool, but got value type is {}".format(type(value))) | |||
| self._unique = value | |||
| @property | |||
| def target(self): | |||
| """This method is used to determine the value of target and whether the parameter update is performed on | |||
| the host or device. This method is read-only.""" | |||
| return self._target | |||
| @target.setter | |||
| def target(self, value): | |||
| """If the input value is set to "CPU", the parameters will be updated on the host using the Fused | |||
| optimizer operation.""" | |||
| raise NotImplementedError | |||
| def decay_weight(self, gradients): | |||
| """ | |||
| Weight decay. | |||
| @@ -217,6 +242,12 @@ class Optimizer(Cell): | |||
| return gradients | |||
| def _grad_sparse_indices_deduplicate(self, gradients): | |||
| """ In the case of using big operators, de duplicate the 'indexes' in gradients.""" | |||
| if self._target != 'CPU' and self._unique: | |||
| gradients = self.map_(F.partial(_indices_deduplicate), gradients) | |||
| return gradients | |||
| def _preprocess_weight_decay(self, weight_decay): | |||
| """Check weight decay, and convert int to float.""" | |||
| if isinstance(weight_decay, (float, int)): | |||
| @@ -514,7 +545,7 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): | |||
| _grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| _indices_deduplicate = C.MultitypeFuncGraph("indices_deduplicate") | |||
| @_grad_scale.register("Number", "Tensor") | |||
| def tensor_grad_scale(scale, grad): | |||
| @@ -532,6 +563,24 @@ def tensor_grad_scale_with_sparse(scale, grad): | |||
| return RowTensor(grad.indices, grad.values * scale, grad.dense_shape) | |||
| @_indices_deduplicate.register("RowTensor") | |||
| def rowtensor_deduplicate_indices_slices(grad): | |||
| """Unique the indices and sums the 'values' corresponding to the duplicate indices.""" | |||
| indices = grad.indices | |||
| values = grad.values | |||
| unique_indices, index_position = P.Unique()(indices) | |||
| summed_values = P.UnsortedSegmentSum()(values, index_position, P.DynamicShape()(unique_indices)[0]) | |||
| return RowTensor(unique_indices, summed_values, grad.dense_shape) | |||
| @_indices_deduplicate.register("Tensor") | |||
| def tensor_deduplicate_indice_slices(grad): | |||
| """Return the input gradient directly in the dense sences.""" | |||
| return grad | |||
| class _ConvertToCell(LearningRateSchedule): | |||
| """Inner api, convert learning rate of scalar to LearningRateSchedule.""" | |||
| def __init__(self, learning_rate): | |||
| @@ -17,7 +17,6 @@ from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.common import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from .optimizer import Optimizer | |||
| _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") | |||
| @@ -66,8 +65,8 @@ class ProximalAdagrad(Optimizer): | |||
| To improve parameter groups performance, the customized order of parameters can be supported. | |||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. | |||
| The sparse feature is under continuous development. The sparse | |||
| behavior is currently performed on the CPU. | |||
| The sparse feature is under continuous development. If the sparse strategy wants to be executed on the host, | |||
| set the target to the CPU. | |||
| Args: | |||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | |||
| @@ -136,14 +135,16 @@ class ProximalAdagrad(Optimizer): | |||
| self.l1 = Tensor(l1, mstype.float32) | |||
| self.l2 = Tensor(l2, mstype.float32) | |||
| self.hyper_map = C.HyperMap() | |||
| self.use_locking = use_locking | |||
| self.opt = P.ApplyProximalAdagrad(use_locking=use_locking) | |||
| self.sparse_opt = P.FusedSparseProximalAdagrad(use_locking=use_locking) | |||
| self.sparse_opt = P.SparseApplyProximalAdagrad(use_locking=use_locking) | |||
| def construct(self, grads): | |||
| params = self.parameters | |||
| accum = self.accum | |||
| grads = self.decay_weight(grads) | |||
| grads = self.scale_grad(grads) | |||
| grads = self._grad_sparse_indices_deduplicate(grads) | |||
| lr = self.get_lr() | |||
| if self.is_group_lr: | |||
| success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr, | |||
| @@ -152,3 +153,18 @@ class ProximalAdagrad(Optimizer): | |||
| success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2, lr), | |||
| grads, params, accum) | |||
| return success | |||
| @Optimizer.target.setter | |||
| def target(self, value): | |||
| """If the input value is set to "CPU", the parameters will be updated on the host using the Fused | |||
| optimizer operation.""" | |||
| if value not in ('CPU', 'Ascend'): | |||
| raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) | |||
| if value == 'CPU': | |||
| self.sparse_opt = P.FusedSparseProximalAdagrad(self.use_locking).add_prim_attr("primitive", "CPU") | |||
| else: | |||
| self.sparse_opt = P.SparseApplyProximalAdagrad(self.use_locking) | |||
| self._target = value | |||
| @@ -345,8 +345,8 @@ class TrainStepWrap(nn.Cell): | |||
| self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | |||
| self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, | |||
| l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) | |||
| self.optimizer_w.sparse_opt.add_prim_attr("primitive_target", "CPU") | |||
| self.optimizer_d.sparse_opt.add_prim_attr("primitive_target", "CPU") | |||
| self.optimizer_w.target = "CPU" | |||
| self.optimizer_d.target = "CPU" | |||
| else: | |||
| self.optimizer_d = Adam( | |||
| self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | |||
| @@ -74,7 +74,7 @@ def do_sparse_embedding(ps=False): | |||
| net.embedding.embedding_table.set_param_ps() | |||
| optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters())) | |||
| optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU") | |||
| optimizer.target = 'CPU' | |||
| criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||
| net_with_criterion = WithLossCell(net, criterion) | |||
| train_network = TrainOneStepCell(net_with_criterion, optimizer) | |||
| @@ -465,7 +465,7 @@ def test_embedding_lookup_with_mix_precision(): | |||
| criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') | |||
| optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1) | |||
| optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU") | |||
| optimizer.target = 'CPU' | |||
| train_network = ms.amp.build_train_network(net, optimizer, criterion, level="O2") | |||
| train_network.set_train() | |||
| for _ in range(2): | |||
| @@ -109,6 +109,19 @@ def test_sparse_adam_compile(): | |||
| net = NetWithSparseGatherV2() | |||
| net.set_train() | |||
| optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9) | |||
| optimizer.target = 'CPU' | |||
| train_network = TrainOneStepCell(net, optimizer) | |||
| _executor.compile(train_network, indices, label) | |||
| def test_sparse_adam(): | |||
| """ test_sparse_adam """ | |||
| indices = Tensor(np.array([0, 1]).astype(np.int32)) | |||
| label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) | |||
| net = NetWithSparseGatherV2() | |||
| net.set_train() | |||
| optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9) | |||
| train_network = TrainOneStepCell(net, optimizer) | |||
| _executor.compile(train_network, indices, label) | |||
| @@ -72,5 +72,19 @@ def test_spares_ftrl_compile(): | |||
| net.set_train() | |||
| optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0) | |||
| optimizer.target = 'CPU' | |||
| train_network = TrainOneStepCell(net, optimizer) | |||
| _executor.compile(train_network, indices, label) | |||
| def test_spares_ftrl(): | |||
| """ test sparse ftrl""" | |||
| indices = Tensor(np.array([0, 1]).astype(np.int32)) | |||
| label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) | |||
| net = NetWithSparseGatherV2() | |||
| net.set_train() | |||
| optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0) | |||
| optimizer.target = 'Ascend' | |||
| train_network = TrainOneStepCell(net, optimizer) | |||
| _executor.compile(train_network, indices, label) | |||
| @@ -76,6 +76,20 @@ def test_spares_lazy_adam_compile(): | |||
| net.set_train() | |||
| optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0) | |||
| optimizer.target = 'CPU' | |||
| train_network = TrainOneStepCell(net, optimizer) | |||
| _executor.compile(train_network, indices, label) | |||
| def test_spares_lazy_adam(): | |||
| """ test sparse adam""" | |||
| indices = Tensor(np.array([0, 1]).astype(np.int32)) | |||
| label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) | |||
| net = NetWithSparseGatherV2() | |||
| net.set_train() | |||
| optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0) | |||
| optimizer.target = 'Ascend' | |||
| train_network = TrainOneStepCell(net, optimizer) | |||
| _executor.compile(train_network, indices, label) | |||
| @@ -71,6 +71,19 @@ def test_spares_proximal_ada_grad_compile(): | |||
| net = NetWithSparseGatherV2() | |||
| net.set_train() | |||
| optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0) | |||
| optimizer.target = 'CPU' | |||
| train_network = TrainOneStepCell(net, optimizer) | |||
| _executor.compile(train_network, indices, label) | |||
| def test_spares_proximal_ada_grad(): | |||
| """ test sparse proximal_ada_grad """ | |||
| indices = Tensor(np.array([0, 1]).astype(np.int32)) | |||
| label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) | |||
| net = NetWithSparseGatherV2() | |||
| net.set_train() | |||
| optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0) | |||
| train_network = TrainOneStepCell(net, optimizer) | |||
| _executor.compile(train_network, indices, label) | |||
| @@ -0,0 +1,75 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test lazy adam """ | |||
| import numpy as np | |||
| from mindspore.nn.optim import LazyAdam, FTRL, Adam, ProximalAdagrad | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter, context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(enable_sparse=True) | |||
| class NetWithSparseGatherV2(nn.Cell): | |||
| """ NetWithSparseGatherV2 definition """ | |||
| def __init__(self): | |||
| super(NetWithSparseGatherV2, self).__init__() | |||
| self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1") | |||
| self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") | |||
| self.axis = 0 | |||
| self.gather = P.SparseGatherV2() | |||
| def construct(self, indices, label): | |||
| return self.gather(self.weight1, indices, self.axis) + self.weight2 | |||
| def test_ftrl_target(): | |||
| """ test_ftrl_target """ | |||
| net = NetWithSparseGatherV2() | |||
| net.set_train() | |||
| optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0) | |||
| if optimizer.target not in ('CPU', 'Ascend'): | |||
| raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(optimizer.target)) | |||
| def test_lazyadam_target(): | |||
| """ test_lazyadam_target """ | |||
| net = NetWithSparseGatherV2() | |||
| net.set_train() | |||
| optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0) | |||
| if optimizer.target not in ('CPU', 'Ascend'): | |||
| raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(optimizer.target)) | |||
| def test_adam_target(): | |||
| """ test_adam_target """ | |||
| net = NetWithSparseGatherV2() | |||
| net.set_train() | |||
| optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9) | |||
| if optimizer.target not in ('CPU', 'Ascend'): | |||
| raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(optimizer.target)) | |||
| def test_proximal_target(): | |||
| """ test_proximal_target """ | |||
| net = NetWithSparseGatherV2() | |||
| net.set_train() | |||
| optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0) | |||
| if optimizer.target not in ('CPU', 'Ascend'): | |||
| raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(optimizer.target)) | |||