| @@ -70,11 +70,13 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"strided_slice", "strided_slice_d"}, | {"strided_slice", "strided_slice_d"}, | ||||
| {"strided_slice_grad", "strided_slice_grad_d"}, | {"strided_slice_grad", "strided_slice_grad_d"}, | ||||
| {"sparse_apply_ftrl", "sparse_apply_ftrl_d"}, | {"sparse_apply_ftrl", "sparse_apply_ftrl_d"}, | ||||
| {"sparse_apply_ftrl_v2", "sparse_apply_ftrl_v2_d"}, | |||||
| {"apply_ada_max", "apply_ada_max_d"}, | {"apply_ada_max", "apply_ada_max_d"}, | ||||
| {"apply_adadelta", "apply_adadelta_d"}, | {"apply_adadelta", "apply_adadelta_d"}, | ||||
| {"apply_adagrad", "apply_adagrad_d"}, | {"apply_adagrad", "apply_adagrad_d"}, | ||||
| {"apply_adagrad_v2", "apply_adagradv2_d"}, | {"apply_adagrad_v2", "apply_adagradv2_d"}, | ||||
| {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, | {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, | ||||
| {"sparse_apply_adagrad_v2", "sparse_apply_adagrad_v2_d"}, | |||||
| {"apply_proximal_adagrad", "apply_proximal_adagrad_d"}, | {"apply_proximal_adagrad", "apply_proximal_adagrad_d"}, | ||||
| {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, | {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, | ||||
| {"apply_add_sign", "apply_add_sign_d"}, | {"apply_add_sign", "apply_add_sign_d"}, | ||||
| @@ -38,6 +38,8 @@ from .apply_add_sign import _apply_add_sign_tbe | |||||
| from .apply_power_sign import _apply_power_sign_tbe | from .apply_power_sign import _apply_power_sign_tbe | ||||
| from .apply_gradient_descent import _apply_gradient_descent_tbe | from .apply_gradient_descent import _apply_gradient_descent_tbe | ||||
| from .apply_proximal_gradient_descent import _apply_proximal_gradient_descent_tbe | from .apply_proximal_gradient_descent import _apply_proximal_gradient_descent_tbe | ||||
| from .sparse_apply_ftrl_v2 import _sparse_apply_ftrl_v2_tbe | |||||
| from .sparse_apply_adagrad_v2 import _sparse_apply_adagrad_v2_tbe | |||||
| from .approximate_equal import _approximate_equal_tbe | from .approximate_equal import _approximate_equal_tbe | ||||
| from .adam_apply_one import _adam_apply_one_tbe | from .adam_apply_one import _adam_apply_one_tbe | ||||
| from .assign import _assign_tbe | from .assign import _assign_tbe | ||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """SparseApplyAdagradV2D op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| sparse_apply_adagrad_v2_d_op_info = TBERegOp("SparseApplyAdagradV2") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("sparse_apply_adagrad_v2_d.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("sparse_apply_adagrad_v2_d") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("lr", "required", "float", "all") \ | |||||
| .attr("epsilon", "required", "float", "all") \ | |||||
| .attr("use_locking", "optional", "bool", "all") \ | |||||
| .attr("update_slots", "optional", "bool", "all") \ | |||||
| .input(0, "var", False, "required", "all") \ | |||||
| .input(1, "accum", False, "required", "all") \ | |||||
| .input(2, "grad", False, "required", "all") \ | |||||
| .input(3, "indices", False, "required", "all") \ | |||||
| .output(0, "var", False, "required", "all") \ | |||||
| .output(1, "accum", False, "required", "all") \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, | |||||
| DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, | |||||
| DataType.F32_NHWC, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(sparse_apply_adagrad_v2_d_op_info) | |||||
| def _sparse_apply_adagrad_v2_tbe(): | |||||
| """SparseApplyAdagradV2D TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """SparseApplyFtrlV2D op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| sparse_apply_ftrl_v2_d_op_info = TBERegOp("SparseApplyFtrlV2") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("sparse_apply_ftrl_v2_d.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("sparse_apply_ftrl_v2_d") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("lr", "required", "float", "all") \ | |||||
| .attr("l1", "required", "float", "all") \ | |||||
| .attr("l2", "required", "float", "all") \ | |||||
| .attr("l2_shrinkage", "required", "float", "all") \ | |||||
| .attr("lr_power", "required", "float", "all") \ | |||||
| .attr("use_locking", "optional", "bool", "true,false", "false") \ | |||||
| .input(0, "var", False, "required", "all") \ | |||||
| .input(1, "accum", False, "required", "all") \ | |||||
| .input(2, "linear", False, "required", "all") \ | |||||
| .input(3, "grad", False, "required", "all") \ | |||||
| .input(4, "indices", False, "required", "all") \ | |||||
| .output(0, "var", False, "required", "all") \ | |||||
| .output(1, "accum", False, "required", "all") \ | |||||
| .output(2, "linear", False, "required", "all") \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||||
| DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(sparse_apply_ftrl_v2_d_op_info) | |||||
| def _sparse_apply_ftrl_v2_tbe(): | |||||
| """SparseApplyFtrlV2D TBE register""" | |||||
| return | |||||
| @@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl | |||||
| SoftmaxCrossEntropyWithLogits, ROIAlign, | SoftmaxCrossEntropyWithLogits, ROIAlign, | ||||
| SparseSoftmaxCrossEntropyWithLogits, Tanh, | SparseSoftmaxCrossEntropyWithLogits, Tanh, | ||||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, | TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, | ||||
| ApplyProximalAdagrad, SparseApplyProximalAdagrad, | |||||
| ApplyProximalAdagrad, SparseApplyProximalAdagrad, SparseApplyAdagradV2, SparseApplyFtrlV2, | |||||
| ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, | ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, | ||||
| ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, | ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, | ||||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) | ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) | ||||
| @@ -284,6 +284,7 @@ __all__ = [ | |||||
| "Abs", | "Abs", | ||||
| "BinaryCrossEntropy", | "BinaryCrossEntropy", | ||||
| "SparseApplyAdagrad", | "SparseApplyAdagrad", | ||||
| "SparseApplyAdagradV2", | |||||
| "SpaceToDepth", | "SpaceToDepth", | ||||
| "DepthToSpace", | "DepthToSpace", | ||||
| "Conv2DBackpropInput", | "Conv2DBackpropInput", | ||||
| @@ -294,6 +295,7 @@ __all__ = [ | |||||
| "ApplyFtrl", | "ApplyFtrl", | ||||
| "SpaceToBatch", | "SpaceToBatch", | ||||
| "SparseApplyFtrl", | "SparseApplyFtrl", | ||||
| "SparseApplyFtrlV2", | |||||
| "ApplyProximalAdagrad", | "ApplyProximalAdagrad", | ||||
| "SparseApplyProximalAdagrad", | "SparseApplyProximalAdagrad", | ||||
| "ApplyAdaMax", | "ApplyAdaMax", | ||||
| @@ -3600,6 +3600,88 @@ class SparseApplyAdagrad(PrimitiveWithInfer): | |||||
| return var_type, accum_type | return var_type, accum_type | ||||
| class SparseApplyAdagradV2(PrimitiveWithInfer): | |||||
| r""" | |||||
| Update relevant entries according to the adagrad scheme. | |||||
| .. math:: | |||||
| accum += grad * grad | |||||
| .. math:: | |||||
| var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon} | |||||
| Args: | |||||
| lr (float): Learning rate. | |||||
| epsilon (float): A small value added for numerical stability. | |||||
| use_locking (bool): If `True`, updating of the var and accum tensors will be protected. Default: False. | |||||
| update_slots (bool): If `True`, the computation logic will be different to `False`. Default: True. | |||||
| Inputs: | |||||
| - **var** (Parameter) - Variable to be updated. The type must be float32. | |||||
| - **accum** (Parameter) - Accum to be updated. The shape must be the same as `var`'s shape, | |||||
| the type must be float32. | |||||
| - **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape except first dimension, | |||||
| the type must be float32. | |||||
| - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. | |||||
| The shape of `indices` must be the same as `grad` in first dimension, the type must be int32. | |||||
| Outputs: | |||||
| Tuple of 2 Tensor, the updated parameters. | |||||
| - **var** (Tensor) - The same shape and data type as `var`. | |||||
| - **accum** (Tensor) - The same shape and data type as `accum`. | |||||
| Examples: | |||||
| >>> import numpy as np | |||||
| >>> import mindspore.nn as nn | |||||
| >>> from mindspore import Tensor, Parameter | |||||
| >>> from mindspore.ops import operations as P | |||||
| >>> import mindspore.common.dtype as mstype | |||||
| >>> class Net(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(Net, self).__init__() | |||||
| >>> self.sparse_apply_adagrad_v2 = P.SparseApplyAdagradV2(lr=1e-8, epsilon=1e-6) | |||||
| >>> self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var") | |||||
| >>> self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum") | |||||
| >>> | |||||
| >>> def construct(self, grad, indices): | |||||
| >>> out = self.sparse_apply_adagrad_v2(self.var, self.accum, grad, indices) | |||||
| >>> return out | |||||
| >>> net = Net() | |||||
| >>> grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) | |||||
| >>> indices = Tensor([0, 1, 2], mstype.int32) | |||||
| >>> result = net(grad, indices) | |||||
| """ | |||||
| __mindspore_signature__ = ( | |||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) | |||||
| ) | |||||
| @prim_attr_register | |||||
| def __init__(self, lr, epsilon, use_locking=False, update_slots=True): | |||||
| self.lr = validator.check_value_type("lr", lr, [float], self.name) | |||||
| self.epsilon = validator.check_value_type("epsilon", epsilon, [float], self.name) | |||||
| self.use_locking = validator.check_value_type("update_slots", update_slots, [bool], self.name) | |||||
| self.update_slots = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||||
| def infer_shape(self, var_shape, accum_shape, grad_shape, indices_shape): | |||||
| validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | |||||
| validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name) | |||||
| if len(var_shape) > 1: | |||||
| validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) | |||||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||||
| validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | |||||
| return var_shape, accum_shape | |||||
| def infer_dtype(self, var_type, accum_type, grad_type, indices_type): | |||||
| args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} | |||||
| validator.check_tensor_type_same(args, [mstype.float32], self.name) | |||||
| validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) | |||||
| return var_type, accum_type | |||||
| class ApplyProximalAdagrad(PrimitiveWithInfer): | class ApplyProximalAdagrad(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Update relevant entries according to the proximal adagrad algorithm. | Update relevant entries according to the proximal adagrad algorithm. | ||||
| @@ -3664,7 +3746,8 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, use_locking=False): | def __init__(self, use_locking=False): | ||||
| self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['output']) | |||||
| self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], | |||||
| outputs=['var', 'accum']) | |||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | ||||
| def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape): | def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape): | ||||
| @@ -4371,6 +4454,98 @@ class SparseApplyFtrl(PrimitiveWithInfer): | |||||
| return var_dtype, accum_dtype, linear_dtype | return var_dtype, accum_dtype, linear_dtype | ||||
| class SparseApplyFtrlV2(PrimitiveWithInfer): | |||||
| """ | |||||
| Update relevant entries according to the FTRL-proximal scheme. | |||||
| Args: | |||||
| lr (float): The learning rate value, must be positive. | |||||
| l1 (float): l1 regularization strength, must be greater than or equal to zero. | |||||
| l2 (float): l2 regularization strength, must be greater than or equal to zero. | |||||
| l2_shrinkage (float): L2 shrinkage regularization. | |||||
| lr_power (float): Learning rate power controls how the learning rate decreases during training, | |||||
| must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero. | |||||
| use_locking (bool): If `True`, updating of the var and accum tensors will be protected. Default: False. | |||||
| Inputs: | |||||
| - **var** (Parameter) - The variable to be updated. The data type must be float32. | |||||
| - **accum** (Parameter) - The accum to be updated, must be same type and shape as `var`. | |||||
| - **linear** (Parameter) - The linear to be updated, must be same type and shape as `var`. | |||||
| - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. | |||||
| - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. | |||||
| The shape of `indices` must be the same as `grad` in first dimension. The type must be int32. | |||||
| Outputs: | |||||
| Tuple of 3 Tensor, the updated parameters. | |||||
| - **var** (Tensor): Tensor, has the same shape and type as `var`. | |||||
| - **accum** (Tensor): Tensor, has the same shape and type as `accum`. | |||||
| - **linear** (Tensor): Tensor, has the same shape and type as `linear`. | |||||
| Examples: | |||||
| >>> import mindspore | |||||
| >>> import mindspore.nn as nn | |||||
| >>> import numpy as np | |||||
| >>> from mindspore import Parameter | |||||
| >>> from mindspore import Tensor | |||||
| >>> from mindspore.ops import operations as P | |||||
| >>> class SparseApplyFtrlV2Net(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(SparseApplyFtrlV2Net, self).__init__() | |||||
| >>> self.sparse_apply_ftrl_v2 = P.SparseApplyFtrlV2(lr=0.01, l1=0.0, l2=0.0, | |||||
| l2_shrinkage=0.0, lr_power=-0.5) | |||||
| >>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") | |||||
| >>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") | |||||
| >>> self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear") | |||||
| >>> | |||||
| >>> def construct(self, grad, indices): | |||||
| >>> out = self.sparse_apply_ftrl_v2(self.var, self.accum, self.linear, grad, indices) | |||||
| >>> return out | |||||
| >>> | |||||
| >>> net = SparseApplyFtrlV2Net() | |||||
| >>> grad = Tensor(np.random.rand(3, 3).astype(np.float32)) | |||||
| >>> indices = Tensor(np.ones([3]), mindspore.int32) | |||||
| >>> output = net(grad, indices) | |||||
| """ | |||||
| __mindspore_signature__ = ( | |||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) | |||||
| ) | |||||
| @prim_attr_register | |||||
| def __init__(self, lr, l1, l2, l2_shrinkage, lr_power, use_locking=False): | |||||
| validator.check_value_type("lr", lr, [float], self.name) | |||||
| validator.check_value_type("l1", l1, [float], self.name) | |||||
| validator.check_value_type("l2", l2, [float], self.name) | |||||
| validator.check_value_type("lr_power", lr_power, [float], self.name) | |||||
| self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) | |||||
| self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||||
| self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||||
| self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | |||||
| self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name) | |||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||||
| def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): | |||||
| validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | |||||
| validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) | |||||
| if len(var_shape) > 1: | |||||
| validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) | |||||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||||
| validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | |||||
| return var_shape, accum_shape, linear_shape | |||||
| def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): | |||||
| args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, | |||||
| "linear_dtype": linear_dtype, "grad_dtype": grad_dtype} | |||||
| validator.check_tensor_type_same(args, [mstype.float32], self.name) | |||||
| validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) | |||||
| return var_dtype, accum_dtype, linear_dtype | |||||
| class ConfusionMulGrad(PrimitiveWithInfer): | class ConfusionMulGrad(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| `output0` is the result of which input0 dot multily input1. | `output0` is the result of which input0 dot multily input1. | ||||
| @@ -306,6 +306,19 @@ class SparseApplyFtrlNet(nn.Cell): | |||||
| return out | return out | ||||
| class SparseApplyFtrlV2Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(SparseApplyFtrlV2Net, self).__init__() | |||||
| self.sparse_apply_ftrl_v2 = P.SparseApplyFtrlV2(lr=0.001, l1=0.0, l2=0.0, l2_shrinkage=0.0, lr_power=-0.5) | |||||
| self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") | |||||
| self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") | |||||
| self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear") | |||||
| def construct(self, grad, indices): | |||||
| out = self.sparse_apply_ftrl_v2(self.var, self.accum, self.linear, grad, indices) | |||||
| return out | |||||
| class SparseApplyProximalAdagradNet(nn.Cell): | class SparseApplyProximalAdagradNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(SparseApplyProximalAdagradNet, self).__init__() | super(SparseApplyProximalAdagradNet, self).__init__() | ||||
| @@ -467,6 +480,18 @@ class SparseApplyAdagradNet(nn.Cell): | |||||
| return out | return out | ||||
| class SparseApplyAdagradV2Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(SparseApplyAdagradV2Net, self).__init__() | |||||
| self.sparse_apply_adagrad_v2 = P.SparseApplyAdagradV2(lr=0.01, epsilon=0.001) | |||||
| self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") | |||||
| self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") | |||||
| def construct(self, grad, indices): | |||||
| out = self.sparse_apply_adagrad_v2(self.var, self.accum, grad, indices) | |||||
| return out | |||||
| class ApplyRMSNet(nn.Cell): | class ApplyRMSNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(ApplyRMSNet, self).__init__() | super(ApplyRMSNet, self).__init__() | ||||
| @@ -1376,10 +1401,18 @@ test_case_nn_ops = [ | |||||
| 'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))], | 'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))], | ||||
| 'desc_bprop': [[3, 3], [3, 3]], | 'desc_bprop': [[3, 3], [3, 3]], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('SparseApplyAdagradV2', { | |||||
| 'block': SparseApplyAdagradV2Net(), | |||||
| 'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))], | |||||
| 'skip': ['backward']}), | |||||
| ('SparseApplyFtrl', { | ('SparseApplyFtrl', { | ||||
| 'block': SparseApplyFtrlNet(), | 'block': SparseApplyFtrlNet(), | ||||
| 'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))], | 'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('SparseApplyFtrlV2', { | |||||
| 'block': SparseApplyFtrlV2Net(), | |||||
| 'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))], | |||||
| 'skip': ['backward']}), | |||||
| ('ApplyProximalAdagrad', { | ('ApplyProximalAdagrad', { | ||||
| 'block': ApplyProximalAdagradNet(), | 'block': ApplyProximalAdagradNet(), | ||||
| 'desc_inputs': [[3, 3]], | 'desc_inputs': [[3, 3]], | ||||