From c7c6f5736b5bef701ea3a65d97fa33b21da98448 Mon Sep 17 00:00:00 2001 From: liuxiao Date: Wed, 17 Jun 2020 12:21:46 +0800 Subject: [PATCH] Adapt ApplyProximalAdagrad and SparseApplyProximalAdagrad --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 2 + mindspore/ccsrc/transform/convert.cc | 2 +- mindspore/ccsrc/transform/op_declare.cc | 10 +- mindspore/ccsrc/transform/op_declare.h | 4 +- .../_op_impl/tbe/apply_proximal_adagrad.py | 29 ++-- .../tbe/sparse_apply_proximal_adagrad.py | 99 +++++++++----- mindspore/ops/operations/nn_ops.py | 129 ++++++++++++------ 7 files changed, 180 insertions(+), 95 deletions(-) diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 8ff3bb2123..f5b2ca3253 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -75,6 +75,8 @@ static std::map tbe_func_adapter_map = { {"apply_adagrad", "apply_adagrad_d"}, {"apply_adagrad_v2", "apply_adagradv2_d"}, {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, + {"apply_proximal_adagrad", "apply_proximal_adagrad_d"}, + {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, {"transpose", "transpose_d"}, {"fill", "fill_d"}, {"unsorted_segment_sum", "unsorted_segment_sum_d"}, diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index a8c085c5fb..a5726b078a 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -391,7 +391,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, {string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)}, - {string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagrad)}, + {string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagradD)}, {string(kNameAcosh), ADPT_DESC(Acosh)}, {string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)}, {string(kNameFloorMod), ADPT_DESC(FloorMod)}, diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 64ec062610..cac526f1fb 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -1170,11 +1170,11 @@ ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits())}, {"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}}; -// ApplyProximalAdagrad -INPUT_MAP(ApplyProximalAdagrad) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, - {4, INPUT_DESC(l1)}, {5, INPUT_DESC(l2)}, {6, INPUT_DESC(grad)}}; -ATTR_MAP(ApplyProximalAdagrad) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ApplyProximalAdagrad) = {{0, OUTPUT_DESC(var)}}; +// ApplyProximalAdagradD +INPUT_MAP(ApplyProximalAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, + {4, INPUT_DESC(l1)}, {5, INPUT_DESC(l2)}, {6, INPUT_DESC(grad)}}; +ATTR_MAP(ApplyProximalAdagradD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyProximalAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; // SparseApplyFtrlD INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)}, diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index 6edc0e1884..f64dc7b671 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -446,8 +446,8 @@ DECLARE_OP_ADAPTER(BinaryCrossEntropyGrad) DECLARE_OP_USE_OUTPUT(BinaryCrossEntropyGrad) DECLARE_OP_ADAPTER(SparseApplyAdagradD) DECLARE_OP_USE_OUTPUT(SparseApplyAdagradD) -DECLARE_OP_ADAPTER(ApplyProximalAdagrad) -DECLARE_OP_USE_OUTPUT(ApplyProximalAdagrad) +DECLARE_OP_ADAPTER(ApplyProximalAdagradD) +DECLARE_OP_USE_OUTPUT(ApplyProximalAdagradD) DECLARE_OP_ADAPTER(SpaceToDepth) DECLARE_OP_USE_OUTPUT(SpaceToDepth) DECLARE_OP_ADAPTER(DepthToSpace) diff --git a/mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py b/mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py index 9099c6e24f..c9b8adf4f4 100644 --- a/mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +++ b/mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py @@ -13,15 +13,15 @@ # limitations under the License. # ============================================================================ -"""ApplyProximalAdagrad op""" +"""ApplyProximalAdagradD op""" from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType -apply_proximal_adagrad_op_info = TBERegOp("ApplyProximalAdagrad") \ +apply_proximal_adagrad_d_op_info = TBERegOp("ApplyProximalAdagrad") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("apply_proximal_adagrad.so") \ + .binfile_name("apply_proximal_adagrad_d.so") \ .compute_cost(10) \ - .kernel_name("apply_proximal_adagrad") \ + .kernel_name("apply_proximal_adagrad_d") \ .partial_flag(True) \ .attr("use_locking", "optional", "bool", "true,false", "false") \ .input(0, "var", False, "required", "all") \ @@ -31,26 +31,27 @@ apply_proximal_adagrad_op_info = TBERegOp("ApplyProximalAdagrad") \ .input(4, "l2", False, "required", "all") \ .input(5, "grad", False, "required", "all") \ .output(0, "var", False, "required", "all") \ + .output(1, "accum", False, "required", "all") \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \ + DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ + DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ + DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ + DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ .get_op_info() -@op_info_register(apply_proximal_adagrad_op_info) +@op_info_register(apply_proximal_adagrad_d_op_info) def _apply_proximal_adagrad(): - """ApplyProximalAdagrad TBE register""" + """ApplyProximalAdagradD TBE register""" return diff --git a/mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py b/mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py index f665890c55..782be983fa 100644 --- a/mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +++ b/mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py @@ -13,10 +13,10 @@ # limitations under the License. # ============================================================================ -"""SparseApplyProximalAdagrad op""" +"""SparseApplyProximalAdagradD op""" from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType -sparse_apply_proximal_adagrad_op_info = TBERegOp("SparseApplyProximalAdagrad") \ +sparse_apply_proximal_adagrad_d_op_info = TBERegOp("SparseApplyProximalAdagrad") \ .fusion_type("OPAQUE") \ .async_flag(False) \ .binfile_name("sparse_apply_proximal_adagrad.so") \ @@ -32,70 +32,101 @@ sparse_apply_proximal_adagrad_op_info = TBERegOp("SparseApplyProximalAdagrad") \ .input(5, "grad", False, "required", "all") \ .input(6, "indices", False, "required", "all") \ .output(0, "var", False, "required", "all") \ + .output(1, "accum", False, "required", "all") \ .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, - DataType.F32_NCHW, DataType.F32_NCHW, DataType.I16_NCHW, DataType.F32_NCHW) \ + DataType.F32_NCHW, DataType.F32_NCHW, DataType.I16_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, - DataType.F32_5HD, DataType.F32_5HD, DataType.I16_5HD, DataType.F32_5HD) \ + DataType.F32_5HD, DataType.F32_5HD, DataType.I16_5HD, DataType.F32_5HD, + DataType.F32_5HD) \ .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, - DataType.F32_NHWC, DataType.F32_NHWC, DataType.I16_NHWC, DataType.F32_NHWC) \ + DataType.F32_NHWC, DataType.F32_NHWC, DataType.I16_NHWC, DataType.F32_NHWC, + DataType.F32_NHWC) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default, DataType.I16_Default, DataType.F32_Default) \ + DataType.F32_Default, DataType.F32_Default, DataType.I16_Default, DataType.F32_Default, + DataType.F32_Default) \ .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, - DataType.F32_FracZ, DataType.F32_FracZ, DataType.I16_FracZ, DataType.F32_FracZ) \ + DataType.F32_FracZ, DataType.F32_FracZ, DataType.I16_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ) \ .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, - DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW) \ + DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, - DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \ + DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD, + DataType.F32_5HD) \ .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, - DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC) \ + DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC, + DataType.F32_NHWC) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ + DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, + DataType.F32_Default) \ .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, - DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \ + DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ) \ .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, - DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW) \ + DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, - DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \ + DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD, + DataType.F32_5HD) \ .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, - DataType.F32_NHWC, DataType.F32_NHWC, DataType.I64_NHWC, DataType.F32_NHWC) \ + DataType.F32_NHWC, DataType.F32_NHWC, DataType.I64_NHWC, DataType.F32_NHWC, + DataType.F32_NHWC) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ + DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, + DataType.F32_Default) \ .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, - DataType.F32_FracZ, DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \ + DataType.F32_FracZ, DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ) \ .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, - DataType.F32_NCHW, DataType.F32_NCHW, DataType.U16_NCHW, DataType.F32_NCHW) \ + DataType.F32_NCHW, DataType.F32_NCHW, DataType.U16_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, - DataType.F32_5HD, DataType.F32_5HD, DataType.U16_5HD, DataType.F32_5HD) \ + DataType.F32_5HD, DataType.F32_5HD, DataType.U16_5HD, DataType.F32_5HD, + DataType.F32_5HD) \ .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, - DataType.F32_NHWC, DataType.F32_NHWC, DataType.U16_NHWC, DataType.F32_NHWC) \ + DataType.F32_NHWC, DataType.F32_NHWC, DataType.U16_NHWC, DataType.F32_NHWC, + DataType.F32_NHWC) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default, DataType.U16_Default, DataType.F32_Default) \ + DataType.F32_Default, DataType.F32_Default, DataType.U16_Default, DataType.F32_Default, + DataType.F32_Default) \ .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, - DataType.F32_FracZ, DataType.F32_FracZ, DataType.U16_FracZ, DataType.F32_FracZ) \ + DataType.F32_FracZ, DataType.F32_FracZ, DataType.U16_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ) \ .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, - DataType.F32_NCHW, DataType.F32_NCHW, DataType.U32_NCHW, DataType.F32_NCHW) \ + DataType.F32_NCHW, DataType.F32_NCHW, DataType.U32_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, - DataType.F32_5HD, DataType.F32_5HD, DataType.U32_5HD, DataType.F32_5HD) \ + DataType.F32_5HD, DataType.F32_5HD, DataType.U32_5HD, DataType.F32_5HD, + DataType.F32_5HD) \ .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, - DataType.F32_NHWC, DataType.F32_NHWC, DataType.U32_NHWC, DataType.F32_NHWC) \ + DataType.F32_NHWC, DataType.F32_NHWC, DataType.U32_NHWC, DataType.F32_NHWC, + DataType.F32_NHWC) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default, DataType.U32_Default, DataType.F32_Default) \ + DataType.F32_Default, DataType.F32_Default, DataType.U32_Default, DataType.F32_Default, + DataType.F32_Default) \ .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, - DataType.F32_FracZ, DataType.F32_FracZ, DataType.U32_FracZ, DataType.F32_FracZ) \ + DataType.F32_FracZ, DataType.F32_FracZ, DataType.U32_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ) \ .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, - DataType.F32_NCHW, DataType.F32_NCHW, DataType.U64_NCHW, DataType.F32_NCHW) \ + DataType.F32_NCHW, DataType.F32_NCHW, DataType.U64_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, - DataType.F32_5HD, DataType.F32_5HD, DataType.U64_5HD, DataType.F32_5HD) \ + DataType.F32_5HD, DataType.F32_5HD, DataType.U64_5HD, DataType.F32_5HD, + DataType.F32_5HD) \ .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, - DataType.F32_NHWC, DataType.F32_NHWC, DataType.U64_NHWC, DataType.F32_NHWC) \ + DataType.F32_NHWC, DataType.F32_NHWC, DataType.U64_NHWC, DataType.F32_NHWC, + DataType.F32_NHWC) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default, DataType.U64_Default, DataType.F32_Default) \ + DataType.F32_Default, DataType.F32_Default, DataType.U64_Default, DataType.F32_Default, + DataType.F32_Default) \ .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, - DataType.F32_FracZ, DataType.F32_FracZ, DataType.U64_FracZ, DataType.F32_FracZ) \ + DataType.F32_FracZ, DataType.F32_FracZ, DataType.U64_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ) \ .get_op_info() -@op_info_register(sparse_apply_proximal_adagrad_op_info) +@op_info_register(sparse_apply_proximal_adagrad_d_op_info) def _sparse_apply_proximal_adagrad(): - """SparseApplyProximalAdagrad TBE register""" + """SparseApplyProximalAdagradD TBE register""" return diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 78c777f0f9..8e0b65247c 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3142,7 +3142,7 @@ class ApplyAdaMax(PrimitiveWithInfer): .. math:: \begin{array}{ll} \\ m_{t} = \beta_1 * m_{t-1} + (1 - \beta_1) * g \\ - v_{t} = \max(\beta_2 * v{t-1}, \left| g \right|) \\ + v_{t} = \max(\beta_2 * v_{t-1}, \left| g \right|) \\ var = var - \frac{l}{1 - \beta_1^t} * \frac{m_{t}}{v_{t} + \epsilon} \end{array} @@ -3497,37 +3497,61 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): .. math:: accum += grad * grad .. math:: - prox_v = var - lr * grad * \frac{1}{\sqrt{accum}} + \text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}} .. math:: - var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0) + var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0) Args: use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. Inputs: - - **var** (Tensor) - Variable to be updated. - - **accum** (Tensor) - Accum to be updated. The shape must be the same as `var`'s shape. + - **var** (Parameter) - Variable to be updated. The data type should be float. + - **accum** (Parameter) - Accum to be updated. Must has the same shape and dtype as `var`. - **lr** (Union[Number, Tensor]): The learning rate value. It should be a scalar tensor or number. + The data type should be float. - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. - It should be a scalar tensor or number. + It should be a scalar tensor or number. The data type should be float. - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. - It should be a scalar tensor or number. - - **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape. + It should be a scalar tensor or number. The data type should be float. + - **grad** (Tensor) - Gradient. Must has the same shape and dtype as `var`. Outputs: - Tensor, has the same shape and type as `var`. + Tuple of 2 Tensor, the updated parameters. + + - **var** (Tensor) - The same shape and data type as `var`. + - **accum** (Tensor) - The same shape and data type as `accum`. Examples: - >>> var = Tensor(np.random.random((3, 3)), mindspore.float32) - >>> accum = Tensor(np.random.random((3, 3)), mindspore.float32) - >>> grad = Tensor(np.random.random((3, 3)), mindspore.float32) - >>> lr = 0.01 - >>> l1 = 0.0 - >>> l2 = 0.0 - >>> apply_proximal_ada_grad = P.ApplyProximalAdagrad() - >>> output = apply_proximal_ada_grad(var, accum, lr, l1, l2, grad) + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore import Tensor, Parameter + >>> from mindspore.ops import operations as P + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.apply_proximal_adagrad = P.ApplyProximalAdagrad() + >>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") + >>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") + >>> self.lr = 0.01 + >>> self.l1 = 0.0 + >>> self.l2 = 0.0 + >>> def construct(self, grad): + >>> out = self.apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad) + >>> return out + >>> net = Net() + >>> grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + >>> output = net(grad) """ + __mindspore_signature__ = ( + ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + ) + @prim_attr_register def __init__(self, use_locking=False): self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['output']) @@ -3536,7 +3560,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape): validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name) - return var_shape + return var_shape, accum_shape def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype): valid_types = [mstype.float16, mstype.float32] @@ -3544,7 +3568,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): validator.check_tensor_type_same(args, valid_types, self.name) scalar_args = {"lr": lr_dtype, "l1": l1_dtype, "l2": l2_dtype} validator.check_scalar_or_tensor_type_same(scalar_args, valid_types, self.name) - return var_dtype + return var_dtype, accum_dtype class SparseApplyProximalAdagrad(PrimitiveWithInfer): @@ -3555,39 +3579,65 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): .. math:: accum += grad * grad .. math:: - prox_v = var - lr * grad * \frac{1}{\sqrt{accum}} + \text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}} .. math:: - var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0) + var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0) Args: use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. Inputs: - - **var** (Tensor) - Variable tensor to be updated. - - **accum** (Tensor) - Variable tensor to be updated. The shape must be the same as `var`'s shape. + - **var** (Parameter) - Variable tensor to be updated. The data type must be float32. + - **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`. - **lr** (Union[Number, Tensor]): The learning rate value. It should be a scalar tensor or number. + The data type must be float32. - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. - It should be a scalar tensor or number. + It should be a scalar tensor or number. The data type must be float32. - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. - It should be a scalar tensor or number. - - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. + It should be a scalar tensor or number. The data type must be float32. + - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32. - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. Outputs: - Tensor, has the same shape and type as `var`. + Tuple of 2 Tensor, the updated parameters. + + - **var** (Tensor) - The same shape and data type as `var`. + - **accum** (Tensor) - The same shape and data type as `accum`. Examples: - >>> var = Tensor(np.random.random((3, 3)), mindspore.float32) - >>> accum = Tensor(np.random.random((3, 3)), mindspore.float32) - >>> grad = Tensor(np.random.random((3, 3)), mindspore.float32) + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore import Tensor, Parameter + >>> from mindspore.ops import operations as P + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() + >>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") + >>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") + >>> self.lr = 0.01 + >>> self.l1 = 0.0 + >>> self.l2 = 0.0 + >>> def construct(self, grad, indices): + >>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, + self.l2, grad, indices) + >>> return out + >>> net = Net() + >>> grad = Tensor(np.random.rand(3, 3).astype(np.float32)) >>> indices = Tensor(np.ones((3,), np.int32)) - >>> lr = 0.01 - >>> l1 = 0.0 - >>> l2 = 0.0 - >>> sparse_apply_proximal_ada_grad = P.SparseApplyProximalAdagrad() - >>> output = sparse_apply_proximal_ada_grad(var, accum, lr, l1, l2, grad, indices) + >>> output = net(grad, indices) """ + __mindspore_signature__ = ( + ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) + ) + @prim_attr_register def __init__(self, use_locking=False): self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'], @@ -3595,7 +3645,8 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): - return var_shape + validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + return var_shape, accum_shape def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} @@ -3605,7 +3656,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): valid_types = [mstype.int16, mstype.int32, mstype.int64, mstype.uint16, mstype.uint32, mstype.uint64] validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) - return var_dtype + return var_dtype, accum_dtype class LARSUpdate(PrimitiveWithInfer): @@ -3858,8 +3909,8 @@ class ConfusionMulGrad(PrimitiveWithInfer): axis (Union[int, tuple[int], list[int]]): The dimensions to reduce. Default:(), reduce all dimensions. Only constant value is allowed. keep_dims (bool): - - If true, keep these reduced dimensions and the length is 1. - - If false, don't keep these dimensions. Default:False. + - If True, keep these reduced dimensions and the length is 1. + - If False, don't keep these dimensions. Default:False. Inputs: - **input_0** (Tensor) - The input Tensor.