Browse Source

rename operators of sparse optimizer

tags/v0.6.0-beta
wangnan39@huawei.com 5 years ago
parent
commit
1cdf6a10ac
14 changed files with 197 additions and 227 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h
  2. +1
    -13
      mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h
  3. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h
  4. +1
    -14
      mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h
  5. +1
    -1
      mindspore/nn/optim/adam.py
  6. +1
    -2
      mindspore/nn/optim/ftrl.py
  7. +1
    -1
      mindspore/nn/optim/lazyadam.py
  8. +1
    -2
      mindspore/nn/optim/proximal_ada_grad.py
  9. +6
    -3
      mindspore/ops/operations/__init__.py
  10. +0
    -180
      mindspore/ops/operations/_inner_ops.py
  11. +180
    -6
      mindspore/ops/operations/nn_ops.py
  12. +1
    -1
      tests/st/ops/cpu/test_sparse_apply_adam_op.py
  13. +1
    -1
      tests/st/ops/cpu/test_sparse_apply_ftrl_op.py
  14. +1
    -1
      tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h View File

@@ -40,7 +40,7 @@ class SparseApplyAdamCPUKernel : public CPUKernel {
bool use_nesterov_{false};
};

MS_REG_CPU_KERNEL(SparseApplyAdam,
MS_REG_CPU_KERNEL(FusedSparseAdam,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)


+ 1
- 13
mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h View File

@@ -42,19 +42,7 @@ class SparseApplyFtrlCPUKernel : public CPUKernel {
float lr_power_{0};
};

MS_REG_CPU_KERNEL(SparseApplyFtrl,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SparseApplyFtrlCPUKernel);

MS_REG_CPU_KERNEL(SparseApplyFtrlNoReturn,
MS_REG_CPU_KERNEL(FusedSparseFtrl,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h View File

@@ -40,7 +40,7 @@ class SparseApplyLazyAdamCPUKernel : public CPUKernel {
bool use_nesterov_{false};
};

MS_REG_CPU_KERNEL(SparseApplyLazyAdam,
MS_REG_CPU_KERNEL(FusedSparseLazyAdam,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)


+ 1
- 14
mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h View File

@@ -39,20 +39,7 @@ class SparseApplyProximalAdagradCPUKernel : public CPUKernel {
size_t var_outer_dim_size_{1};
};

MS_REG_CPU_KERNEL(SparseApplyProximalAdagrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SparseApplyProximalAdagradCPUKernel);

MS_REG_CPU_KERNEL(SparseApplyProximalAdagradNoReturn,
MS_REG_CPU_KERNEL(FusedSparseProximalAdagrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)


+ 1
- 1
mindspore/nn/optim/adam.py View File

@@ -299,7 +299,7 @@ class Adam(Optimizer):

self.hyper_map = C.HyperMap()
self.opt = P.Adam(use_locking, use_nesterov)
self.sparse_opt = P.SparseApplyAdam(use_locking, use_nesterov)
self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov)

def construct(self, gradients):
params = self.parameters


+ 1
- 2
mindspore/nn/optim/ftrl.py View File

@@ -16,7 +16,6 @@
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops.operations import _inner_ops as inner
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer, _apply_decay, _grad_scale
@@ -159,7 +158,7 @@ class FTRL(Optimizer):
self.decay_tf = tuple((lambda: True)() for x in self.parameters)
self.hyper_map = C.HyperMap()
self.opt = P.ApplyFtrl(use_locking=use_locking)
self.sparse_opt = inner.SparseApplyFtrlNoReturn(learning_rate, l1, l2, lr_power, use_locking=use_locking)
self.sparse_opt = P.FusedSparseFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking)

def construct(self, grads):
params = self.parameters


+ 1
- 1
mindspore/nn/optim/lazyadam.py View File

@@ -182,7 +182,7 @@ class LazyAdam(Optimizer):

self.hyper_map = C.HyperMap()
self.opt = P.Adam(use_locking, use_nesterov)
self.sparse_opt = P.SparseApplyLazyAdam(use_locking, use_nesterov)
self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov)

def construct(self, gradients):
gradients = self.decay_weight(gradients)


+ 1
- 2
mindspore/nn/optim/proximal_ada_grad.py View File

@@ -16,7 +16,6 @@
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops.operations import _inner_ops as inner
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer
@@ -101,7 +100,7 @@ class ProximalAdagrad(Optimizer):
self.weight_decay = weight_decay
self.hyper_map = C.HyperMap()
self.opt = P.ApplyProximalAdagrad(use_locking=use_locking)
self.sparse_opt = inner.SparseApplyProximalAdagradNoReturn(use_locking=use_locking)
self.sparse_opt = P.FusedSparseProximalAdagrad(use_locking=use_locking)
def construct(self, grads):
params = self.parameters


+ 6
- 3
mindspore/ops/operations/__init__.py View File

@@ -56,7 +56,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A

from .random_ops import (RandomChoiceWithMask, Normal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, Laplace)
from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm,
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
DepthwiseConv2dNative,
DropoutDoMask, DropoutGrad, Dropout,
@@ -74,6 +74,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
SparseSoftmaxCrossEntropyWithLogits, Tanh,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
ApplyProximalAdagrad, SparseApplyProximalAdagrad, SparseApplyAdagradV2, SparseApplyFtrlV2,
FusedSparseFtrl, FusedSparseProximalAdagrad,
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
@@ -114,8 +115,8 @@ __all__ = [
'MaxPool',
'TopK',
'Adam',
'SparseApplyAdam',
'SparseApplyLazyAdam',
'FusedSparseAdam',
'FusedSparseLazyAdam',
'Softplus',
'Softmax',
'Softsign',
@@ -311,8 +312,10 @@ __all__ = [
"SpaceToBatch",
"SparseApplyFtrl",
"SparseApplyFtrlV2",
"FusedSparseFtrl",
"ApplyProximalAdagrad",
"SparseApplyProximalAdagrad",
"FusedSparseProximalAdagrad",
"ApplyAdaMax",
"ApplyAdadelta",
"ApplyAdagrad",


+ 0
- 180
mindspore/ops/operations/_inner_ops.py View File

@@ -18,9 +18,6 @@
from ..._checkparam import Rel
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_dtype as sig_dtype
from ..primitive import PrimitiveWithInfer, prim_attr_register


@@ -394,183 +391,6 @@ class Dequant(PrimitiveWithInfer):
return mstype.float16


class SparseApplyFtrlNoReturn(PrimitiveWithInfer):
"""
Update relevant entries according to the FTRL-proximal scheme.

Args:
lr (float): The learning rate value, must be positive.
l1 (float): l1 regularization strength, must be greater than or equal to zero.
l2 (float): l2 regularization strength, must be greater than or equal to zero.
lr_power (float): Learning rate power controls how the learning rate decreases during training,
must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero.
use_locking (bool): Use locks for update operation if True . Default: False.

Inputs:
- **var** (Parameter): The variable to be updated. The data type must be float32.
- **accum** (Parameter): The accum to be updated, must be same type and shape as `var`.
- **linear** (Parameter): The linear to be updated, must be same type and shape as `var`.
- **grad** (Tensor): A tensor of the same type as `var`, for the gradient.
- **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. The shape
of `indices` must be the same as `grad` in first dimension. The type must be int32.

Outputs:
Tuple of 3 Tensor, this operator will update the input parameters directly, the outputs are useless.

- **var** (Tensor) - A Tensor with shape (1,).
- **accum** (Tensor) - A Tensor with shape (1,).
- **linear** (Tensor) - A Tensor with shape (1,).

Examples:
>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Parameter
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class SparseApplyFtrlNet(nn.Cell):
>>> def __init__(self):
>>> super(SparseApplyFtrlNet, self).__init__()
>>> self.sparse_apply_ftrl = P.SparseApplyFtrlV2(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
>>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
>>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
>>> self.linear = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="linear")
>>>
>>> def construct(self, grad, indices):
>>> out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
>>> return out
>>>
>>> net = SparseApplyFtrlNet()
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
>>> output = net(grad, indices)
"""
__mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
)

@prim_attr_register
def __init__(self, lr, l1, l2, lr_power, use_locking=False):
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
outputs=['output'])
validator.check_value_type("lr", lr, [float], self.name)
validator.check_value_type("l1", l1, [float], self.name)
validator.check_value_type("l2", l2, [float], self.name)
validator.check_value_type("lr_power", lr_power, [float], self.name)
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
self.add_prim_attr('primitive_target', 'CPU')

def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return [1], [1], [1]

def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
validator.check_tensor_type_same(args, [mstype.float32], self.name)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
return var_dtype, accum_dtype, linear_dtype


class SparseApplyProximalAdagradNoReturn(PrimitiveWithInfer):
r"""
Updates relevant entries according to the proximal adagrad algorithm.

.. math::
accum += grad * grad
.. math::
\text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}
.. math::
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)

Args:
use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.

Inputs:
- **var** (Parameter) - Variable tensor to be updated. The data type must be float32.
- **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
- **lr** (Tensor): The learning rate value. The data type must be float32.
- **l1** (Tensor): l1 regularization strength. The data type must be float32.
- **l2** (Tensor): l2 regularization strength. The data type must be float32.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The data type
must be int32.

Outputs:
Tuple of 2 Tensor, this operator will update the input parameters directly, the outputs are useless.

- **var** (Tensor) - A Tensor with shape (1,).
- **accum** (Tensor) - A Tensor with shape (1,).

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagradV2()
>>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
>>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
>>> self.lr = Tensor(0.01, mstype.float32)
>>> self.l1 = Tensor(0.0, mstype.float32)
>>> self.l2 = Tensor(0.0, mstype.float32)
>>> def construct(self, grad, indices):
>>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1,
>>> self.l2, grad, indices)
>>> return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
>>> output = net(grad, indices)
"""
__mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
)

@prim_attr_register
def __init__(self, use_locking=False):
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
outputs=['output'])
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
self.add_prim_attr('primitive_target', 'CPU')

