From 1cdf6a10acf839a0a1bd7f7c47630953c8e28a64 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Mon, 20 Jul 2020 08:51:37 +0800 Subject: [PATCH] rename operators of sparse optimizer --- .../cpu/sparse_apply_adam_cpu_kernel.h | 2 +- .../cpu/sparse_apply_ftrl_cpu_kernel.h | 14 +- .../cpu/sparse_apply_lazy_adam_cpu_kernel.h | 2 +- ...sparse_apply_proximal_adagrad_cpu_kernel.h | 15 +- mindspore/nn/optim/adam.py | 2 +- mindspore/nn/optim/ftrl.py | 3 +- mindspore/nn/optim/lazyadam.py | 2 +- mindspore/nn/optim/proximal_ada_grad.py | 3 +- mindspore/ops/operations/__init__.py | 9 +- mindspore/ops/operations/_inner_ops.py | 180 ----------------- mindspore/ops/operations/nn_ops.py | 186 +++++++++++++++++- tests/st/ops/cpu/test_sparse_apply_adam_op.py | 2 +- tests/st/ops/cpu/test_sparse_apply_ftrl_op.py | 2 +- .../test_sparse_apply_proximal_adagrad_op.py | 2 +- 14 files changed, 197 insertions(+), 227 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h index 5d3d4193f7..3a7a449246 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h @@ -40,7 +40,7 @@ class SparseApplyAdamCPUKernel : public CPUKernel { bool use_nesterov_{false}; }; -MS_REG_CPU_KERNEL(SparseApplyAdam, +MS_REG_CPU_KERNEL(FusedSparseAdam, KernelAttr() .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h index af8796d8a5..c24ce8c703 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h @@ -42,19 +42,7 @@ class SparseApplyFtrlCPUKernel : public CPUKernel { float lr_power_{0}; }; -MS_REG_CPU_KERNEL(SparseApplyFtrl, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyFtrlCPUKernel); - -MS_REG_CPU_KERNEL(SparseApplyFtrlNoReturn, +MS_REG_CPU_KERNEL(FusedSparseFtrl, KernelAttr() .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h index ee95db8f33..e588702aea 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h @@ -40,7 +40,7 @@ class SparseApplyLazyAdamCPUKernel : public CPUKernel { bool use_nesterov_{false}; }; -MS_REG_CPU_KERNEL(SparseApplyLazyAdam, +MS_REG_CPU_KERNEL(FusedSparseLazyAdam, KernelAttr() .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h index 56b180ec0b..616fb9b954 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h @@ -39,20 +39,7 @@ class SparseApplyProximalAdagradCPUKernel : public CPUKernel { size_t var_outer_dim_size_{1}; }; -MS_REG_CPU_KERNEL(SparseApplyProximalAdagrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyProximalAdagradCPUKernel); - -MS_REG_CPU_KERNEL(SparseApplyProximalAdagradNoReturn, +MS_REG_CPU_KERNEL(FusedSparseProximalAdagrad, KernelAttr() .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 1dbfb940ee..61666f7f86 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -299,7 +299,7 @@ class Adam(Optimizer): self.hyper_map = C.HyperMap() self.opt = P.Adam(use_locking, use_nesterov) - self.sparse_opt = P.SparseApplyAdam(use_locking, use_nesterov) + self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov) def construct(self, gradients): params = self.parameters diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index fd755d703a..c9f12dc6d4 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -16,7 +16,6 @@ from mindspore.ops import functional as F, composite as C, operations as P from mindspore.common import Tensor import mindspore.common.dtype as mstype -from mindspore.ops.operations import _inner_ops as inner from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from .optimizer import Optimizer, _apply_decay, _grad_scale @@ -159,7 +158,7 @@ class FTRL(Optimizer): self.decay_tf = tuple((lambda: True)() for x in self.parameters) self.hyper_map = C.HyperMap() self.opt = P.ApplyFtrl(use_locking=use_locking) - self.sparse_opt = inner.SparseApplyFtrlNoReturn(learning_rate, l1, l2, lr_power, use_locking=use_locking) + self.sparse_opt = P.FusedSparseFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking) def construct(self, grads): params = self.parameters diff --git a/mindspore/nn/optim/lazyadam.py b/mindspore/nn/optim/lazyadam.py index 756200c41b..7b5be61268 100644 --- a/mindspore/nn/optim/lazyadam.py +++ b/mindspore/nn/optim/lazyadam.py @@ -182,7 +182,7 @@ class LazyAdam(Optimizer): self.hyper_map = C.HyperMap() self.opt = P.Adam(use_locking, use_nesterov) - self.sparse_opt = P.SparseApplyLazyAdam(use_locking, use_nesterov) + self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov) def construct(self, gradients): gradients = self.decay_weight(gradients) diff --git a/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/nn/optim/proximal_ada_grad.py index 2b965fc5b5..948868322e 100644 --- a/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/nn/optim/proximal_ada_grad.py @@ -16,7 +16,6 @@ from mindspore.ops import functional as F, composite as C, operations as P from mindspore.common import Tensor import mindspore.common.dtype as mstype -from mindspore.ops.operations import _inner_ops as inner from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from .optimizer import Optimizer @@ -101,7 +100,7 @@ class ProximalAdagrad(Optimizer): self.weight_decay = weight_decay self.hyper_map = C.HyperMap() self.opt = P.ApplyProximalAdagrad(use_locking=use_locking) - self.sparse_opt = inner.SparseApplyProximalAdagradNoReturn(use_locking=use_locking) + self.sparse_opt = P.FusedSparseProximalAdagrad(use_locking=use_locking) def construct(self, grads): params = self.parameters diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 267593ab1a..6c409433a0 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -56,7 +56,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A from .random_ops import (RandomChoiceWithMask, Normal, Gamma, Poisson, UniformInt, UniformReal, RandomCategorical, Laplace) -from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm, +from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, DepthwiseConv2dNative, DropoutDoMask, DropoutGrad, Dropout, @@ -74,6 +74,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl SparseSoftmaxCrossEntropyWithLogits, Tanh, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, ApplyProximalAdagrad, SparseApplyProximalAdagrad, SparseApplyAdagradV2, SparseApplyFtrlV2, + FusedSparseFtrl, FusedSparseProximalAdagrad, ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) @@ -114,8 +115,8 @@ __all__ = [ 'MaxPool', 'TopK', 'Adam', - 'SparseApplyAdam', - 'SparseApplyLazyAdam', + 'FusedSparseAdam', + 'FusedSparseLazyAdam', 'Softplus', 'Softmax', 'Softsign', @@ -311,8 +312,10 @@ __all__ = [ "SpaceToBatch", "SparseApplyFtrl", "SparseApplyFtrlV2", + "FusedSparseFtrl", "ApplyProximalAdagrad", "SparseApplyProximalAdagrad", + "FusedSparseProximalAdagrad", "ApplyAdaMax", "ApplyAdadelta", "ApplyAdagrad", diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 7ae59a6467..8e57265991 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -18,9 +18,6 @@ from ..._checkparam import Rel from ..._checkparam import Validator as validator from ...common import dtype as mstype -from ..._c_expression import signature_rw as sig_rw -from ..._c_expression import signature_kind as sig_kind -from ..._c_expression import signature_dtype as sig_dtype from ..primitive import PrimitiveWithInfer, prim_attr_register @@ -394,183 +391,6 @@ class Dequant(PrimitiveWithInfer): return mstype.float16 -class SparseApplyFtrlNoReturn(PrimitiveWithInfer): - """ - Update relevant entries according to the FTRL-proximal scheme. - - Args: - lr (float): The learning rate value, must be positive. - l1 (float): l1 regularization strength, must be greater than or equal to zero. - l2 (float): l2 regularization strength, must be greater than or equal to zero. - lr_power (float): Learning rate power controls how the learning rate decreases during training, - must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero. - use_locking (bool): Use locks for update operation if True . Default: False. - - Inputs: - - **var** (Parameter): The variable to be updated. The data type must be float32. - - **accum** (Parameter): The accum to be updated, must be same type and shape as `var`. - - **linear** (Parameter): The linear to be updated, must be same type and shape as `var`. - - **grad** (Tensor): A tensor of the same type as `var`, for the gradient. - - **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. The shape - of `indices` must be the same as `grad` in first dimension. The type must be int32. - - Outputs: - Tuple of 3 Tensor, this operator will update the input parameters directly, the outputs are useless. - - - **var** (Tensor) - A Tensor with shape (1,). - - **accum** (Tensor) - A Tensor with shape (1,). - - **linear** (Tensor) - A Tensor with shape (1,). - - Examples: - >>> import mindspore - >>> import mindspore.nn as nn - >>> import numpy as np - >>> from mindspore import Parameter - >>> from mindspore import Tensor - >>> from mindspore.ops import operations as P - >>> class SparseApplyFtrlNet(nn.Cell): - >>> def __init__(self): - >>> super(SparseApplyFtrlNet, self).__init__() - >>> self.sparse_apply_ftrl = P.SparseApplyFtrlV2(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5) - >>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var") - >>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum") - >>> self.linear = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="linear") - >>> - >>> def construct(self, grad, indices): - >>> out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices) - >>> return out - >>> - >>> net = SparseApplyFtrlNet() - >>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32)) - >>> indices = Tensor(np.array([0, 1]).astype(np.int32)) - >>> output = net(grad, indices) - """ - __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) - ) - - @prim_attr_register - def __init__(self, lr, l1, l2, lr_power, use_locking=False): - self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'], - outputs=['output']) - validator.check_value_type("lr", lr, [float], self.name) - validator.check_value_type("l1", l1, [float], self.name) - validator.check_value_type("l2", l2, [float], self.name) - validator.check_value_type("lr_power", lr_power, [float], self.name) - self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) - self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) - self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) - self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) - self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) - self.add_prim_attr('primitive_target', 'CPU') - - def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): - validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) - validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) - if len(var_shape) > 1: - validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) - validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) - validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) - return [1], [1], [1] - - def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): - args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, - "linear_dtype": linear_dtype, "grad_dtype": grad_dtype} - validator.check_tensor_type_same(args, [mstype.float32], self.name) - validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) - return var_dtype, accum_dtype, linear_dtype - - -class SparseApplyProximalAdagradNoReturn(PrimitiveWithInfer): - r""" - Updates relevant entries according to the proximal adagrad algorithm. - - .. math:: - accum += grad * grad - .. math:: - \text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}} - .. math:: - var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0) - - Args: - use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. - - Inputs: - - **var** (Parameter) - Variable tensor to be updated. The data type must be float32. - - **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`. - - **lr** (Tensor): The learning rate value. The data type must be float32. - - **l1** (Tensor): l1 regularization strength. The data type must be float32. - - **l2** (Tensor): l2 regularization strength. The data type must be float32. - - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32. - - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The data type - must be int32. - - Outputs: - Tuple of 2 Tensor, this operator will update the input parameters directly, the outputs are useless. - - - **var** (Tensor) - A Tensor with shape (1,). - - **accum** (Tensor) - A Tensor with shape (1,). - - Examples: - >>> import numpy as np - >>> import mindspore.nn as nn - >>> from mindspore import Tensor, Parameter - >>> from mindspore.ops import operations as P - >>> class Net(nn.Cell): - >>> def __init__(self): - >>> super(Net, self).__init__() - >>> self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagradV2() - >>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var") - >>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum") - >>> self.lr = Tensor(0.01, mstype.float32) - >>> self.l1 = Tensor(0.0, mstype.float32) - >>> self.l2 = Tensor(0.0, mstype.float32) - >>> def construct(self, grad, indices): - >>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, - >>> self.l2, grad, indices) - >>> return out - >>> net = Net() - >>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32)) - >>> indices = Tensor(np.array([0, 1]).astype(np.int32)) - >>> output = net(grad, indices) - """ - __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) - ) - - @prim_attr_register - def __init__(self, use_locking=False): - self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'], - outputs=['output']) - self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) - self.add_prim_attr('primitive_target', 'CPU') - - def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): - validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) - return [1], [1] - - def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): - args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} - validator.check_tensor_type_same(args, [mstype.float32], self.name) - validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float32], self.name) - validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float32], self.name) - validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float32], self.name) - valid_types = [mstype.int16, mstype.int32, mstype.int64, - mstype.uint16, mstype.uint32, mstype.uint64] - validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) - return var_dtype, accum_dtype - - class LinSpace(PrimitiveWithInfer): r""" Generates values in an interval. And return the corresponding interpolation accroding to assist. diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index e2f6888c88..2ec1a99a07 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2917,7 +2917,7 @@ class Adam(PrimitiveWithInfer): return var_dtype, m_dtype, v_dtype -class SparseApplyAdam(PrimitiveWithInfer): +class FusedSparseAdam(PrimitiveWithInfer): r""" Merge the duplicate value of the gradient and then updates parameters by Adaptive Moment Estimation (Adam) algorithm. This operator is used when the gradient is sparse. @@ -2979,7 +2979,7 @@ class SparseApplyAdam(PrimitiveWithInfer): >>> class Net(nn.Cell): >>> def __init__(self): >>> super(Net, self).__init__() - >>> self.sparse_apply_adam = P.SparseApplyAdam() + >>> self.sparse_apply_adam = P.FusedSparseAdam() >>> self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var") >>> self.m = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="m") >>> self.v = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="v") @@ -3025,7 +3025,6 @@ class SparseApplyAdam(PrimitiveWithInfer): self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2', 'epsilon', 'grad', 'indices'], outputs=['var', 'm', 'v']) - self.add_prim_attr('primitive_target', 'CPU') def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape, beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape): @@ -3051,7 +3050,7 @@ class SparseApplyAdam(PrimitiveWithInfer): return var_dtype, m_dtype, v_dtype -class SparseApplyLazyAdam(PrimitiveWithInfer): +class FusedSparseLazyAdam(PrimitiveWithInfer): r""" Merge the duplicate value of the gradient and then updates parameters by Adaptive Moment Estimation (Adam) algorithm. This operator is used when the gradient is sparse. The behavior is not equivalent to the @@ -3114,7 +3113,7 @@ class SparseApplyLazyAdam(PrimitiveWithInfer): >>> class Net(nn.Cell): >>> def __init__(self): >>> super(Net, self).__init__() - >>> self.sparse_apply_lazyadam = P.SparseApplyLazyAdam() + >>> self.sparse_apply_lazyadam = P.FusedSparseLazyAdam() >>> self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var") >>> self.m = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="m") >>> self.v = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="v") @@ -3160,7 +3159,6 @@ class SparseApplyLazyAdam(PrimitiveWithInfer): self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2', 'epsilon', 'grad', 'indices'], outputs=['var', 'm', 'v']) - self.add_prim_attr('primitive_target', 'CPU') def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape, beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape): @@ -3187,6 +3185,182 @@ class SparseApplyLazyAdam(PrimitiveWithInfer): return var_dtype, m_dtype, v_dtype +class FusedSparseFtrl(PrimitiveWithInfer): + """ + Merge the duplicate value of the gradient and then update relevant entries according to the FTRL-proximal scheme. + + Args: + lr (float): The learning rate value, must be positive. + l1 (float): l1 regularization strength, must be greater than or equal to zero. + l2 (float): l2 regularization strength, must be greater than or equal to zero. + lr_power (float): Learning rate power controls how the learning rate decreases during training, + must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero. + use_locking (bool): Use locks for update operation if True . Default: False. + + Inputs: + - **var** (Parameter): The variable to be updated. The data type must be float32. + - **accum** (Parameter): The accum to be updated, must be same type and shape as `var`. + - **linear** (Parameter): The linear to be updated, must be same type and shape as `var`. + - **grad** (Tensor): A tensor of the same type as `var`, for the gradient. + - **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. The shape + of `indices` must be the same as `grad` in first dimension. The type must be int32. + + Outputs: + Tuple of 3 Tensor, this operator will update the input parameters directly, the outputs are useless. + + - **var** (Tensor) - A Tensor with shape (1,). + - **accum** (Tensor) - A Tensor with shape (1,). + - **linear** (Tensor) - A Tensor with shape (1,). + + Examples: + >>> import mindspore + >>> import mindspore.nn as nn + >>> import numpy as np + >>> from mindspore import Parameter + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> class SparseApplyFtrlNet(nn.Cell): + >>> def __init__(self): + >>> super(SparseApplyFtrlNet, self).__init__() + >>> self.sparse_apply_ftrl = P.FusedSparseFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5) + >>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var") + >>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum") + >>> self.linear = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="linear") + >>> + >>> def construct(self, grad, indices): + >>> out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices) + >>> return out + >>> + >>> net = SparseApplyFtrlNet() + >>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32)) + >>> indices = Tensor(np.array([0, 1]).astype(np.int32)) + >>> output = net(grad, indices) + """ + __mindspore_signature__ = ( + ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) + ) + + @prim_attr_register + def __init__(self, lr, l1, l2, lr_power, use_locking=False): + self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'], + outputs=['output']) + validator.check_value_type("lr", lr, [float], self.name) + validator.check_value_type("l1", l1, [float], self.name) + validator.check_value_type("l2", l2, [float], self.name) + validator.check_value_type("lr_power", lr_power, [float], self.name) + self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) + self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) + self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) + self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) + + def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): + validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) + validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) + if len(var_shape) > 1: + validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) + validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) + return [1], [1], [1] + + def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): + args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, + "linear_dtype": linear_dtype, "grad_dtype": grad_dtype} + validator.check_tensor_type_same(args, [mstype.float32], self.name) + validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) + return var_dtype, accum_dtype, linear_dtype + + +class FusedSparseProximalAdagrad(PrimitiveWithInfer): + r""" + Merge the duplicate value of the gradient and then Updates relevant entries according to the proximal adagrad + algorithm. + + .. math:: + accum += grad * grad + .. math:: + \text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}} + .. math:: + var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0) + + Args: + use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. + + Inputs: + - **var** (Parameter) - Variable tensor to be updated. The data type must be float32. + - **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`. + - **lr** (Tensor): The learning rate value. The data type must be float32. + - **l1** (Tensor): l1 regularization strength. The data type must be float32. + - **l2** (Tensor): l2 regularization strength. The data type must be float32. + - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32. + - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The data type + must be int32. + + Outputs: + Tuple of 2 Tensor, this operator will update the input parameters directly, the outputs are useless. + + - **var** (Tensor) - A Tensor with shape (1,). + - **accum** (Tensor) - A Tensor with shape (1,). + + Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore import Tensor, Parameter + >>> from mindspore.ops import operations as P + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.sparse_apply_proximal_adagrad = P.FusedSparseProximalAdagrad() + >>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var") + >>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum") + >>> self.lr = Tensor(0.01, mstype.float32) + >>> self.l1 = Tensor(0.0, mstype.float32) + >>> self.l2 = Tensor(0.0, mstype.float32) + >>> def construct(self, grad, indices): + >>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, + >>> self.l2, grad, indices) + >>> return out + >>> net = Net() + >>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32)) + >>> indices = Tensor(np.array([0, 1]).astype(np.int32)) + >>> output = net(grad, indices) + """ + __mindspore_signature__ = ( + ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) + ) + + @prim_attr_register + def __init__(self, use_locking=False): + self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'], + outputs=['output']) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) + + def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): + validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + return [1], [1] + + def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): + args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} + validator.check_tensor_type_same(args, [mstype.float32], self.name) + validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float32], self.name) + validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float32], self.name) + validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float32], self.name) + valid_types = [mstype.int16, mstype.int32, mstype.int64, + mstype.uint16, mstype.uint32, mstype.uint64] + validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) + return var_dtype, accum_dtype + + class BinaryCrossEntropy(PrimitiveWithInfer): r""" Computes the Binary Cross Entropy between the target and the output. diff --git a/tests/st/ops/cpu/test_sparse_apply_adam_op.py b/tests/st/ops/cpu/test_sparse_apply_adam_op.py index 6dd866e96c..e57b8b515d 100644 --- a/tests/st/ops/cpu/test_sparse_apply_adam_op.py +++ b/tests/st/ops/cpu/test_sparse_apply_adam_op.py @@ -33,7 +33,7 @@ epsilon = 1e-8 class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - self.sparse_apply_adam = P.SparseApplyAdam() + self.sparse_apply_adam = P.FusedSparseAdam() self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var") self.m = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="m") self.v = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="v") diff --git a/tests/st/ops/cpu/test_sparse_apply_ftrl_op.py b/tests/st/ops/cpu/test_sparse_apply_ftrl_op.py index dca5cf7a77..31826b3fab 100644 --- a/tests/st/ops/cpu/test_sparse_apply_ftrl_op.py +++ b/tests/st/ops/cpu/test_sparse_apply_ftrl_op.py @@ -26,7 +26,7 @@ import mindspore.common.dtype as mstype class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5) + self.sparse_apply_ftrl = P.FusedSparseFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5) self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var") self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum") self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="linear") diff --git a/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py b/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py index 5d52e71896..696f3bf016 100644 --- a/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py +++ b/tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py @@ -26,7 +26,7 @@ import mindspore.common.dtype as mstype class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() + self.sparse_apply_proximal_adagrad = P.FusedSparseProximalAdagrad() self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var") self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum") self.lr = 0.01