| @@ -75,6 +75,8 @@ static std::map<string, string> tbe_func_adapter_map = { | |||
| {"apply_adagrad", "apply_adagrad_d"}, | |||
| {"apply_adagrad_v2", "apply_adagradv2_d"}, | |||
| {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, | |||
| {"apply_proximal_adagrad", "apply_proximal_adagrad_d"}, | |||
| {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, | |||
| {"transpose", "transpose_d"}, | |||
| {"fill", "fill_d"}, | |||
| {"unsorted_segment_sum", "unsorted_segment_sum_d"}, | |||
| @@ -391,7 +391,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||
| {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, | |||
| {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, | |||
| {string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)}, | |||
| {string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagrad)}, | |||
| {string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagradD)}, | |||
| {string(kNameAcosh), ADPT_DESC(Acosh)}, | |||
| {string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)}, | |||
| {string(kNameFloorMod), ADPT_DESC(FloorMod)}, | |||
| @@ -1170,11 +1170,11 @@ ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())}, | |||
| {"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}}; | |||
| // ApplyProximalAdagrad | |||
| INPUT_MAP(ApplyProximalAdagrad) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, | |||
| {4, INPUT_DESC(l1)}, {5, INPUT_DESC(l2)}, {6, INPUT_DESC(grad)}}; | |||
| ATTR_MAP(ApplyProximalAdagrad) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(ApplyProximalAdagrad) = {{0, OUTPUT_DESC(var)}}; | |||
| // ApplyProximalAdagradD | |||
| INPUT_MAP(ApplyProximalAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, | |||
| {4, INPUT_DESC(l1)}, {5, INPUT_DESC(l2)}, {6, INPUT_DESC(grad)}}; | |||
| ATTR_MAP(ApplyProximalAdagradD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(ApplyProximalAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; | |||
| // SparseApplyFtrlD | |||
| INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)}, | |||
| @@ -446,8 +446,8 @@ DECLARE_OP_ADAPTER(BinaryCrossEntropyGrad) | |||
| DECLARE_OP_USE_OUTPUT(BinaryCrossEntropyGrad) | |||
| DECLARE_OP_ADAPTER(SparseApplyAdagradD) | |||
| DECLARE_OP_USE_OUTPUT(SparseApplyAdagradD) | |||
| DECLARE_OP_ADAPTER(ApplyProximalAdagrad) | |||
| DECLARE_OP_USE_OUTPUT(ApplyProximalAdagrad) | |||
| DECLARE_OP_ADAPTER(ApplyProximalAdagradD) | |||
| DECLARE_OP_USE_OUTPUT(ApplyProximalAdagradD) | |||
| DECLARE_OP_ADAPTER(SpaceToDepth) | |||
| DECLARE_OP_USE_OUTPUT(SpaceToDepth) | |||
| DECLARE_OP_ADAPTER(DepthToSpace) | |||
| @@ -13,15 +13,15 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ApplyProximalAdagrad op""" | |||
| """ApplyProximalAdagradD op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| apply_proximal_adagrad_op_info = TBERegOp("ApplyProximalAdagrad") \ | |||
| apply_proximal_adagrad_d_op_info = TBERegOp("ApplyProximalAdagrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("apply_proximal_adagrad.so") \ | |||
| .binfile_name("apply_proximal_adagrad_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("apply_proximal_adagrad") \ | |||
| .kernel_name("apply_proximal_adagrad_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("use_locking", "optional", "bool", "true,false", "false") \ | |||
| .input(0, "var", False, "required", "all") \ | |||
| @@ -31,26 +31,27 @@ apply_proximal_adagrad_op_info = TBERegOp("ApplyProximalAdagrad") \ | |||
| .input(4, "l2", False, "required", "all") \ | |||
| .input(5, "grad", False, "required", "all") \ | |||
| .output(0, "var", False, "required", "all") \ | |||
| .output(1, "accum", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||
| DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||
| DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register(apply_proximal_adagrad_op_info) | |||
| @op_info_register(apply_proximal_adagrad_d_op_info) | |||
| def _apply_proximal_adagrad(): | |||
| """ApplyProximalAdagrad TBE register""" | |||
| """ApplyProximalAdagradD TBE register""" | |||
| return | |||
| @@ -13,10 +13,10 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """SparseApplyProximalAdagrad op""" | |||
| """SparseApplyProximalAdagradD op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| sparse_apply_proximal_adagrad_op_info = TBERegOp("SparseApplyProximalAdagrad") \ | |||
| sparse_apply_proximal_adagrad_d_op_info = TBERegOp("SparseApplyProximalAdagrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("sparse_apply_proximal_adagrad.so") \ | |||
| @@ -32,70 +32,101 @@ sparse_apply_proximal_adagrad_op_info = TBERegOp("SparseApplyProximalAdagrad") \ | |||
| .input(5, "grad", False, "required", "all") \ | |||
| .input(6, "indices", False, "required", "all") \ | |||
| .output(0, "var", False, "required", "all") \ | |||
| .output(1, "accum", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I16_NCHW, DataType.F32_NCHW) \ | |||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I16_NCHW, DataType.F32_NCHW, | |||
| DataType.F32_NCHW) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I16_5HD, DataType.F32_5HD) \ | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I16_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I16_NHWC, DataType.F32_NHWC) \ | |||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I16_NHWC, DataType.F32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.I16_Default, DataType.F32_Default) \ | |||
| DataType.F32_Default, DataType.F32_Default, DataType.I16_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I16_FracZ, DataType.F32_FracZ) \ | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I16_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW) \ | |||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW, | |||
| DataType.F32_NCHW) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \ | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC) \ | |||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ | |||
| DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \ | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW) \ | |||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW, | |||
| DataType.F32_NCHW) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \ | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I64_NHWC, DataType.F32_NHWC) \ | |||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I64_NHWC, DataType.F32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ | |||
| DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \ | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U16_NCHW, DataType.F32_NCHW) \ | |||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U16_NCHW, DataType.F32_NCHW, | |||
| DataType.F32_NCHW) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U16_5HD, DataType.F32_5HD) \ | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U16_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U16_NHWC, DataType.F32_NHWC) \ | |||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U16_NHWC, DataType.F32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.U16_Default, DataType.F32_Default) \ | |||
| DataType.F32_Default, DataType.F32_Default, DataType.U16_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U16_FracZ, DataType.F32_FracZ) \ | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U16_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U32_NCHW, DataType.F32_NCHW) \ | |||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U32_NCHW, DataType.F32_NCHW, | |||
| DataType.F32_NCHW) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U32_5HD, DataType.F32_5HD) \ | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U32_NHWC, DataType.F32_NHWC) \ | |||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U32_NHWC, DataType.F32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.U32_Default, DataType.F32_Default) \ | |||
| DataType.F32_Default, DataType.F32_Default, DataType.U32_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U32_FracZ, DataType.F32_FracZ) \ | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U32_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U64_NCHW, DataType.F32_NCHW) \ | |||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U64_NCHW, DataType.F32_NCHW, | |||
| DataType.F32_NCHW) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U64_5HD, DataType.F32_5HD) \ | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U64_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | |||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U64_NHWC, DataType.F32_NHWC) \ | |||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U64_NHWC, DataType.F32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.U64_Default, DataType.F32_Default) \ | |||
| DataType.F32_Default, DataType.F32_Default, DataType.U64_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U64_FracZ, DataType.F32_FracZ) \ | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U64_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register(sparse_apply_proximal_adagrad_op_info) | |||
| @op_info_register(sparse_apply_proximal_adagrad_d_op_info) | |||
| def _sparse_apply_proximal_adagrad(): | |||
| """SparseApplyProximalAdagrad TBE register""" | |||
| """SparseApplyProximalAdagradD TBE register""" | |||
| return | |||
| @@ -3142,7 +3142,7 @@ class ApplyAdaMax(PrimitiveWithInfer): | |||
| .. math:: | |||
| \begin{array}{ll} \\ | |||
| m_{t} = \beta_1 * m_{t-1} + (1 - \beta_1) * g \\ | |||
| v_{t} = \max(\beta_2 * v{t-1}, \left| g \right|) \\ | |||
| v_{t} = \max(\beta_2 * v_{t-1}, \left| g \right|) \\ | |||
| var = var - \frac{l}{1 - \beta_1^t} * \frac{m_{t}}{v_{t} + \epsilon} | |||
| \end{array} | |||
| @@ -3497,37 +3497,61 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): | |||
| .. math:: | |||
| accum += grad * grad | |||
| .. math:: | |||
| prox_v = var - lr * grad * \frac{1}{\sqrt{accum}} | |||
| \text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}} | |||
| .. math:: | |||
| var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0) | |||
| var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0) | |||
| Args: | |||
| use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. | |||
| Inputs: | |||
| - **var** (Tensor) - Variable to be updated. | |||
| - **accum** (Tensor) - Accum to be updated. The shape must be the same as `var`'s shape. | |||
| - **var** (Parameter) - Variable to be updated. The data type should be float. | |||
| - **accum** (Parameter) - Accum to be updated. Must has the same shape and dtype as `var`. | |||
| - **lr** (Union[Number, Tensor]): The learning rate value. It should be a scalar tensor or number. | |||
| The data type should be float. | |||
| - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. | |||
| It should be a scalar tensor or number. | |||
| It should be a scalar tensor or number. The data type should be float. | |||
| - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. | |||
| It should be a scalar tensor or number. | |||
| - **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape. | |||
| It should be a scalar tensor or number. The data type should be float. | |||
| - **grad** (Tensor) - Gradient. Must has the same shape and dtype as `var`. | |||
| Outputs: | |||
| Tensor, has the same shape and type as `var`. | |||
| Tuple of 2 Tensor, the updated parameters. | |||
| - **var** (Tensor) - The same shape and data type as `var`. | |||
| - **accum** (Tensor) - The same shape and data type as `accum`. | |||
| Examples: | |||
| >>> var = Tensor(np.random.random((3, 3)), mindspore.float32) | |||
| >>> accum = Tensor(np.random.random((3, 3)), mindspore.float32) | |||
| >>> grad = Tensor(np.random.random((3, 3)), mindspore.float32) | |||
| >>> lr = 0.01 | |||
| >>> l1 = 0.0 | |||
| >>> l2 = 0.0 | |||
| >>> apply_proximal_ada_grad = P.ApplyProximalAdagrad() | |||
| >>> output = apply_proximal_ada_grad(var, accum, lr, l1, l2, grad) | |||
| >>> import numpy as np | |||
| >>> import mindspore.nn as nn | |||
| >>> from mindspore import Tensor, Parameter | |||
| >>> from mindspore.ops import operations as P | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.apply_proximal_adagrad = P.ApplyProximalAdagrad() | |||
| >>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") | |||
| >>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") | |||
| >>> self.lr = 0.01 | |||
| >>> self.l1 = 0.0 | |||
| >>> self.l2 = 0.0 | |||
| >>> def construct(self, grad): | |||
| >>> out = self.apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad) | |||
| >>> return out | |||
| >>> net = Net() | |||
| >>> grad = Tensor(np.random.rand(3, 3).astype(np.float32)) | |||
| >>> output = net(grad) | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||
| ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||
| ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||
| ) | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=False): | |||
| self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['output']) | |||
| @@ -3536,7 +3560,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): | |||
| def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape): | |||
| validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | |||
| validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name) | |||
| return var_shape | |||
| return var_shape, accum_shape | |||
| def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype): | |||
| valid_types = [mstype.float16, mstype.float32] | |||
| @@ -3544,7 +3568,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): | |||
| validator.check_tensor_type_same(args, valid_types, self.name) | |||
| scalar_args = {"lr": lr_dtype, "l1": l1_dtype, "l2": l2_dtype} | |||
| validator.check_scalar_or_tensor_type_same(scalar_args, valid_types, self.name) | |||
| return var_dtype | |||
| return var_dtype, accum_dtype | |||
| class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||
| @@ -3555,39 +3579,65 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||
| .. math:: | |||
| accum += grad * grad | |||
| .. math:: | |||
| prox_v = var - lr * grad * \frac{1}{\sqrt{accum}} | |||
| \text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}} | |||
| .. math:: | |||
| var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0) | |||
| var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0) | |||
| Args: | |||
| use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. | |||
| Inputs: | |||
| - **var** (Tensor) - Variable tensor to be updated. | |||
| - **accum** (Tensor) - Variable tensor to be updated. The shape must be the same as `var`'s shape. | |||
| - **var** (Parameter) - Variable tensor to be updated. The data type must be float32. | |||
| - **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`. | |||
| - **lr** (Union[Number, Tensor]): The learning rate value. It should be a scalar tensor or number. | |||
| The data type must be float32. | |||
| - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. | |||
| It should be a scalar tensor or number. | |||
| It should be a scalar tensor or number. The data type must be float32. | |||
| - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. | |||
| It should be a scalar tensor or number. | |||
| - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. | |||
| It should be a scalar tensor or number. The data type must be float32. | |||
| - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32. | |||
| - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. | |||
| Outputs: | |||
| Tensor, has the same shape and type as `var`. | |||
| Tuple of 2 Tensor, the updated parameters. | |||
| - **var** (Tensor) - The same shape and data type as `var`. | |||
| - **accum** (Tensor) - The same shape and data type as `accum`. | |||
| Examples: | |||
| >>> var = Tensor(np.random.random((3, 3)), mindspore.float32) | |||
| >>> accum = Tensor(np.random.random((3, 3)), mindspore.float32) | |||
| >>> grad = Tensor(np.random.random((3, 3)), mindspore.float32) | |||
| >>> import numpy as np | |||
| >>> import mindspore.nn as nn | |||
| >>> from mindspore import Tensor, Parameter | |||
| >>> from mindspore.ops import operations as P | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() | |||
| >>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") | |||
| >>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") | |||
| >>> self.lr = 0.01 | |||
| >>> self.l1 = 0.0 | |||
| >>> self.l2 = 0.0 | |||
| >>> def construct(self, grad, indices): | |||
| >>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, | |||
| self.l2, grad, indices) | |||
| >>> return out | |||
| >>> net = Net() | |||
| >>> grad = Tensor(np.random.rand(3, 3).astype(np.float32)) | |||
| >>> indices = Tensor(np.ones((3,), np.int32)) | |||
| >>> lr = 0.01 | |||
| >>> l1 = 0.0 | |||
| >>> l2 = 0.0 | |||
| >>> sparse_apply_proximal_ada_grad = P.SparseApplyProximalAdagrad() | |||
| >>> output = sparse_apply_proximal_ada_grad(var, accum, lr, l1, l2, grad, indices) | |||
| >>> output = net(grad, indices) | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||
| ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||
| ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) | |||
| ) | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=False): | |||
| self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'], | |||
| @@ -3595,7 +3645,8 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||
| def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): | |||
| return var_shape | |||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||
| return var_shape, accum_shape | |||
| def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): | |||
| args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} | |||
| @@ -3605,7 +3656,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||
| valid_types = [mstype.int16, mstype.int32, mstype.int64, | |||
| mstype.uint16, mstype.uint32, mstype.uint64] | |||
| validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) | |||
| return var_dtype | |||
| return var_dtype, accum_dtype | |||
| class LARSUpdate(PrimitiveWithInfer): | |||
| @@ -3858,8 +3909,8 @@ class ConfusionMulGrad(PrimitiveWithInfer): | |||
| axis (Union[int, tuple[int], list[int]]): The dimensions to reduce. | |||
| Default:(), reduce all dimensions. Only constant value is allowed. | |||
| keep_dims (bool): | |||
| - If true, keep these reduced dimensions and the length is 1. | |||
| - If false, don't keep these dimensions. Default:False. | |||
| - If True, keep these reduced dimensions and the length is 1. | |||
| - If False, don't keep these dimensions. Default:False. | |||
| Inputs: | |||
| - **input_0** (Tensor) - The input Tensor. | |||