def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
return [1], [1]

def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float32], self.name)
valid_types = [mstype.int16, mstype.int32, mstype.int64,
mstype.uint16, mstype.uint32, mstype.uint64]
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
return var_dtype, accum_dtype


class LinSpace(PrimitiveWithInfer):
r"""
Generates values in an interval. And return the corresponding interpolation accroding to assist.


+ 180
- 6
mindspore/ops/operations/nn_ops.py View File

@@ -2917,7 +2917,7 @@ class Adam(PrimitiveWithInfer):
return var_dtype, m_dtype, v_dtype


class SparseApplyAdam(PrimitiveWithInfer):
class FusedSparseAdam(PrimitiveWithInfer):
r"""
Merge the duplicate value of the gradient and then updates parameters by Adaptive Moment Estimation (Adam)
algorithm. This operator is used when the gradient is sparse.
@@ -2979,7 +2979,7 @@ class SparseApplyAdam(PrimitiveWithInfer):
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.sparse_apply_adam = P.SparseApplyAdam()
>>> self.sparse_apply_adam = P.FusedSparseAdam()
>>> self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var")
>>> self.m = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="m")
>>> self.v = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="v")
@@ -3025,7 +3025,6 @@ class SparseApplyAdam(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2',
'epsilon', 'grad', 'indices'],
outputs=['var', 'm', 'v'])
self.add_prim_attr('primitive_target', 'CPU')

