| @@ -65,6 +65,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"dropout_do_mask", "drop_out_do_mask"}, | {"dropout_do_mask", "drop_out_do_mask"}, | ||||
| {"strided_slice", "strided_slice_d"}, | {"strided_slice", "strided_slice_d"}, | ||||
| {"strided_slice_grad", "strided_slice_grad_d"}, | {"strided_slice_grad", "strided_slice_grad_d"}, | ||||
| {"sparse_apply_ftrl", "sparse_apply_ftrl_d"}, | |||||
| {"transpose", "transpose_d"}, | {"transpose", "transpose_d"}, | ||||
| {"fill", "fill_d"}, | {"fill", "fill_d"}, | ||||
| {"unsorted_segment_sum", "unsorted_segment_sum_d"}, | {"unsorted_segment_sum", "unsorted_segment_sum_d"}, | ||||
| @@ -112,6 +112,9 @@ from .softplus_grad import _softplus_grad_tbe | |||||
| from .softmax_grad_ext import _softmax_grad_ext_tbe | from .softmax_grad_ext import _softmax_grad_ext_tbe | ||||
| from .square import _square_tbe | from .square import _square_tbe | ||||
| from .sqrt import _sqrt_tbe | from .sqrt import _sqrt_tbe | ||||
| from .sparse_apply_ftrl_d import _sparse_apply_ftrl_d | |||||
| from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad | |||||
| from .apply_proximal_adagrad import _apply_proximal_adagrad | |||||
| from .transpose_d import _transpose_d_tbe | from .transpose_d import _transpose_d_tbe | ||||
| from .unsorted_segment_sum import _unsorted_segment_sum_tbe | from .unsorted_segment_sum import _unsorted_segment_sum_tbe | ||||
| from .logsoftmax_grad import _logsoftmax_grad_tbe | from .logsoftmax_grad import _logsoftmax_grad_tbe | ||||
| @@ -0,0 +1,56 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ApplyProximalAdagrad op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| apply_proximal_adagrad_op_info = TBERegOp("ApplyProximalAdagrad") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("apply_proximal_adagrad.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("apply_proximal_adagrad") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("use_locking", "optional", "bool", "true,false", "false") \ | |||||
| .input(0, "var", False, "required", "all") \ | |||||
| .input(1, "accum", False, "required", "all") \ | |||||
| .input(2, "lr", False, "required", "all") \ | |||||
| .input(3, "l1", False, "required", "all") \ | |||||
| .input(4, "l2", False, "required", "all") \ | |||||
| .input(5, "grad", False, "required", "all") \ | |||||
| .output(0, "var", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||||
| .get_op_info() | |||||
| @op_info_register(apply_proximal_adagrad_op_info) | |||||
| def _apply_proximal_adagrad(): | |||||
| """ApplyProximalAdagrad TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,51 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """SparseApplyFtrl op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| sparse_apply_ftrl_d_op_info = TBERegOp("SparseApplyFtrl") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("sparse_apply_ftrl.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("sparse_apply_ftrl") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("lr", "required", "float", "all") \ | |||||
| .attr("l1", "required", "float", "all") \ | |||||
| .attr("l2", "required", "float", "all") \ | |||||
| .attr("lr_power", "required", "float", "all") \ | |||||
| .attr("use_locking", "optional", "bool", "true,false", "false") \ | |||||
| .input(0, "var", False, "required", "all") \ | |||||
| .input(1, "accum", False, "required", "all") \ | |||||
| .input(2, "linear", False, "required", "all") \ | |||||
| .input(3, "grad", False, "required", "all") \ | |||||
| .input(4, "indices", False, "required", "all") \ | |||||
| .output(0, "var", False, "required", "all") \ | |||||
| .output(1, "accum", False, "required", "all") \ | |||||
| .output(2, "linear", False, "required", "all") \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||||
| DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(sparse_apply_ftrl_d_op_info) | |||||
| def _sparse_apply_ftrl_d(): | |||||
| """SparseApplyFtrl TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,101 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """SparseApplyProximalAdagrad op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| sparse_apply_proximal_adagrad_op_info = TBERegOp("SparseApplyProximalAdagrad") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("sparse_apply_proximal_adagrad.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("sparse_apply_proximal_adagrad") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("use_locking", "optional", "bool", "true,false", "false") \ | |||||
| .input(0, "var", False, "required", "all") \ | |||||
| .input(1, "accum", False, "required", "all") \ | |||||
| .input(2, "lr", False, "required", "all") \ | |||||
| .input(3, "l1", False, "required", "all") \ | |||||
| .input(4, "l2", False, "required", "all") \ | |||||
| .input(5, "grad", False, "required", "all") \ | |||||
| .input(6, "indices", False, "required", "all") \ | |||||
| .output(0, "var", False, "required", "all") \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I16_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I16_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I16_NHWC, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.I16_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I16_FracZ, DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I64_NHWC, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U16_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U16_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U16_NHWC, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.U16_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U16_FracZ, DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U32_NHWC, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.U32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U32_FracZ, DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U64_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U64_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U64_NHWC, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.U64_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U64_FracZ, DataType.F32_FracZ) \ | |||||
| .get_op_info() | |||||
| @op_info_register(sparse_apply_proximal_adagrad_op_info) | |||||
| def _sparse_apply_proximal_adagrad(): | |||||
| """SparseApplyProximalAdagrad TBE register""" | |||||
| return | |||||
| @@ -68,7 +68,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||||
| SmoothL1Loss, Softmax, Softplus, | SmoothL1Loss, Softmax, Softplus, | ||||
| SoftmaxCrossEntropyWithLogits, ROIAlign, | SoftmaxCrossEntropyWithLogits, ROIAlign, | ||||
| SparseSoftmaxCrossEntropyWithLogits, Tanh, | SparseSoftmaxCrossEntropyWithLogits, Tanh, | ||||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, | |||||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, | |||||
| ApplyProximalAdagrad, SparseApplyProximalAdagrad, | |||||
| ApplyRMSProp, ApplyCenteredRMSProp) | ApplyRMSProp, ApplyCenteredRMSProp) | ||||
| from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop | from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop | ||||
| from . import _quant_ops | from . import _quant_ops | ||||
| @@ -265,6 +266,9 @@ __all__ = [ | |||||
| "Round", | "Round", | ||||
| "ApplyFtrl", | "ApplyFtrl", | ||||
| "SpaceToBatch", | "SpaceToBatch", | ||||
| "SparseApplyFtrl", | |||||
| "ApplyProximalAdagrad", | |||||
| "SparseApplyProximalAdagrad", | |||||
| "BatchToSpace", | "BatchToSpace", | ||||
| "Atan2", | "Atan2", | ||||
| "ApplyRMSProp", | "ApplyRMSProp", | ||||
| @@ -2807,6 +2807,126 @@ class SparseApplyAdagrad(PrimitiveWithInfer): | |||||
| return var_type | return var_type | ||||
| class ApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| r""" | |||||
| Update relevant entries according to the proximal adagrad algorithm. | |||||
| .. math:: | |||||
| accum += grad * grad | |||||
| .. math:: | |||||
| prox_v = var - lr * grad * \frac{1}{\sqrt{accum}} | |||||
| .. math:: | |||||
| var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0) | |||||
| Args: | |||||
| use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. | |||||
| Inputs: | |||||
| - **var** (Tensor) - Variable to be updated. | |||||
| - **accum** (Tensor) - Accum to be updated. The shape must be the same as `var`'s shape. | |||||
| - **lr** (Union[Number, Tensor]): The learning rate value, must be positive. It should be | |||||
| a scalar tensor or number. | |||||
| - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. | |||||
| It should be a scalar tensor or number. | |||||
| - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. | |||||
| It should be a scalar tensor or number. | |||||
| - **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape. | |||||
| Outputs: | |||||
| Tensor, has the same shape and type as `var`. | |||||
| Examples: | |||||
| >>> var = Tensor(np.random.random((3, 3)), mindspore.float32) | |||||
| >>> accum = Tensor(np.random.random((3, 3)), mindspore.float32) | |||||
| >>> grad = Tensor(np.random.random((3, 3)), mindspore.float32) | |||||
| >>> lr = 0.01 | |||||
| >>> l1 = 0.0 | |||||
| >>> l2 = 0.0 | |||||
| >>> apply_proximal_ada_grad = P.ApplyProximalAdagrad() | |||||
| >>> output = apply_proximal_ada_grad(var, accum, lr, l1, l2, grad) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, use_locking=False): | |||||
| self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['output']) | |||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||||
| def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape): | |||||
| validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | |||||
| validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name) | |||||
| return var_shape | |||||
| def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype): | |||||
| valid_types = [mstype.float16, mstype.float32] | |||||
| args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} | |||||
| validator.check_tensor_type_same(args, valid_types, self.name) | |||||
| scalar_args = {"lr": lr_dtype, "l1": l1_dtype, "l2": l2_dtype} | |||||
| validator.check_scalar_or_tensor_type_same(scalar_args, valid_types, self.name) | |||||
| return var_dtype | |||||
| class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| r""" | |||||
| Update relevant entries according to the proximal adagrad algorithm. Compared with ApplyProximalAdagrad, | |||||
| an additional index tensor is input. | |||||
| .. math:: | |||||
| accum += grad * grad | |||||
| .. math:: | |||||
| prox_v = var - lr * grad * \frac{1}{\sqrt{accum}} | |||||
| .. math:: | |||||
| var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0) | |||||
| Args: | |||||
| use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. | |||||
| Inputs: | |||||
| - **var** (Tensor) - Variable tensor to be updated. | |||||
| - **accum** (Tensor) - Variable tensor to be updated. The shape must be the same as `var`'s shape. | |||||
| - **lr** (Union[Number, Tensor]): The learning rate value, must be positive. It should be | |||||
| a scalar tensor or number. | |||||
| - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. | |||||
| It should be a scalar tensor or number. | |||||
| - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. | |||||
| It should be a scalar tensor or number. | |||||
| - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. | |||||
| - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. | |||||
| Outputs: | |||||
| Tensor, has the same shape and type as `var`. | |||||
| Examples: | |||||
| >>> var = Tensor(np.random.random((3, 3)), mindspore.float32) | |||||
| >>> accum = Tensor(np.random.random((3, 3)), mindspore.float32) | |||||
| >>> grad = Tensor(np.random.random((3, 3)), mindspore.float32) | |||||
| >>> indices = Tensor(np.ones((3,), np.int32)) | |||||
| >>> lr = 0.01 | |||||
| >>> l1 = 0.0 | |||||
| >>> l2 = 0.0 | |||||
| >>> sparse_apply_proximal_ada_grad = P.SparseApplyProximalAdagrad() | |||||
| >>> output = sparse_apply_proximal_ada_grad(var, accum, lr, l1, l2, grad, indices) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, use_locking=False): | |||||
| self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'], | |||||
| outputs=['output']) | |||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||||
| def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): | |||||
| return var_shape | |||||
| def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): | |||||
| args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} | |||||
| validator.check_tensor_type_same(args, [mstype.float32], self.name) | |||||
| scalar_args = {"lr": lr_dtype, "l1": l1_dtype, "l2": l2_dtype} | |||||
| validator.check_scalar_or_tensor_type_same(scalar_args, [mstype.float32], self.name) | |||||
| valid_types = [mstype.int16, mstype.int32, mstype.int64, | |||||
| mstype.uint16, mstype.uint32, mstype.uint64] | |||||
| validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) | |||||
| return var_dtype | |||||
| class LARSUpdate(PrimitiveWithInfer): | class LARSUpdate(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Conduct lars (layer-wise adaptive rate scaling) update on the square sum of gradient. | Conduct lars (layer-wise adaptive rate scaling) update on the square sum of gradient. | ||||
| @@ -2963,6 +3083,85 @@ class ApplyFtrl(PrimitiveWithInfer): | |||||
| return var_type | return var_type | ||||
| class SparseApplyFtrl(PrimitiveWithInfer): | |||||
| """ | |||||
| Update relevant entries according to the FTRL-proximal scheme. | |||||
| Args: | |||||
| lr (float): The learning rate value, must be positive. | |||||
| l1 (float): l1 regularization strength, must be greater than or equal to zero. | |||||
| l2 (float): l2 regularization strength, must be greater than or equal to zero. | |||||
| lr_power (float): Learning rate power controls how the learning rate decreases during training, | |||||
| must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero. | |||||
| use_locking (bool): Use locks for update operation if True . Default: False. | |||||
| Inputs: | |||||
| - **var** (Tensor): The variable to be updated. | |||||
| - **accum** (Tensor): The accum to be updated, must be same type and shape as `var`. | |||||
| - **linear** (Tensor): The linear to be updated, must be same type and shape as `var`. | |||||
| - **grad** (Tensor): A tensor of the same type as `var`, for the gradient. | |||||
| - **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. | |||||
| The shape of `indices` must be the same as `grad` in first dimension. The type must be int32. | |||||
| Outputs: | |||||
| - **var** (Tensor): Tensor, has the same shape and type as `var`. | |||||
| - **accum** (Tensor): Tensor, has the same shape and type as `accum`. | |||||
| - **linear** (Tensor): Tensor, has the same shape and type as `linear`. | |||||
| Examples: | |||||
| >>> import mindspore | |||||
| >>> import mindspore.nn as nn | |||||
| >>> import numpy as np | |||||
| >>> from mindspore import Parameter | |||||
| >>> from mindspore import Tensor | |||||
| >>> from mindspore.ops import operations as P | |||||
| >>> class SparseApplyFtrlNet(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(SparseApplyFtrlNet, self).__init__() | |||||
| >>> self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5) | |||||
| >>> self.var = Parameter(Tensor(np.random.random(3, 3).astype(np.float32)), name="var") | |||||
| >>> self.accum = Parameter(Tensor(np.random.random(3, 3).astype(np.float32)), name="accum") | |||||
| >>> self.linear = Parameter(Tensor(np.random.random(3, 3).astype(np.float32)), name="linear") | |||||
| >>> | |||||
| >>> def construct(self, grad, indices): | |||||
| >>> out = self.apply_ftrl(self.var, self.accum, self.linear, grad, indices) | |||||
| >>> return out | |||||
| >>> | |||||
| >>> net = SparseApplyFtrlNet() | |||||
| >>> grad = Tensor(np.random.random(3, 3).astype(np.float32)) | |||||
| >>> indices = Tnsor(np.ones([3]), mindspore.float32) | |||||
| >>> output = net(grad, indices) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, lr, l1, l2, lr_power, use_locking=False): | |||||
| validator.check_value_type("lr", lr, [float], self.name) | |||||
| validator.check_value_type("l1", l1, [float], self.name) | |||||
| validator.check_value_type("l2", l2, [float], self.name) | |||||
| validator.check_value_type("lr_power", lr_power, [float], self.name) | |||||
| self.lr = validator.check_number("lr", lr, 0.0, Rel.GT, self.name) | |||||
| self.l1 = validator.check_number("l1", l1, 0.0, Rel.GE, self.name) | |||||
| self.l2 = validator.check_number("l2", l2, 0.0, Rel.GE, self.name) | |||||
| self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | |||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||||
| def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): | |||||
| validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | |||||
| validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) | |||||
| if len(var_shape) > 1: | |||||
| validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) | |||||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||||
| validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | |||||
| return var_shape, accum_shape, linear_shape | |||||
| def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): | |||||
| args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, | |||||
| "linear_dtype": linear_dtype, "grad_dtype": grad_dtype} | |||||
| validator.check_tensor_type_same(args, [mstype.float32], self.name) | |||||
| validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) | |||||
| return var_dtype, accum_dtype, linear_dtype | |||||
| class ConfusionMulGrad(PrimitiveWithInfer): | class ConfusionMulGrad(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| `output0` is the result of which input0 dot multily input1. | `output0` is the result of which input0 dot multily input1. | ||||
| @@ -3124,7 +3323,7 @@ class CTCLoss(PrimitiveWithInfer): | |||||
| >>> labels_indices = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int64) | >>> labels_indices = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int64) | ||||
| >>> labels_values = Tensor(np.array([2, 2]), mindspore.int32) | >>> labels_values = Tensor(np.array([2, 2]), mindspore.int32) | ||||
| >>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32) | >>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32) | ||||
| >>> ctc_loss = P.CTCloss() | |||||
| >>> ctc_loss = P.CTCLoss() | |||||
| >>> output = ctc_loss(inputs, labels_indices, labels_values, sequence_length) | >>> output = ctc_loss(inputs, labels_indices, labels_values, sequence_length) | ||||
| """ | """ | ||||
| @@ -225,6 +225,46 @@ class ApplyFtrlNet(nn.Cell): | |||||
| out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2, self.lr_power) | out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2, self.lr_power) | ||||
| return out | return out | ||||
| class SparseApplyFtrlNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(SparseApplyFtrlNet, self).__init__() | |||||
| self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5) | |||||
| self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") | |||||
| self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") | |||||
| self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear") | |||||
| def construct(self, grad, indices): | |||||
| out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices) | |||||
| return out | |||||
| class SparseApplyProximalAdagradNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(SparseApplyProximalAdagradNet, self).__init__() | |||||
| self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() | |||||
| self.lr = 0.01 | |||||
| self.l1 = 0.0 | |||||
| self.l2 = 0.0 | |||||
| def construct(self, var, accum, grad, indices): | |||||
| out = self.sparse_apply_proximal_adagrad(var, accum, self.lr, self.l1, self.l2, grad, indices) | |||||
| return out | |||||
| class ApplyProximalAdagradNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(ApplyProximalAdagradNet, self).__init__() | |||||
| self.apply_proximal_adagrad = P.ApplyProximalAdagrad() | |||||
| self.lr = 0.01 | |||||
| self.l1 = 0.0 | |||||
| self.l2 = 0.0 | |||||
| def construct(self, var, accum, grad): | |||||
| out = self.apply_proximal_adagrad(var, accum, self.lr, self.l1, self.l2, grad) | |||||
| return out | |||||
| class ApplyRMSNet(nn.Cell): | class ApplyRMSNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(ApplyRMSNet, self).__init__() | super(ApplyRMSNet, self).__init__() | ||||
| @@ -970,6 +1010,18 @@ test_case_nn_ops = [ | |||||
| 'block': P.SparseApplyAdagrad(0.5), | 'block': P.SparseApplyAdagrad(0.5), | ||||
| 'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))], | 'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('SparseApplyFtrl', { | |||||
| 'block': SparseApplyFtrlNet(), | |||||
| 'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))], | |||||
| 'skip': ['backward']}), | |||||
| ('ApplyProximalAdagrad', { | |||||
| 'block': ApplyProximalAdagradNet(), | |||||
| 'desc_inputs': [[3, 3], [3, 3], [3, 3]], | |||||
| 'skip': ['backward']}), | |||||
| ('SparseApplyProximalAdagrad', { | |||||
| 'block': SparseApplyProximalAdagradNet(), | |||||
| 'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))], | |||||
| 'skip': ['backward']}), | |||||
| ('Flatten_1', { | ('Flatten_1', { | ||||
| 'block': NetForFlatten(), | 'block': NetForFlatten(), | ||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 4]).astype(np.int32)), Tensor(np.ones([2, 12]).astype(np.int32))], | 'desc_inputs': [Tensor(np.ones([2, 3, 4]).astype(np.int32)), Tensor(np.ones([2, 12]).astype(np.int32))], | ||||