Browse Source

Add optimizer operators for VM.

tags/v0.6.0-beta
liuxiao 5 years ago
parent
commit
2097a0e90a
9 changed files with 665 additions and 9 deletions
  1. +2
    -0
      mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
  2. +4
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  3. +65
    -0
      mindspore/ops/_op_impl/tbe/apply_add_sign.py
  4. +44
    -0
      mindspore/ops/_op_impl/tbe/apply_gradient_descent.py
  5. +65
    -0
      mindspore/ops/_op_impl/tbe/apply_power_sign.py
  6. +54
    -0
      mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py
  7. +5
    -0
      mindspore/ops/operations/__init__.py
  8. +352
    -9
      mindspore/ops/operations/nn_ops.py
  9. +74
    -0
      tests/ut/python/ops/test_ops.py

+ 2
- 0
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc View File

@@ -77,6 +77,8 @@ static std::map<string, string> tbe_func_adapter_map = {
{"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
{"apply_proximal_adagrad", "apply_proximal_adagrad_d"},
{"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"},
{"apply_add_sign", "apply_add_sign_d"},
{"apply_power_sign", "apply_power_sign_d"},
{"transpose", "transpose_d"},
{"fill", "fill_d"},
{"unsorted_segment_sum", "unsorted_segment_sum_d"},


+ 4
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -34,6 +34,10 @@ from .apply_ada_max import _apply_ada_max_tbe
from .apply_adadelta import _apply_adadelta_tbe
from .apply_adagrad import _apply_adagrad_tbe
from .apply_adagrad_v2 import _apply_adagrad_v2_tbe
from .apply_add_sign import _apply_add_sign_tbe
from .apply_power_sign import _apply_power_sign_tbe
from .apply_gradient_descent import _apply_gradient_descent_tbe
from .apply_proximal_gradient_descent import _apply_proximal_gradient_descent_tbe
from .approximate_equal import _approximate_equal_tbe
from .adam_apply_one import _adam_apply_one_tbe
from .assign import _assign_tbe


+ 65
- 0
mindspore/ops/_op_impl/tbe/apply_add_sign.py View File

@@ -0,0 +1,65 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""ApplyAddSignD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

apply_add_sign_d_op_info = TBERegOp("ApplyAddSign") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_add_sign_d.so") \
.compute_cost(10) \
.kernel_name("apply_add_sign_d") \
.partial_flag(True) \
.input(0, "var", False, "required", "all") \
.input(1, "m", False, "required", "all") \
.input(2, "lr", False, "required", "all") \
.input(3, "alpha", False, "required", "all") \
.input(4, "sign_decay", False, "required", "all") \
.input(5, "beta", False, "required", "all") \
.input(6, "grad", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "m", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ,
DataType.F16_FracZ) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.get_op_info()


@op_info_register(apply_add_sign_d_op_info)
def _apply_add_sign_tbe():
"""ApplyAddSignD TBE register"""
return

+ 44
- 0
mindspore/ops/_op_impl/tbe/apply_gradient_descent.py View File

@@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""ApplyGradientDescent op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

apply_gradient_descent_op_info = TBERegOp("ApplyGradientDescent") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_gradient_descent.so") \
.compute_cost(10) \
.kernel_name("apply_gradient_descent") \
.partial_flag(True) \
.input(0, "var", False, "required", "all") \
.input(1, "alpha", False, "required", "all") \
.input(2, "delta", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(apply_gradient_descent_op_info)
def _apply_gradient_descent_tbe():
"""ApplyGradientDescent TBE register"""
return

+ 65
- 0
mindspore/ops/_op_impl/tbe/apply_power_sign.py View File

@@ -0,0 +1,65 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""ApplyPowerSignD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

apply_power_sign_d_op_info = TBERegOp("ApplyPowerSign") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_power_sign_d.so") \
.compute_cost(10) \
.kernel_name("apply_power_sign_d") \
.partial_flag(True) \
.input(0, "var", False, "required", "all") \
.input(1, "m", False, "required", "all") \
.input(2, "lr", False, "required", "all") \
.input(3, "logbase", False, "required", "all") \
.input(4, "sign_decay", False, "required", "all") \
.input(5, "beta", False, "required", "all") \
.input(6, "grad", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "m", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ,
DataType.F16_FracZ) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.get_op_info()


@op_info_register(apply_power_sign_d_op_info)
def _apply_power_sign_tbe():
"""ApplyPowerSignD TBE register"""
return

+ 54
- 0
mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py View File