def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape,
beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape):
@@ -3051,7 +3050,7 @@ class SparseApplyAdam(PrimitiveWithInfer):
return var_dtype, m_dtype, v_dtype


class SparseApplyLazyAdam(PrimitiveWithInfer):
class FusedSparseLazyAdam(PrimitiveWithInfer):
r"""
Merge the duplicate value of the gradient and then updates parameters by Adaptive Moment Estimation (Adam)
algorithm. This operator is used when the gradient is sparse. The behavior is not equivalent to the
@@ -3114,7 +3113,7 @@ class SparseApplyLazyAdam(PrimitiveWithInfer):
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.sparse_apply_lazyadam = P.SparseApplyLazyAdam()
>>> self.sparse_apply_lazyadam = P.FusedSparseLazyAdam()
>>> self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var")
>>> self.m = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="m")
>>> self.v = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="v")
@@ -3160,7 +3159,6 @@ class SparseApplyLazyAdam(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2',
'epsilon', 'grad', 'indices'],
outputs=['var', 'm', 'v'])
self.add_prim_attr('primitive_target', 'CPU')

def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape,
beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape):
@@ -3187,6 +3185,182 @@ class SparseApplyLazyAdam(PrimitiveWithInfer):
return var_dtype, m_dtype, v_dtype


class FusedSparseFtrl(PrimitiveWithInfer):
"""
Merge the duplicate value of the gradient and then update relevant entries according to the FTRL-proximal scheme.

Args:
lr (float): The learning rate value, must be positive.
l1 (float): l1 regularization strength, must be greater than or equal to zero.
l2 (float): l2 regularization strength, must be greater than or equal to zero.
lr_power (float): Learning rate power controls how the learning rate decreases during training,
must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero.
use_locking (bool): Use locks for update operation if True . Default: False.

