diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index cbc31415ec..c38f48763e 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -70,11 +70,13 @@ static std::map tbe_func_adapter_map = { {"strided_slice", "strided_slice_d"}, {"strided_slice_grad", "strided_slice_grad_d"}, {"sparse_apply_ftrl", "sparse_apply_ftrl_d"}, + {"sparse_apply_ftrl_v2", "sparse_apply_ftrl_v2_d"}, {"apply_ada_max", "apply_ada_max_d"}, {"apply_adadelta", "apply_adadelta_d"}, {"apply_adagrad", "apply_adagrad_d"}, {"apply_adagrad_v2", "apply_adagradv2_d"}, {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, + {"sparse_apply_adagrad_v2", "sparse_apply_adagrad_v2_d"}, {"apply_proximal_adagrad", "apply_proximal_adagrad_d"}, {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, {"apply_add_sign", "apply_add_sign_d"}, diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index be3d313554..fa2be6d515 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -38,6 +38,8 @@ from .apply_add_sign import _apply_add_sign_tbe from .apply_power_sign import _apply_power_sign_tbe from .apply_gradient_descent import _apply_gradient_descent_tbe from .apply_proximal_gradient_descent import _apply_proximal_gradient_descent_tbe +from .sparse_apply_ftrl_v2 import _sparse_apply_ftrl_v2_tbe +from .sparse_apply_adagrad_v2 import _sparse_apply_adagrad_v2_tbe from .approximate_equal import _approximate_equal_tbe from .adam_apply_one import _adam_apply_one_tbe from .assign import _assign_tbe diff --git a/mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py b/mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py new file mode 100644 index 0000000000..088edb60d3 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseApplyAdagradV2D op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sparse_apply_adagrad_v2_d_op_info = TBERegOp("SparseApplyAdagradV2") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("sparse_apply_adagrad_v2_d.so") \ + .compute_cost(10) \ + .kernel_name("sparse_apply_adagrad_v2_d") \ + .partial_flag(True) \ + .attr("lr", "required", "float", "all") \ + .attr("epsilon", "required", "float", "all") \ + .attr("use_locking", "optional", "bool", "all") \ + .attr("update_slots", "optional", "bool", "all") \ + .input(0, "var", False, "required", "all") \ + .input(1, "accum", False, "required", "all") \ + .input(2, "grad", False, "required", "all") \ + .input(3, "indices", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .output(1, "accum", False, "required", "all") \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, + DataType.F32_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, + DataType.F32_NHWC, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(sparse_apply_adagrad_v2_d_op_info) +def _sparse_apply_adagrad_v2_tbe(): + """SparseApplyAdagradV2D TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py b/mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py new file mode 100644 index 0000000000..518c524010 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py @@ -0,0 +1,52 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseApplyFtrlV2D op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sparse_apply_ftrl_v2_d_op_info = TBERegOp("SparseApplyFtrlV2") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("sparse_apply_ftrl_v2_d.so") \ + .compute_cost(10) \ + .kernel_name("sparse_apply_ftrl_v2_d") \ + .partial_flag(True) \ + .attr("lr", "required", "float", "all") \ + .attr("l1", "required", "float", "all") \ + .attr("l2", "required", "float", "all") \ + .attr("l2_shrinkage", "required", "float", "all") \ + .attr("lr_power", "required", "float", "all") \ + .attr("use_locking", "optional", "bool", "true,false", "false") \ + .input(0, "var", False, "required", "all") \ + .input(1, "accum", False, "required", "all") \ + .input(2, "linear", False, "required", "all") \ + .input(3, "grad", False, "required", "all") \ + .input(4, "indices", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .output(1, "accum", False, "required", "all") \ + .output(2, "linear", False, "required", "all") \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, + DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, + DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(sparse_apply_ftrl_v2_d_op_info) +def _sparse_apply_ftrl_v2_tbe(): + """SparseApplyFtrlV2D TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index bc4edce193..06a19d2db7 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl SoftmaxCrossEntropyWithLogits, ROIAlign, SparseSoftmaxCrossEntropyWithLogits, Tanh, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, - ApplyProximalAdagrad, SparseApplyProximalAdagrad, + ApplyProximalAdagrad, SparseApplyProximalAdagrad, SparseApplyAdagradV2, SparseApplyFtrlV2, ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) @@ -284,6 +284,7 @@ __all__ = [ "Abs", "BinaryCrossEntropy", "SparseApplyAdagrad", + "SparseApplyAdagradV2", "SpaceToDepth", "DepthToSpace", "Conv2DBackpropInput", @@ -294,6 +295,7 @@ __all__ = [ "ApplyFtrl", "SpaceToBatch", "SparseApplyFtrl", + "SparseApplyFtrlV2", "ApplyProximalAdagrad", "SparseApplyProximalAdagrad", "ApplyAdaMax", diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index c07f072f38..6320e9e011 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3600,6 +3600,88 @@ class SparseApplyAdagrad(PrimitiveWithInfer): return var_type, accum_type +class SparseApplyAdagradV2(PrimitiveWithInfer): + r""" + Update relevant entries according to the adagrad scheme. + + .. math:: + accum += grad * grad + .. math:: + var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon} + + Args: + lr (float): Learning rate. + epsilon (float): A small value added for numerical stability. + use_locking (bool): If `True`, updating of the var and accum tensors will be protected. Default: False. + update_slots (bool): If `True`, the computation logic will be different to `False`. Default: True. + + Inputs: + - **var** (Parameter) - Variable to be updated. The type must be float32. + - **accum** (Parameter) - Accum to be updated. The shape must be the same as `var`'s shape, + the type must be float32. + - **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape except first dimension, + the type must be float32. + - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. + The shape of `indices` must be the same as `grad` in first dimension, the type must be int32. + + Outputs: + Tuple of 2 Tensor, the updated parameters. + + - **var** (Tensor) - The same shape and data type as `var`. + - **accum** (Tensor) - The same shape and data type as `accum`. + + Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore import Tensor, Parameter + >>> from mindspore.ops import operations as P + >>> import mindspore.common.dtype as mstype + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.sparse_apply_adagrad_v2 = P.SparseApplyAdagradV2(lr=1e-8, epsilon=1e-6) + >>> self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var") + >>> self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum") + >>> + >>> def construct(self, grad, indices): + >>> out = self.sparse_apply_adagrad_v2(self.var, self.accum, grad, indices) + >>> return out + >>> net = Net() + >>> grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) + >>> indices = Tensor([0, 1, 2], mstype.int32) + >>> result = net(grad, indices) + """ + + __mindspore_signature__ = ( + ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) + ) + + @prim_attr_register + def __init__(self, lr, epsilon, use_locking=False, update_slots=True): + self.lr = validator.check_value_type("lr", lr, [float], self.name) + self.epsilon = validator.check_value_type("epsilon", epsilon, [float], self.name) + self.use_locking = validator.check_value_type("update_slots", update_slots, [bool], self.name) + self.update_slots = validator.check_value_type("use_locking", use_locking, [bool], self.name) + + def infer_shape(self, var_shape, accum_shape, grad_shape, indices_shape): + validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) + validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name) + if len(var_shape) > 1: + validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) + validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) + return var_shape, accum_shape + + def infer_dtype(self, var_type, accum_type, grad_type, indices_type): + args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} + validator.check_tensor_type_same(args, [mstype.float32], self.name) + validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) + return var_type, accum_type + + class ApplyProximalAdagrad(PrimitiveWithInfer): r""" Update relevant entries according to the proximal adagrad algorithm. @@ -3664,7 +3746,8 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=False): - self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['output']) + self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], + outputs=['var', 'accum']) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape): @@ -4371,6 +4454,98 @@ class SparseApplyFtrl(PrimitiveWithInfer): return var_dtype, accum_dtype, linear_dtype +class SparseApplyFtrlV2(PrimitiveWithInfer): + """ + Update relevant entries according to the FTRL-proximal scheme. + + Args: + lr (float): The learning rate value, must be positive. + l1 (float): l1 regularization strength, must be greater than or equal to zero. + l2 (float): l2 regularization strength, must be greater than or equal to zero. + l2_shrinkage (float): L2 shrinkage regularization. + lr_power (float): Learning rate power controls how the learning rate decreases during training, + must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero. + use_locking (bool): If `True`, updating of the var and accum tensors will be protected. Default: False. + + Inputs: + - **var** (Parameter) - The variable to be updated. The data type must be float32. + - **accum** (Parameter) - The accum to be updated, must be same type and shape as `var`. + - **linear** (Parameter) - The linear to be updated, must be same type and shape as `var`. + - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. + - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. + The shape of `indices` must be the same as `grad` in first dimension. The type must be int32. + + Outputs: + Tuple of 3 Tensor, the updated parameters. + + - **var** (Tensor): Tensor, has the same shape and type as `var`. + - **accum** (Tensor): Tensor, has the same shape and type as `accum`. + - **linear** (Tensor): Tensor, has the same shape and type as `linear`. + + Examples: + >>> import mindspore + >>> import mindspore.nn as nn + >>> import numpy as np + >>> from mindspore import Parameter + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> class SparseApplyFtrlV2Net(nn.Cell): + >>> def __init__(self): + >>> super(SparseApplyFtrlV2Net, self).__init__() + >>> self.sparse_apply_ftrl_v2 = P.SparseApplyFtrlV2(lr=0.01, l1=0.0, l2=0.0, + l2_shrinkage=0.0, lr_power=-0.5) + >>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") + >>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") + >>> self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear") + >>> + >>> def construct(self, grad, indices): + >>> out = self.sparse_apply_ftrl_v2(self.var, self.accum, self.linear, grad, indices) + >>> return out + >>> + >>> net = SparseApplyFtrlV2Net() + >>> grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + >>> indices = Tensor(np.ones([3]), mindspore.int32) + >>> output = net(grad, indices) + """ + + __mindspore_signature__ = ( + ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) + ) + + @prim_attr_register + def __init__(self, lr, l1, l2, l2_shrinkage, lr_power, use_locking=False): + validator.check_value_type("lr", lr, [float], self.name) + validator.check_value_type("l1", l1, [float], self.name) + validator.check_value_type("l2", l2, [float], self.name) + validator.check_value_type("lr_power", lr_power, [float], self.name) + self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) + self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) + self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) + self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) + self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) + + def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): + validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) + validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) + if len(var_shape) > 1: + validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) + validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) + return var_shape, accum_shape, linear_shape + + def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): + args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, + "linear_dtype": linear_dtype, "grad_dtype": grad_dtype} + validator.check_tensor_type_same(args, [mstype.float32], self.name) + validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) + return var_dtype, accum_dtype, linear_dtype + + class ConfusionMulGrad(PrimitiveWithInfer): """ `output0` is the result of which input0 dot multily input1. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 98a7b766e7..f55d42e28b 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -306,6 +306,19 @@ class SparseApplyFtrlNet(nn.Cell): return out +class SparseApplyFtrlV2Net(nn.Cell): + def __init__(self): + super(SparseApplyFtrlV2Net, self).__init__() + self.sparse_apply_ftrl_v2 = P.SparseApplyFtrlV2(lr=0.001, l1=0.0, l2=0.0, l2_shrinkage=0.0, lr_power=-0.5) + self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") + self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") + self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear") + + def construct(self, grad, indices): + out = self.sparse_apply_ftrl_v2(self.var, self.accum, self.linear, grad, indices) + return out + + class SparseApplyProximalAdagradNet(nn.Cell): def __init__(self): super(SparseApplyProximalAdagradNet, self).__init__() @@ -467,6 +480,18 @@ class SparseApplyAdagradNet(nn.Cell): return out +class SparseApplyAdagradV2Net(nn.Cell): + def __init__(self): + super(SparseApplyAdagradV2Net, self).__init__() + self.sparse_apply_adagrad_v2 = P.SparseApplyAdagradV2(lr=0.01, epsilon=0.001) + self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") + self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") + + def construct(self, grad, indices): + out = self.sparse_apply_adagrad_v2(self.var, self.accum, grad, indices) + return out + + class ApplyRMSNet(nn.Cell): def __init__(self): super(ApplyRMSNet, self).__init__() @@ -1376,10 +1401,18 @@ test_case_nn_ops = [ 'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))], 'desc_bprop': [[3, 3], [3, 3]], 'skip': ['backward']}), + ('SparseApplyAdagradV2', { + 'block': SparseApplyAdagradV2Net(), + 'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))], + 'skip': ['backward']}), ('SparseApplyFtrl', { 'block': SparseApplyFtrlNet(), 'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))], 'skip': ['backward']}), + ('SparseApplyFtrlV2', { + 'block': SparseApplyFtrlV2Net(), + 'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))], + 'skip': ['backward']}), ('ApplyProximalAdagrad', { 'block': ApplyProximalAdagradNet(), 'desc_inputs': [[3, 3]],