@@ -0,0 +1,54 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""ApplyProximalGradientDescent op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

apply_proximal_gradient_descent_op_info = TBERegOp("ApplyProximalGradientDescent") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_proximal_gradient_descent.so") \
.compute_cost(10) \
.kernel_name("apply_proximal_gradient_descent") \
.partial_flag(True) \
.input(0, "var", False, "required", "all") \
.input(1, "alpha", False, "required", "all") \
.input(2, "l1", False, "required", "all") \
.input(3, "l2", False, "required", "all") \
.input(4, "delta", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(apply_proximal_gradient_descent_op_info)
def _apply_proximal_gradient_descent_tbe():
"""ApplyProximalGradientDescent TBE register"""
return

+ 5
- 0
mindspore/ops/operations/__init__.py View File

@@ -74,6 +74,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
ApplyProximalAdagrad, SparseApplyProximalAdagrad,
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
CheckValid, MakeRefKey, Partial, Depend, CheckBprop)
@@ -295,6 +296,10 @@ __all__ = [
"ApplyAdadelta",
"ApplyAdagrad",
"ApplyAdagradV2",
"ApplyAddSign",
"ApplyPowerSign",
"ApplyGradientDescent",
"ApplyProximalGradientDescent",
"BatchToSpace",
"Atan2",
"ApplyRMSProp",


+ 352
- 9
mindspore/ops/operations/nn_ops.py View File

@@ -3382,7 +3382,7 @@ class ApplyAdagrad(PrimitiveWithInfer):
- **var** (Parameter) - Variable to be updated. With float32 or float16 data type.
- **accum** (Parameter) - Accum to be updated. The shape and dtype should be the same as `var`.
With float32 or float16 data type.
- **lr** (Union[Number, Tensor]): The learning rate value, should be scalar. With float32 or float16 data type.
- **lr** (Union[Number, Tensor]) - The learning rate value, should be scalar. With float32 or float16 data type.
- **grad** (Tensor) - A tensor for gradient. The shape and dtype should be the same as `var`.
With float32 or float16 data type.

@@ -3458,7 +3458,7 @@ class ApplyAdagradV2(PrimitiveWithInfer):
- **var** (Parameter) - Variable to be updated. With float32 or float16 data type.
- **accum** (Parameter) - Accum to be updated. The shape and dtype should be the same as `var`.
With float32 or float16 data type.
- **lr** (Union[Number, Tensor]): The learning rate value, should be scalar. With float32 or float16 data type.
- **lr** (Union[Number, Tensor]) - The learning rate value, should be scalar. With float32 or float16 data type.
- **grad** (Tensor) - A tensor for gradient. The shape and dtype should be the same as `var`.
With float32 or float16 data type.

@@ -3615,11 +3615,11 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
Inputs:
- **var** (Parameter) - Variable to be updated. The data type should be float16 or float32.
- **accum** (Parameter) - Accum to be updated. Must has the same shape and dtype as `var`.
- **lr** (Union[Number, Tensor]): The learning rate value, should be scalar. The data type should be
- **lr** (Union[Number, Tensor]) - The learning rate value, should be scalar. The data type should be
float16 or float32.
- **l1** (Union[Number, Tensor]): l1 regularization strength, should be scalar. The data type should be
- **l1** (Union[Number, Tensor]) - l1 regularization strength, should be scalar. The data type should be
float16 or float32.
- **l2** (Union[Number, Tensor]): l2 regularization strength, should be scalar. The data type should be
- **l2** (Union[Number, Tensor]) - l2 regularization strength, should be scalar. The data type should be
float16 or float32.
- **grad** (Tensor) - Gradient. Must has the same shape and dtype as `var`.

@@ -3710,9 +3710,9 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
Inputs:
- **var** (Parameter) - Variable tensor to be updated. The data type must be float32.
- **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
- **lr** (Union[Number, Tensor]): The learning rate value. The data type must be float32.
- **l1** (Union[Number, Tensor]): l1 regularization strength. The data type must be float32.
- **l2** (Union[Number, Tensor]): l2 regularization strength. The data type must be float32.
- **lr** (Union[Number, Tensor]) - The learning rate value. The data type must be float32.
- **l1** (Union[Number, Tensor]) - l1 regularization strength. The data type must be float32.
- **l2** (Union[Number, Tensor]) - l2 regularization strength. The data type must be float32.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.

@@ -3759,7 +3759,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=False):
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
outputs=['output'])
outputs=['var', 'accum'])
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)

