[feat] [assistant] [I48OBG] add new operator ApplyAdamD [feat] [assistant] [I48OBG] add new operator ApplyAdamD [feat] [assistant] [I48OBG] add new operator ApplyAdamD [feat] [assistant] [I48OBG] add new operator ApplyAdamD [feat] [assistant] [I48OBG] add new operator ApplyAdamD [feat] [assistant] [I48OBG] add new operator ApplyAdamD [feat] [assistant] [I48OBG] add new operator ApplyAdamD [feat] [assistant] [I48OBG] add new operator ApplyAdamD [feat] [assistant] [I48OBG] add new operator ApplyAdamD [feat] [assistant] [I48OBG] add new operator ApplyAdamDtags/v1.6.0
| @@ -15,41 +15,68 @@ | |||
| */ | |||
| #include "ops/adam.h" | |||
| #include <set> | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| TuplePtr AdamInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| const int64_t input_num = 10; | |||
| CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim_name); | |||
| // infer shape | |||
| auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape]; | |||
| auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape]; | |||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape]; | |||
| auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->GetShapeTrack())[kShape]; | |||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name); | |||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name); | |||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name); | |||
| // infer type | |||
| auto var_type = input_args[kInputIndex0]->BuildType(); | |||
| auto m_type = input_args[kInputIndex1]->BuildType(); | |||
| auto v_type = input_args[kInputIndex2]->BuildType(); | |||
| auto beta1_power_type = input_args[kInputIndex3]->BuildType(); | |||
| auto beta2_power_type = input_args[kInputIndex4]->BuildType(); | |||
| auto lr_type = input_args[kInputIndex5]->BuildType(); | |||
| auto beta1_type = input_args[kInputIndex6]->BuildType(); | |||
| auto beta2_type = input_args[kInputIndex7]->BuildType(); | |||
| auto epsilon_type = input_args[kInputIndex8]->BuildType(); | |||
| auto grad_type = input_args[kInputIndex9]->BuildType(); | |||
| auto infer_var_type = CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name); | |||
| auto infer_m_type = CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name); | |||
| auto infer_v_type = CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name); | |||
| (void)CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name); | |||
| auto output0 = std::make_shared<abstract::AbstractTensor>(infer_var_type, var_shape); | |||
| auto output1 = std::make_shared<abstract::AbstractTensor>(infer_m_type, m_shape); | |||
| auto output2 = std::make_shared<abstract::AbstractTensor>(infer_v_type, v_shape); | |||
| AbstractBasePtrList output = {output0, output1, output2}; | |||
| return std::make_shared<abstract::AbstractTuple>(output); | |||
| std::map<std::string, TypePtr> type_dict; | |||
| type_dict.emplace("var", var_type); | |||
| type_dict.emplace("m", m_type); | |||
| type_dict.emplace("v", v_type); | |||
| type_dict.emplace("grad", grad_type); | |||
| std::set<TypePtr> num_type = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, | |||
| kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128}; | |||
| (void)CheckAndConvertUtils::CheckTensorTypeSame(type_dict, num_type, prim_name); | |||
| std::map<std::string, TypePtr> type_dict1; | |||
| type_dict1.emplace("beta1_power", beta1_power_type); | |||
| type_dict1.emplace("beta2_power", beta2_power_type); | |||
| type_dict1.emplace("lr", lr_type); | |||
| type_dict1.emplace("beta1", beta1_type); | |||
| type_dict1.emplace("beta2", beta2_type); | |||
| type_dict1.emplace("epsilon", epsilon_type); | |||
| std::set<TypePtr> float_set = {kFloat16, kFloat32}; | |||
| (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(type_dict1, float_set, prim_name, true); | |||
| return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type, v_type}); | |||
| } | |||
| abstract::TupleShapePtr AdamInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| auto var_shape_ptr = input_args[kInputIndex0]->BuildShape(); | |||
| auto m_shape_ptr = input_args[kInputIndex1]->BuildShape(); | |||
| auto v_shape_ptr = input_args[kInputIndex2]->BuildShape(); | |||
| auto grad_shape_ptr = input_args[kInputIndex9]->BuildShape(); | |||
| if (var_shape_ptr->IsDynamic() || m_shape_ptr->IsDynamic() || v_shape_ptr->IsDynamic() || | |||
| grad_shape_ptr->IsDynamic()) { | |||
| return std::make_shared<abstract::TupleShape>( | |||
| std::vector<abstract::BaseShapePtr>{var_shape_ptr, m_shape_ptr, v_shape_ptr}); | |||
| } | |||
| auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; | |||
| auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; | |||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; | |||
| auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->BuildShape())[kShape]; | |||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name); | |||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name); | |||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name); | |||
| return std::make_shared<abstract::TupleShape>( | |||
| std::vector<abstract::BaseShapePtr>{var_shape_ptr, m_shape_ptr, v_shape_ptr}); | |||
| } | |||
| } // namespace | |||
| void Adam::Init(const bool use_locking, const bool use_nesterov) { | |||
| @@ -73,8 +100,13 @@ bool Adam::get_use_nesterov() const { | |||
| AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(AdamInfer(primitive, input_args)); | |||
| auto prim_name = primitive->name(); | |||
| const int64_t kInputNum = 10; | |||
| CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name); | |||
| auto infer_shape = AdamInferShape(primitive, input_args); | |||
| auto infer_type = AdamInferType(primitive, input_args); | |||
| return abstract::MakeAbstract(infer_shape, infer_type); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameAdam, Adam); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Adam, prim::kPrimAdam, AdamInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -53,6 +53,7 @@ class MS_CORE_API Adam : public PrimitiveC { | |||
| }; | |||
| AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using kPrimAdamPtr = std::shared_ptr<Adam>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -39,6 +39,7 @@ from .apply_ftrl import _apply_ftrl_tbe | |||
| from .apply_keras_momentum import _apply_keras_momentum_tbe | |||
| from .apply_momentum import _apply_momentum_tbe | |||
| from .apply_adam import _apply_adam_tbe | |||
| from .apply_adam_ds import _apply_adam_ds_tbe | |||
| from .apply_ada_max import _apply_ada_max_tbe | |||
| from .apply_adadelta import _apply_adadelta_tbe | |||
| from .apply_adagrad import _apply_adagrad_tbe | |||
| @@ -0,0 +1,80 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ApplyAdam op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| apply_adam_ds_op_info = TBERegOp("Adam") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("apply_adam.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("apply_adam_d") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True) \ | |||
| .attr("use_locking", "optional", "bool", "true,false", "false") \ | |||
| .attr("use_nesterov", "optional", "bool", "true,false", "false") \ | |||
| .input(0, "var", False, "required", "all") \ | |||
| .input(1, "m", False, "required", "all") \ | |||
| .input(2, "v", False, "required", "all") \ | |||
| .input(3, "beta1_power", False, "required", "all") \ | |||
| .input(4, "beta2_power", False, "required", "all") \ | |||
| .input(5, "lr", False, "required", "all") \ | |||
| .input(6, "beta1", False, "required", "all") \ | |||
| .input(7, "beta2", False, "required", "all") \ | |||
| .input(8, "epsilon", False, "required", "all") \ | |||
| .input(9, "grad", False, "required", "all") \ | |||
| .output(0, "var", False, "required", "all") \ | |||
| .output(1, "m", False, "required", "all") \ | |||
| .output(2, "v", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, | |||
| DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, | |||
| DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, | |||
| DataType.F32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | |||
| DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register(apply_adam_ds_op_info) | |||
| def _apply_adam_ds_tbe(): | |||
| """ApplyAdam TBE register""" | |||
| return | |||
| @@ -4368,7 +4368,7 @@ class ROIAlign(PrimitiveWithInfer): | |||
| return inputs_type | |||
| class Adam(PrimitiveWithInfer): | |||
| class Adam(Primitive): | |||
| r""" | |||
| Updates gradients by the Adaptive Moment Estimation (Adam) algorithm. | |||
| @@ -4464,23 +4464,6 @@ class Adam(PrimitiveWithInfer): | |||
| validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name) | |||
| self.add_prim_attr('side_effect_mem', True) | |||
| def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape, | |||
| beta1_shape, beta2_shape, epsilon_shape, grad_shape): | |||
| validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) | |||
| validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) | |||
| validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) | |||
| return var_shape, m_shape, v_shape | |||
| def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, | |||
| beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): | |||
| args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, | |||
| "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} | |||
| validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True) | |||
| return var_dtype, m_dtype, v_dtype | |||
| class AdamWeightDecay(PrimitiveWithInfer): | |||
| r""" | |||