Merge pull request !3085 from liuxiao93/adapt-ApplyCenteredRmsProptags/v0.6.0-beta
| @@ -81,6 +81,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||
| {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, | |||
| {"apply_add_sign", "apply_add_sign_d"}, | |||
| {"apply_power_sign", "apply_power_sign_d"}, | |||
| {"apply_centered_rms_prop", "apply_centered_rms_prop_d"}, | |||
| {"transpose", "transpose_d"}, | |||
| {"fill", "fill_d"}, | |||
| {"unsorted_segment_sum", "unsorted_segment_sum_d"}, | |||
| @@ -409,7 +409,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||
| {string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}, | |||
| {string(kNameAtan2), ADPT_DESC(Atan2)}, | |||
| {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, | |||
| {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}, | |||
| {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSPropD)}, | |||
| {string(kNameL2Loss), ADPT_DESC(L2Loss)}, | |||
| {string(kNameCTCLoss), ADPT_DESC(CTCLoss)}, | |||
| {string(kNameRange), ADPT_DESC(RangeD)}, | |||
| @@ -1284,12 +1284,13 @@ INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits<float>())}, | |||
| ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}}; | |||
| // ApplyCenteredRMSProp | |||
| INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, | |||
| {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, | |||
| {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; | |||
| ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}}; | |||
| // ApplyCenteredRMSPropD | |||
| INPUT_MAP(ApplyCenteredRMSPropD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, | |||
| {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, | |||
| {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; | |||
| ATTR_MAP(ApplyCenteredRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(ApplyCenteredRMSPropD) = { | |||
| {0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}}; | |||
| // L2Loss | |||
| INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}}; | |||
| @@ -486,8 +486,8 @@ DECLARE_OP_USE_OUTPUT(Atan2) | |||
| DECLARE_OP_ADAPTER(ApplyRMSPropD) | |||
| DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) | |||
| DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) | |||
| DECLARE_OP_ADAPTER(ApplyCenteredRMSProp) | |||
| DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp) | |||
| DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD) | |||
| DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD) | |||
| DECLARE_OP_ADAPTER(L2Loss) | |||
| DECLARE_OP_USE_OUTPUT(L2Loss) | |||
| DECLARE_OP_ADAPTER(CTCLoss) | |||
| @@ -13,15 +13,15 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ApplyCenteredRMSProp op""" | |||
| """ApplyCenteredRMSPropD op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("apply_centered_rms_prop.so") \ | |||
| .binfile_name("apply_centered_rms_prop_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("apply_centered_rms_prop") \ | |||
| .kernel_name("apply_centered_rms_prop_d") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "var", False, "required", "all") \ | |||
| .input(1, "mg", False, "required", "all") \ | |||
| @@ -33,34 +33,45 @@ apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \ | |||
| .input(7, "epsilon", False, "required", "all") \ | |||
| .input(8, "grad", False, "required", "all") \ | |||
| .output(0, "var", False, "required", "all") \ | |||
| .output(1, "mg", False, "required", "all") \ | |||
| .output(2, "ms", False, "required", "all") \ | |||
| .output(3, "mom", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_5HD, DataType.F16_5HD) \ | |||
| DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, | |||
| DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||
| DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, | |||
| DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||
| DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, | |||
| DataType.F32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(apply_centered_rms_prop_op_info) | |||
| def _apply_centered_rms_prop_tbe(): | |||
| """ApplyCenteredRMSProp TBE register""" | |||
| """ApplyCenteredRMSPropD TBE register""" | |||
| return | |||
| @@ -1962,6 +1962,7 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=False): | |||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||
| self.is_ascend = context.get_context("device_target") == "Ascend" | |||
| def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape, | |||
| learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): | |||
| @@ -1969,6 +1970,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): | |||
| validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) | |||
| validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) | |||
| validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) | |||
| if self.is_ascend: | |||
| return var_shape, mean_gradient_shape, mean_square_shape, moment_shape | |||
| return var_shape | |||
| def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype, | |||
| @@ -1982,6 +1985,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): | |||
| validator.check_type_same(args_rho, valid_types, self.name) | |||
| args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype} | |||
| validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True) | |||
| if self.is_ascend: | |||
| return var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype | |||
| return var_dtype | |||