def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
@@ -3778,6 +3778,349 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
return var_dtype, accum_dtype


class ApplyAddSign(PrimitiveWithInfer):
r"""
Update relevant entries according to the AddSign algorithm.

.. math::
\begin{array}{ll} \\
m_{t} = \beta * m_{t-1} + (1 - \beta) * g \\
\text{update} = (\alpha + \text{sign_decay} * sign(g) * sign(m)) * g \\
var = var - lr_{t} * \text{update}
\end{array}

:math:`t` represents updating step while, :math:`m` represents the 1st moment vector, :math:`m_{t-1}`
is the last momentent of :math:`m_{t}`, :math:`lr` represents scaling factor `lr`, :math:`g` represents `grad`.

Inputs:
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
- **m** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
- **lr** (Union[Number, Tensor]) - The learning rate value, should be a scalar.
With float32 or float16 data type.
- **alpha** (Union[Number, Tensor]) - Should be a scalar. With float32 or float16 data type.
- **sign_decay** (Union[Number, Tensor]) - Should be a scalar. With float32 or float16 data type.
- **beta** (Union[Number, Tensor]) - The exponential decay rate, should be a scalar.
With float32 or float16 data type.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.

Outputs:
Tuple of 2 Tensor, the updated parameters.

- **var** (Tensor) - The same shape and data type as `var`.
- **m** (Tensor) - The same shape and data type as `m`.

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.apply_add_sign = P.ApplyAddSign()
>>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>> self.m = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="m")
>>> self.lr = 0.001
>>> self.alpha = 1.0
>>> self.sign_decay = 0.99
>>> self.beta = 0.9
>>> def construct(self, grad):
>>> out = self.apply_add_sign(self.var, self.m, self.lr, self.alpha, self.sign_decay, self.beta, grad)
>>> return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> output = net(grad)
"""

__mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('alpha', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2),
('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
sig_dtype.T3),
('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4),
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)

@prim_attr_register
def __init__(self):
"init ApplyAddSign"

def infer_shape(self, var_shape, m_shape, lr_shape, alpha_shape, sign_decay_shape, beta_shape, grad_shape):
validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name)
validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name)
lr_shape_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shape_len, 1, Rel.LE, self.name)
if lr_shape_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
alpha_shape_len = len(alpha_shape)
validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name)
if alpha_shape_len == 1:
validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name)
sign_decay_shape_len = len(sign_decay_shape)
validator.check_integer("sign_decay's rank", sign_decay_shape_len, 1, Rel.LE, self.name)
if sign_decay_shape_len == 1:
validator.check_integer("sign_decay_shape[0]", sign_decay_shape[0], 1, Rel.EQ, self.name)
beta_shape_len = len(beta_shape)
validator.check_integer("beta's rank", beta_shape_len, 1, Rel.LE, self.name)
if beta_shape_len == 1:
validator.check_integer("beta_shape[0]", beta_shape[0], 1, Rel.EQ, self.name)
return var_shape, m_shape

def infer_dtype(self, var_dtype, m_dtype, lr_dtype, alpha_dtype, sign_decay_dtype, beta_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"alpha": alpha_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"sign_decay": sign_decay_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"beta": beta_dtype}, valid_types, self.name)
return var_dtype, m_dtype


class ApplyPowerSign(PrimitiveWithInfer):
r"""
Update relevant entries according to the AddSign algorithm.

.. math::
\begin{array}{ll} \\
m_{t} = \beta * m_{t-1} + (1 - \beta) * g \\
\text{update} = \exp(\text{logbase} * \text{sign_decay} * sign(g) * sign(m)) * g \\
var = var - lr_{t} * \text{update}
\end{array}

:math:`t` represents updating step while, :math:`m` represents the 1st moment vector, :math:`m_{t-1}`
is the last momentent of :math:`m_{t}`, :math:`lr` represents scaling factor `lr`, :math:`g` represents `grad`.

Inputs:
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
- **m** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
- **lr** (Union[Number, Tensor]) - The learning rate value, should be a scalar.
With float32 or float16 data type.
- **logbase** (Union[Number, Tensor]) - Should be a scalar. With float32 or float16 data type.
- **sign_decay** (Union[Number, Tensor]) - Should be a scalar. With float32 or float16 data type.
- **beta** (Union[Number, Tensor]) - The exponential decay rate, should be a scalar.
With float32 or float16 data type.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.

Outputs:
Tuple of 2 Tensor, the updated parameters.

- **var** (Tensor) - The same shape and data type as `var`.
- **m** (Tensor) - The same shape and data type as `m`.

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.apply_power_sign = P.ApplyPowerSign()
>>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>> self.m = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="m")
>>> self.lr = 0.001
>>> self.logbase = np.e
>>> self.sign_decay = 0.99
>>> self.beta = 0.9
>>> def construct(self, grad):
>>> out = self.apply_power_sign(self.var, self.m, self.lr, self.logbase,
self.sign_decay, self.beta, grad)
>>> return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> output = net(grad)
"""

__mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2),
('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
sig_dtype.T3),
('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4),
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)

@prim_attr_register
def __init__(self):
"init ApplyPowerSign"

def infer_shape(self, var_shape, m_shape, lr_shape, logbase_shape, sign_decay_shape, beta_shape, grad_shape):
validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name)
validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name)
lr_shape_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shape_len, 1, Rel.LE, self.name)
if lr_shape_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
logbase_shape_len = len(logbase_shape)
validator.check_integer("logbase's rank", logbase_shape_len, 1, Rel.LE, self.name)
if logbase_shape_len == 1:
validator.check_integer("logbase_shape[0]", logbase_shape[0], 1, Rel.EQ, self.name)
sign_decay_shape_len = len(sign_decay_shape)
validator.check_integer("sign_decay's rank", sign_decay_shape_len, 1, Rel.LE, self.name)
if sign_decay_shape_len == 1:
validator.check_integer("sign_decay_shape[0]", sign_decay_shape[0], 1, Rel.EQ, self.name)
beta_shape_len = len(beta_shape)
validator.check_integer("beta's rank", beta_shape_len, 1, Rel.LE, self.name)
if beta_shape_len == 1:
validator.check_integer("beta_shape[0]", beta_shape[0], 1, Rel.EQ, self.name)
return var_shape, m_shape

def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype, beta_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"logbase": logbase_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"sign_decay": sign_decay_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"beta": beta_dtype}, valid_types, self.name)
return var_dtype, m_dtype


class ApplyGradientDescent(PrimitiveWithInfer):
r"""
Update relevant entries according to the following formula.

.. math::
var = var - \alpha * \delta

Inputs:
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
- **alpha** (Union[Number, Tensor]) - Scaling factor, should be a scalar. With float32 or float16 data type.
- **delta** (Tensor) - A tensor for the change. Has the same type as `var`.

Outputs:
Tensor, representing the updated var.

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.apply_gradient_descent = P.ApplyGradientDescent()
>>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>> self.alpha = 0.001
>>> def construct(self, delta):
>>> out = self.apply_gradient_descent(self.var, self.alpha, delta)
>>> return out
>>> net = Net()
>>> delta = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> output = net(delta)
"""

__mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('alpha', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('delta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)

@prim_attr_register
def __init__(self):
"init ApplyGradientDescent"

def infer_shape(self, var_shape, alpha_shape, delta_shape):
validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name)
alpha_shape_len = len(alpha_shape)
validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name)
if alpha_shape_len == 1:
validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name)
return var_shape

def infer_dtype(self, var_dtype, alpha_dtype, delta_dtype):
valid_types = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'delta': delta_dtype}
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"alpha": alpha_dtype}, valid_types, self.name)
return var_dtype


class ApplyProximalGradientDescent(PrimitiveWithInfer):
r"""
Update relevant entries according to the FOBOS(Forward Backward Splitting) algorithm.

.. math::
\text{prox_v} = var - \alpha * \delta
.. math::
var = \frac{sign(\text{prox_v})}{1 + \alpha * l2} * \max(\left| \text{prox_v} \right| - alpha * l1, 0)

Inputs:
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
- **alpha** (Union[Number, Tensor]) - Saling factor, should be a scalar. With float32 or float16 data type.
- **l1** (Union[Number, Tensor]) - l1 regularization strength, should be scalar.
With float32 or float16 data type.
- **l2** (Union[Number, Tensor]) - l2 regularization strength, should be scalar.
With float32 or float16 data type.
- **delta** (Tensor) - A tensor for the change. Has the same type as `var`.

Outputs:
Tensor, representing the updated var.

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.apply_proximal_gradient_descent = P.ApplyProximalGradientDescent()
>>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>> self.alpha = 0.001
>>> self.l1 = 0.0
>>> self.l2 = 0.0
>>> def construct(self, delta):
>>> out = self.apply_proximal_gradient_descent(self.var, self.alpha, self.l1, self.l2, delta)
>>> return out
>>> net = Net()
>>> delta = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> output = net(delta)
"""

__mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('alpha', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2),
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3),
('delta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)

@prim_attr_register
def __init__(self):
"init ApplyGradientDescent"

def infer_shape(self, var_shape, alpha_shape, l1_shape, l2_shape, delta_shape):
validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name)
alpha_shape_len = len(alpha_shape)
validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name)
if alpha_shape_len == 1:
validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name)
l1_shape_len = len(l1_shape)
validator.check_integer("l1's rank", l1_shape_len, 1, Rel.LE, self.name)
if l1_shape_len == 1:
validator.check_integer("l1_shape[0]", l1_shape[0], 1, Rel.EQ, self.name)
l2_shape_len = len(l2_shape)
validator.check_integer("l2's rank", l2_shape_len, 1, Rel.LE, self.name)
if l2_shape_len == 1:
validator.check_integer("l2_shape[0]", l2_shape[0], 1, Rel.EQ, self.name)
return var_shape

def infer_dtype(self, var_dtype, alpha_dtype, l1_dtype, l2_dtype, delta_dtype):
valid_types = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'delta': delta_dtype}
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"alpha": alpha_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, valid_types, self.name)
return var_dtype


class LARSUpdate(PrimitiveWithInfer):
"""
Conduct lars (layer-wise adaptive rate scaling) update on the square sum of gradient.


+ 74
- 0
tests/ut/python/ops/test_ops.py View File

@@ -351,6 +351,64 @@ class ApplyAdagradV2Net(nn.Cell):
return out


class ApplyAddSignNet(nn.Cell):
def __init__(self):
super(ApplyAddSignNet, self).__init__()
self.apply_add_sign = P.ApplyAddSign()
self.lr = 0.001
self.alpha = 1.0
self.sign_decay = 0.99
self.beta = 0.99
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
self.m = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="m")

def construct(self, grad):
out = self.apply_add_sign(self.var, self.m, self.lr, self.alpha, self.sign_decay, self.beta, grad)
return out


class ApplyPowerSignNet(nn.Cell):
def __init__(self):
super(ApplyPowerSignNet, self).__init__()
self.apply_power_sign = P.ApplyPowerSign()
self.lr = 0.001
self.logbase = np.e
self.sign_decay = 0.99
self.beta = 0.99
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
self.m = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="m")

def construct(self, grad):
out = self.apply_power_sign(self.var, self.m, self.lr, self.logbase, self.sign_decay, self.beta, grad)
return out


class ApplyGradientDescentNet(nn.Cell):
def __init__(self):
super(ApplyGradientDescentNet, self).__init__()
self.apply_gradient_descent = P.ApplyGradientDescent()
self.alpha = 0.001
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")

def construct(self, delta):
out = self.apply_gradient_descent(self.var, self.alpha, delta)
return out


class ApplyProximalGradientDescentNet(nn.Cell):
def __init__(self):
super(ApplyProximalGradientDescentNet, self).__init__()
self.apply_proximal_gradient_descent = P.ApplyProximalGradientDescent()
self.alpha = 0.001
self.l1 = 0.0
self.l2 = 0.0
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")

def construct(self, delta):
out = self.apply_proximal_gradient_descent(self.var, self.alpha, self.l1, self.l2, delta)
return out


class SparseApplyAdagradNet(nn.Cell):
def __init__(self):
super(SparseApplyAdagradNet, self).__init__()
@@ -1241,6 +1299,22 @@ test_case_nn_ops = [
'block': ApplyAdagradV2Net(),
'desc_inputs': [[3, 3]],
'skip': ['backward']}),
('ApplyAddSign', {
'block': ApplyAddSignNet(),
'desc_inputs': [[3, 3]],
'skip': ['backward']}),
('ApplyPowerSign', {
'block': ApplyPowerSignNet(),
'desc_inputs': [[3, 3]],
'skip': ['backward']}),
('ApplyGradientDescent', {
'block': ApplyGradientDescentNet(),
'desc_inputs': [[3, 3]],
'skip': ['backward']}),
('ApplyProximalGradientDescent', {
'block': ApplyProximalGradientDescentNet(),
'desc_inputs': [[3, 3]],
'skip': ['backward']}),
('Flatten_1', {
'block': NetForFlatten(),
'desc_inputs': [Tensor(np.ones([2, 3, 4]).astype(np.int32)), Tensor(np.ones([2, 12]).astype(np.int32))],


Loading…
Cancel
Save