| @@ -75,6 +75,8 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"apply_adagrad", "apply_adagrad_d"}, | {"apply_adagrad", "apply_adagrad_d"}, | ||||
| {"apply_adagrad_v2", "apply_adagradv2_d"}, | {"apply_adagrad_v2", "apply_adagradv2_d"}, | ||||
| {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, | {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, | ||||
| {"apply_proximal_adagrad", "apply_proximal_adagrad_d"}, | |||||
| {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, | |||||
| {"transpose", "transpose_d"}, | {"transpose", "transpose_d"}, | ||||
| {"fill", "fill_d"}, | {"fill", "fill_d"}, | ||||
| {"unsorted_segment_sum", "unsorted_segment_sum_d"}, | {"unsorted_segment_sum", "unsorted_segment_sum_d"}, | ||||
| @@ -391,7 +391,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, | {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, | ||||
| {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, | {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, | ||||
| {string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)}, | {string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)}, | ||||
| {string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagrad)}, | |||||
| {string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagradD)}, | |||||
| {string(kNameAcosh), ADPT_DESC(Acosh)}, | {string(kNameAcosh), ADPT_DESC(Acosh)}, | ||||
| {string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)}, | {string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)}, | ||||
| {string(kNameFloorMod), ADPT_DESC(FloorMod)}, | {string(kNameFloorMod), ADPT_DESC(FloorMod)}, | ||||
| @@ -1170,11 +1170,11 @@ ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())}, | |||||
| {"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | {"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}}; | OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}}; | ||||
| // ApplyProximalAdagrad | |||||
| INPUT_MAP(ApplyProximalAdagrad) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, | |||||
| {4, INPUT_DESC(l1)}, {5, INPUT_DESC(l2)}, {6, INPUT_DESC(grad)}}; | |||||
| ATTR_MAP(ApplyProximalAdagrad) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ApplyProximalAdagrad) = {{0, OUTPUT_DESC(var)}}; | |||||
| // ApplyProximalAdagradD | |||||
| INPUT_MAP(ApplyProximalAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, | |||||
| {4, INPUT_DESC(l1)}, {5, INPUT_DESC(l2)}, {6, INPUT_DESC(grad)}}; | |||||
| ATTR_MAP(ApplyProximalAdagradD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ApplyProximalAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; | |||||
| // SparseApplyFtrlD | // SparseApplyFtrlD | ||||
| INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)}, | INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)}, | ||||
| @@ -446,8 +446,8 @@ DECLARE_OP_ADAPTER(BinaryCrossEntropyGrad) | |||||
| DECLARE_OP_USE_OUTPUT(BinaryCrossEntropyGrad) | DECLARE_OP_USE_OUTPUT(BinaryCrossEntropyGrad) | ||||
| DECLARE_OP_ADAPTER(SparseApplyAdagradD) | DECLARE_OP_ADAPTER(SparseApplyAdagradD) | ||||
| DECLARE_OP_USE_OUTPUT(SparseApplyAdagradD) | DECLARE_OP_USE_OUTPUT(SparseApplyAdagradD) | ||||
| DECLARE_OP_ADAPTER(ApplyProximalAdagrad) | |||||
| DECLARE_OP_USE_OUTPUT(ApplyProximalAdagrad) | |||||
| DECLARE_OP_ADAPTER(ApplyProximalAdagradD) | |||||
| DECLARE_OP_USE_OUTPUT(ApplyProximalAdagradD) | |||||
| DECLARE_OP_ADAPTER(SpaceToDepth) | DECLARE_OP_ADAPTER(SpaceToDepth) | ||||
| DECLARE_OP_USE_OUTPUT(SpaceToDepth) | DECLARE_OP_USE_OUTPUT(SpaceToDepth) | ||||
| DECLARE_OP_ADAPTER(DepthToSpace) | DECLARE_OP_ADAPTER(DepthToSpace) | ||||
| @@ -13,15 +13,15 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """ApplyProximalAdagrad op""" | |||||
| """ApplyProximalAdagradD op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | ||||
| apply_proximal_adagrad_op_info = TBERegOp("ApplyProximalAdagrad") \ | |||||
| apply_proximal_adagrad_d_op_info = TBERegOp("ApplyProximalAdagrad") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .async_flag(False) \ | .async_flag(False) \ | ||||
| .binfile_name("apply_proximal_adagrad.so") \ | |||||
| .binfile_name("apply_proximal_adagrad_d.so") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("apply_proximal_adagrad") \ | |||||
| .kernel_name("apply_proximal_adagrad_d") \ | |||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .attr("use_locking", "optional", "bool", "true,false", "false") \ | .attr("use_locking", "optional", "bool", "true,false", "false") \ | ||||
| .input(0, "var", False, "required", "all") \ | .input(0, "var", False, "required", "all") \ | ||||
| @@ -31,26 +31,27 @@ apply_proximal_adagrad_op_info = TBERegOp("ApplyProximalAdagrad") \ | |||||
| .input(4, "l2", False, "required", "all") \ | .input(4, "l2", False, "required", "all") \ | ||||
| .input(5, "grad", False, "required", "all") \ | .input(5, "grad", False, "required", "all") \ | ||||
| .output(0, "var", False, "required", "all") \ | .output(0, "var", False, "required", "all") \ | ||||
| .output(1, "accum", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, | .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, | .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||||
| DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, | .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||||
| DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||||
| DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||||
| DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @op_info_register(apply_proximal_adagrad_op_info) | |||||
| @op_info_register(apply_proximal_adagrad_d_op_info) | |||||
| def _apply_proximal_adagrad(): | def _apply_proximal_adagrad(): | ||||
| """ApplyProximalAdagrad TBE register""" | |||||
| """ApplyProximalAdagradD TBE register""" | |||||
| return | return | ||||
| @@ -13,10 +13,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """SparseApplyProximalAdagrad op""" | |||||
| """SparseApplyProximalAdagradD op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | ||||
| sparse_apply_proximal_adagrad_op_info = TBERegOp("SparseApplyProximalAdagrad") \ | |||||
| sparse_apply_proximal_adagrad_d_op_info = TBERegOp("SparseApplyProximalAdagrad") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .async_flag(False) \ | .async_flag(False) \ | ||||
| .binfile_name("sparse_apply_proximal_adagrad.so") \ | .binfile_name("sparse_apply_proximal_adagrad.so") \ | ||||
| @@ -32,70 +32,101 @@ sparse_apply_proximal_adagrad_op_info = TBERegOp("SparseApplyProximalAdagrad") \ | |||||
| .input(5, "grad", False, "required", "all") \ | .input(5, "grad", False, "required", "all") \ | ||||
| .input(6, "indices", False, "required", "all") \ | .input(6, "indices", False, "required", "all") \ | ||||
| .output(0, "var", False, "required", "all") \ | .output(0, "var", False, "required", "all") \ | ||||
| .output(1, "accum", False, "required", "all") \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | ||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I16_NCHW, DataType.F32_NCHW) \ | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I16_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I16_5HD, DataType.F32_5HD) \ | |||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I16_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | ||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I16_NHWC, DataType.F32_NHWC) \ | |||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I16_NHWC, DataType.F32_NHWC, | |||||
| DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_Default, DataType.I16_Default, DataType.F32_Default) \ | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.I16_Default, DataType.F32_Default, | |||||
| DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | ||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I16_FracZ, DataType.F32_FracZ) \ | |||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I16_FracZ, DataType.F32_FracZ, | |||||
| DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | ||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW) \ | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \ | |||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | ||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC) \ | |||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC, | |||||
| DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | ||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \ | |||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ, | |||||
| DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | ||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW) \ | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \ | |||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | ||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I64_NHWC, DataType.F32_NHWC) \ | |||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.I64_NHWC, DataType.F32_NHWC, | |||||
| DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, | |||||
| DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | ||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \ | |||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ, | |||||
| DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | ||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U16_NCHW, DataType.F32_NCHW) \ | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U16_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U16_5HD, DataType.F32_5HD) \ | |||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U16_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | ||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U16_NHWC, DataType.F32_NHWC) \ | |||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U16_NHWC, DataType.F32_NHWC, | |||||
| DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_Default, DataType.U16_Default, DataType.F32_Default) \ | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.U16_Default, DataType.F32_Default, | |||||
| DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | ||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U16_FracZ, DataType.F32_FracZ) \ | |||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U16_FracZ, DataType.F32_FracZ, | |||||
| DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | ||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U32_NCHW, DataType.F32_NCHW) \ | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U32_5HD, DataType.F32_5HD) \ | |||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U32_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | ||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U32_NHWC, DataType.F32_NHWC) \ | |||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U32_NHWC, DataType.F32_NHWC, | |||||
| DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_Default, DataType.U32_Default, DataType.F32_Default) \ | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.U32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | ||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U32_FracZ, DataType.F32_FracZ) \ | |||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U32_FracZ, DataType.F32_FracZ, | |||||
| DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | ||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U64_NCHW, DataType.F32_NCHW) \ | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.U64_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U64_5HD, DataType.F32_5HD) \ | |||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.U64_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | ||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U64_NHWC, DataType.F32_NHWC) \ | |||||
| DataType.F32_NHWC, DataType.F32_NHWC, DataType.U64_NHWC, DataType.F32_NHWC, | |||||
| DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_Default, DataType.U64_Default, DataType.F32_Default) \ | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.U64_Default, DataType.F32_Default, | |||||
| DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | ||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U64_FracZ, DataType.F32_FracZ) \ | |||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.U64_FracZ, DataType.F32_FracZ, | |||||
| DataType.F32_FracZ) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @op_info_register(sparse_apply_proximal_adagrad_op_info) | |||||
| @op_info_register(sparse_apply_proximal_adagrad_d_op_info) | |||||
| def _sparse_apply_proximal_adagrad(): | def _sparse_apply_proximal_adagrad(): | ||||
| """SparseApplyProximalAdagrad TBE register""" | |||||
| """SparseApplyProximalAdagradD TBE register""" | |||||
| return | return | ||||
| @@ -3142,7 +3142,7 @@ class ApplyAdaMax(PrimitiveWithInfer): | |||||
| .. math:: | .. math:: | ||||
| \begin{array}{ll} \\ | \begin{array}{ll} \\ | ||||
| m_{t} = \beta_1 * m_{t-1} + (1 - \beta_1) * g \\ | m_{t} = \beta_1 * m_{t-1} + (1 - \beta_1) * g \\ | ||||
| v_{t} = \max(\beta_2 * v{t-1}, \left| g \right|) \\ | |||||
| v_{t} = \max(\beta_2 * v_{t-1}, \left| g \right|) \\ | |||||
| var = var - \frac{l}{1 - \beta_1^t} * \frac{m_{t}}{v_{t} + \epsilon} | var = var - \frac{l}{1 - \beta_1^t} * \frac{m_{t}}{v_{t} + \epsilon} | ||||
| \end{array} | \end{array} | ||||
| @@ -3497,37 +3497,61 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| .. math:: | .. math:: | ||||
| accum += grad * grad | accum += grad * grad | ||||
| .. math:: | .. math:: | ||||
| prox_v = var - lr * grad * \frac{1}{\sqrt{accum}} | |||||
| \text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}} | |||||
| .. math:: | .. math:: | ||||
| var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0) | |||||
| var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0) | |||||
| Args: | Args: | ||||
| use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. | use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. | ||||
| Inputs: | Inputs: | ||||
| - **var** (Tensor) - Variable to be updated. | |||||
| - **accum** (Tensor) - Accum to be updated. The shape must be the same as `var`'s shape. | |||||
| - **var** (Parameter) - Variable to be updated. The data type should be float. | |||||
| - **accum** (Parameter) - Accum to be updated. Must has the same shape and dtype as `var`. | |||||
| - **lr** (Union[Number, Tensor]): The learning rate value. It should be a scalar tensor or number. | - **lr** (Union[Number, Tensor]): The learning rate value. It should be a scalar tensor or number. | ||||
| The data type should be float. | |||||
| - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. | - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. | ||||
| It should be a scalar tensor or number. | |||||
| It should be a scalar tensor or number. The data type should be float. | |||||
| - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. | - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. | ||||
| It should be a scalar tensor or number. | |||||
| - **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape. | |||||
| It should be a scalar tensor or number. The data type should be float. | |||||
| - **grad** (Tensor) - Gradient. Must has the same shape and dtype as `var`. | |||||
| Outputs: | Outputs: | ||||
| Tensor, has the same shape and type as `var`. | |||||
| Tuple of 2 Tensor, the updated parameters. | |||||
| - **var** (Tensor) - The same shape and data type as `var`. | |||||
| - **accum** (Tensor) - The same shape and data type as `accum`. | |||||
| Examples: | Examples: | ||||
| >>> var = Tensor(np.random.random((3, 3)), mindspore.float32) | |||||
| >>> accum = Tensor(np.random.random((3, 3)), mindspore.float32) | |||||
| >>> grad = Tensor(np.random.random((3, 3)), mindspore.float32) | |||||
| >>> lr = 0.01 | |||||
| >>> l1 = 0.0 | |||||
| >>> l2 = 0.0 | |||||
| >>> apply_proximal_ada_grad = P.ApplyProximalAdagrad() | |||||
| >>> output = apply_proximal_ada_grad(var, accum, lr, l1, l2, grad) | |||||
| >>> import numpy as np | |||||
| >>> import mindspore.nn as nn | |||||
| >>> from mindspore import Tensor, Parameter | |||||
| >>> from mindspore.ops import operations as P | |||||
| >>> class Net(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(Net, self).__init__() | |||||
| >>> self.apply_proximal_adagrad = P.ApplyProximalAdagrad() | |||||
| >>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") | |||||
| >>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") | |||||
| >>> self.lr = 0.01 | |||||
| >>> self.l1 = 0.0 | |||||
| >>> self.l2 = 0.0 | |||||
| >>> def construct(self, grad): | |||||
| >>> out = self.apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad) | |||||
| >>> return out | |||||
| >>> net = Net() | |||||
| >>> grad = Tensor(np.random.rand(3, 3).astype(np.float32)) | |||||
| >>> output = net(grad) | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | |||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||||
| ) | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, use_locking=False): | def __init__(self, use_locking=False): | ||||
| self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['output']) | self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['output']) | ||||
| @@ -3536,7 +3560,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape): | def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape): | ||||
| validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | ||||
| validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name) | validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name) | ||||
| return var_shape | |||||
| return var_shape, accum_shape | |||||
| def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype): | def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype): | ||||
| valid_types = [mstype.float16, mstype.float32] | valid_types = [mstype.float16, mstype.float32] | ||||
| @@ -3544,7 +3568,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| validator.check_tensor_type_same(args, valid_types, self.name) | validator.check_tensor_type_same(args, valid_types, self.name) | ||||
| scalar_args = {"lr": lr_dtype, "l1": l1_dtype, "l2": l2_dtype} | scalar_args = {"lr": lr_dtype, "l1": l1_dtype, "l2": l2_dtype} | ||||
| validator.check_scalar_or_tensor_type_same(scalar_args, valid_types, self.name) | validator.check_scalar_or_tensor_type_same(scalar_args, valid_types, self.name) | ||||
| return var_dtype | |||||
| return var_dtype, accum_dtype | |||||
| class SparseApplyProximalAdagrad(PrimitiveWithInfer): | class SparseApplyProximalAdagrad(PrimitiveWithInfer): | ||||
| @@ -3555,39 +3579,65 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| .. math:: | .. math:: | ||||
| accum += grad * grad | accum += grad * grad | ||||
| .. math:: | .. math:: | ||||
| prox_v = var - lr * grad * \frac{1}{\sqrt{accum}} | |||||
| \text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}} | |||||
| .. math:: | .. math:: | ||||
| var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0) | |||||
| var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0) | |||||
| Args: | Args: | ||||
| use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. | use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. | ||||
| Inputs: | Inputs: | ||||
| - **var** (Tensor) - Variable tensor to be updated. | |||||
| - **accum** (Tensor) - Variable tensor to be updated. The shape must be the same as `var`'s shape. | |||||
| - **var** (Parameter) - Variable tensor to be updated. The data type must be float32. | |||||
| - **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`. | |||||
| - **lr** (Union[Number, Tensor]): The learning rate value. It should be a scalar tensor or number. | - **lr** (Union[Number, Tensor]): The learning rate value. It should be a scalar tensor or number. | ||||
| The data type must be float32. | |||||
| - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. | - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. | ||||
| It should be a scalar tensor or number. | |||||
| It should be a scalar tensor or number. The data type must be float32. | |||||
| - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. | - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. | ||||
| It should be a scalar tensor or number. | |||||
| - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. | |||||
| It should be a scalar tensor or number. The data type must be float32. | |||||
| - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32. | |||||
| - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. | - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. | ||||
| Outputs: | Outputs: | ||||
| Tensor, has the same shape and type as `var`. | |||||
| Tuple of 2 Tensor, the updated parameters. | |||||
| - **var** (Tensor) - The same shape and data type as `var`. | |||||
| - **accum** (Tensor) - The same shape and data type as `accum`. | |||||
| Examples: | Examples: | ||||
| >>> var = Tensor(np.random.random((3, 3)), mindspore.float32) | |||||
| >>> accum = Tensor(np.random.random((3, 3)), mindspore.float32) | |||||
| >>> grad = Tensor(np.random.random((3, 3)), mindspore.float32) | |||||
| >>> import numpy as np | |||||
| >>> import mindspore.nn as nn | |||||
| >>> from mindspore import Tensor, Parameter | |||||
| >>> from mindspore.ops import operations as P | |||||
| >>> class Net(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(Net, self).__init__() | |||||
| >>> self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() | |||||
| >>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") | |||||
| >>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") | |||||
| >>> self.lr = 0.01 | |||||
| >>> self.l1 = 0.0 | |||||
| >>> self.l2 = 0.0 | |||||
| >>> def construct(self, grad, indices): | |||||
| >>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, | |||||
| self.l2, grad, indices) | |||||
| >>> return out | |||||
| >>> net = Net() | |||||
| >>> grad = Tensor(np.random.rand(3, 3).astype(np.float32)) | |||||
| >>> indices = Tensor(np.ones((3,), np.int32)) | >>> indices = Tensor(np.ones((3,), np.int32)) | ||||
| >>> lr = 0.01 | |||||
| >>> l1 = 0.0 | |||||
| >>> l2 = 0.0 | |||||
| >>> sparse_apply_proximal_ada_grad = P.SparseApplyProximalAdagrad() | |||||
| >>> output = sparse_apply_proximal_ada_grad(var, accum, lr, l1, l2, grad, indices) | |||||
| >>> output = net(grad, indices) | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | |||||
| ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) | |||||
| ) | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, use_locking=False): | def __init__(self, use_locking=False): | ||||
| self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'], | self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'], | ||||
| @@ -3595,7 +3645,8 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | ||||
| def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): | def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): | ||||
| return var_shape | |||||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | |||||
| return var_shape, accum_shape | |||||
| def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): | def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): | ||||
| args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} | args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} | ||||
| @@ -3605,7 +3656,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer): | |||||
| valid_types = [mstype.int16, mstype.int32, mstype.int64, | valid_types = [mstype.int16, mstype.int32, mstype.int64, | ||||
| mstype.uint16, mstype.uint32, mstype.uint64] | mstype.uint16, mstype.uint32, mstype.uint64] | ||||
| validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) | validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) | ||||
| return var_dtype | |||||
| return var_dtype, accum_dtype | |||||
| class LARSUpdate(PrimitiveWithInfer): | class LARSUpdate(PrimitiveWithInfer): | ||||
| @@ -3858,8 +3909,8 @@ class ConfusionMulGrad(PrimitiveWithInfer): | |||||
| axis (Union[int, tuple[int], list[int]]): The dimensions to reduce. | axis (Union[int, tuple[int], list[int]]): The dimensions to reduce. | ||||
| Default:(), reduce all dimensions. Only constant value is allowed. | Default:(), reduce all dimensions. Only constant value is allowed. | ||||
| keep_dims (bool): | keep_dims (bool): | ||||
| - If true, keep these reduced dimensions and the length is 1. | |||||
| - If false, don't keep these dimensions. Default:False. | |||||
| - If True, keep these reduced dimensions and the length is 1. | |||||
| - If False, don't keep these dimensions. Default:False. | |||||
| Inputs: | Inputs: | ||||
| - **input_0** (Tensor) - The input Tensor. | - **input_0** (Tensor) - The input Tensor. | ||||