Inputs:
- **var** (Parameter): The variable to be updated. The data type must be float32.
- **accum** (Parameter): The accum to be updated, must be same type and shape as `var`.
- **linear** (Parameter): The linear to be updated, must be same type and shape as `var`.
- **grad** (Tensor): A tensor of the same type as `var`, for the gradient.
- **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. The shape
of `indices` must be the same as `grad` in first dimension. The type must be int32.

Outputs:
Tuple of 3 Tensor, this operator will update the input parameters directly, the outputs are useless.

- **var** (Tensor) - A Tensor with shape (1,).
- **accum** (Tensor) - A Tensor with shape (1,).
- **linear** (Tensor) - A Tensor with shape (1,).

Examples:
>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Parameter
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class SparseApplyFtrlNet(nn.Cell):
>>> def __init__(self):
>>> super(SparseApplyFtrlNet, self).__init__()
>>> self.sparse_apply_ftrl = P.FusedSparseFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
>>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
>>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
>>> self.linear = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="linear")
>>>
>>> def construct(self, grad, indices):
>>> out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
>>> return out
>>>
>>> net = SparseApplyFtrlNet()
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
>>> output = net(grad, indices)
"""
__mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
)

@prim_attr_register
def __init__(self, lr, l1, l2, lr_power, use_locking=False):
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
outputs=['output'])
validator.check_value_type("lr", lr, [float], self.name)
validator.check_value_type("l1", l1, [float], self.name)
validator.check_value_type("l2", l2, [float], self.name)
validator.check_value_type("lr_power", lr_power, [float], self.name)
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)

def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return [1], [1], [1]

def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
validator.check_tensor_type_same(args, [mstype.float32], self.name)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
return var_dtype, accum_dtype, linear_dtype


class FusedSparseProximalAdagrad(PrimitiveWithInfer):
r"""
Merge the duplicate value of the gradient and then Updates relevant entries according to the proximal adagrad
algorithm.

.. math::
accum += grad * grad
.. math::
\text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}
.. math::
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)

Args:
use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.

Inputs:
- **var** (Parameter) - Variable tensor to be updated. The data type must be float32.
- **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
- **lr** (Tensor): The learning rate value. The data type must be float32.
- **l1** (Tensor): l1 regularization strength. The data type must be float32.
- **l2** (Tensor): l2 regularization strength. The data type must be float32.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The data type
must be int32.

Outputs:
Tuple of 2 Tensor, this operator will update the input parameters directly, the outputs are useless.

- **var** (Tensor) - A Tensor with shape (1,).
- **accum** (Tensor) - A Tensor with shape (1,).

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.sparse_apply_proximal_adagrad = P.FusedSparseProximalAdagrad()
>>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
>>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
>>> self.lr = Tensor(0.01, mstype.float32)
>>> self.l1 = Tensor(0.0, mstype.float32)
>>> self.l2 = Tensor(0.0, mstype.float32)
>>> def construct(self, grad, indices):
>>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1,
>>> self.l2, grad, indices)
>>> return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
>>> output = net(grad, indices)
"""
__mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
)

@prim_attr_register
def __init__(self, use_locking=False):
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
outputs=['output'])
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)

def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
return [1], [1]

def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float32], self.name)
valid_types = [mstype.int16, mstype.int32, mstype.int64,
mstype.uint16, mstype.uint32, mstype.uint64]
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
return var_dtype, accum_dtype


class BinaryCrossEntropy(PrimitiveWithInfer):
r"""
Computes the Binary Cross Entropy between the target and the output.


+ 1
- 1
tests/st/ops/cpu/test_sparse_apply_adam_op.py View File

@@ -33,7 +33,7 @@ epsilon = 1e-8
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.sparse_apply_adam = P.SparseApplyAdam()
self.sparse_apply_adam = P.FusedSparseAdam()
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
self.m = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="m")
self.v = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="v")


+ 1
- 1
tests/st/ops/cpu/test_sparse_apply_ftrl_op.py View File

@@ -26,7 +26,7 @@ import mindspore.common.dtype as mstype
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5)
self.sparse_apply_ftrl = P.FusedSparseFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5)
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum")
self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="linear")


+ 1
- 1
tests/st/ops/cpu/test_sparse_apply_proximal_adagrad_op.py View File

@@ -26,7 +26,7 @@ import mindspore.common.dtype as mstype
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad()
self.sparse_apply_proximal_adagrad = P.FusedSparseProximalAdagrad()
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum")
self.lr = 0.01


Loading…
Cancel
Save