From d2c696667bd28d203576c8eb4b0f8941b748b1c0 Mon Sep 17 00:00:00 2001 From: lianliguang Date: Sat, 17 Apr 2021 10:23:04 +0800 Subject: [PATCH] change process of infer value --- .../ccsrc/backend/optimizer/common/helper.cc | 8 +++---- .../pipeline/jit/static_analysis/prim.cc | 21 ++++++++++++------- .../jit/static_analysis/static_analysis.cc | 2 +- mindspore/core/abstract/primitive_infer_map.h | 14 ++++++------- mindspore/core/ops/dtype.cc | 5 ++--- mindspore/core/ops/primitive_c.cc | 2 +- mindspore/core/ops/shape.cc | 3 +-- mindspore/core/ops/zeros.cc | 20 +++++++++--------- .../core/utils/tensor_construct_utils.cc | 8 +++---- mindspore/core/utils/tensor_construct_utils.h | 8 +++---- tests/ut/cpp/abstract/abstract_test.cc | 4 ++-- 11 files changed, 49 insertions(+), 46 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index b7f465b7b2..2997444d7f 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -924,16 +924,16 @@ AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrLis auto ret = prim_eval_implement_map.find(prim); if (ret != prim_eval_implement_map.end()) { // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr - MS_EXCEPTION_IF_NULL(ret->second.infer_shape_dtype_impl_); + MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_); auto infer_spec_list = RectifyAbstract(prim, args_spec_list); - return ret->second.infer_shape_dtype_impl_(nullptr, prim, infer_spec_list); + return ret->second.infer_shape_impl_(nullptr, prim, infer_spec_list); } else { // if the infer function has been not founded in the front infer map find it in the backend infer map instead auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap(); auto ret_backend = prim_backend_eval_impl_map.find(prim); if (ret_backend != prim_backend_eval_impl_map.end()) { - MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_dtype_impl_); - return ret_backend->second.infer_shape_dtype_impl_(nullptr, prim, args_spec_list); + MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_); + return ret_backend->second.infer_shape_impl_(nullptr, prim, args_spec_list); } } MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 85e7b40c36..65a92ae4f4 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -576,7 +576,7 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en prim_py->RunCheck(py_args); prim_->BeginRecordAddAttr(); - AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args); + AbstractBasePtr abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args); prim_->EndRecordAddAttr(); auto added_attrs = prim_->evaluate_added_attrs(); @@ -602,16 +602,21 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); bool need_infer_value = eval_impl_.in_white_list_ || (context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode); + AbstractBasePtr abs_base = nullptr; + ValuePtr value = nullptr; prim_->BeginRecordAddAttr(); - AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args); - prim_->EndRecordAddAttr(); - auto added_attrs = prim_->evaluate_added_attrs(); - if (need_infer_value) { - if (eval_impl_.infer_value_func_ != nullptr) { - auto value = eval_impl_.infer_value_func_(prim_, args, abs_base); - abs_base->set_value(value); + if (need_infer_value && eval_impl_.infer_value_impl_ != nullptr) { + value = eval_impl_.infer_value_impl_(prim_, args); + if (value != nullptr) { + abs_base = value->ToAbstract(); + prim_->EndRecordAddAttr(); + auto added_attrs = prim_->evaluate_added_attrs(); + return std::make_shared(abs_base, std::make_shared(added_attrs)); } } + abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args); + prim_->EndRecordAddAttr(); + auto added_attrs = prim_->evaluate_added_attrs(); auto eval_result = std::make_shared(abs_base, std::make_shared(added_attrs)); return eval_result; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 9f05e8e466..142c01dec8 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -359,7 +359,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr // find prim infer function in the prim function map return a standard evaluator auto eval_impl = GetPrimitiveInferImpl(prim); - if (eval_impl.infer_shape_dtype_impl_ != nullptr) { + if (eval_impl.infer_shape_impl_ != nullptr) { return std::make_shared(prim, eval_impl); } diff --git a/mindspore/core/abstract/primitive_infer_map.h b/mindspore/core/abstract/primitive_infer_map.h index e4b0710dbb..fe10017d98 100644 --- a/mindspore/core/abstract/primitive_infer_map.h +++ b/mindspore/core/abstract/primitive_infer_map.h @@ -28,13 +28,13 @@ namespace mindspore { namespace abstract { -using InferShapeAndTypeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &); -using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &); +using InferShapeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &); +using InferValueImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &); struct StandardPrimitiveImplReg { - InferShapeAndTypeImpl infer_shape_dtype_impl_; // Implement function of Primitive - InferValueEvalImpl infer_value_func_; // infer value of primitive + InferShapeImpl infer_shape_impl_; // infer shape and type for ops + InferValueImpl infer_value_impl_; // infer value for ops // in_white_list_ is true means this primitive can be executed by vm backend // else will be optimized by frontend bool in_white_list_; @@ -55,8 +55,8 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard class RegisterStandardPrimitiveEvalHelper { public: - RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeAndTypeImpl &infer_impl, - const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) { + RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeImpl &infer_impl, + const InferValueImpl &infer_value_impl, const bool is_wight_list = true) { const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list}; RegisterStandardPrimitiveImpl(primitive, impl_reg); } diff --git a/mindspore/core/ops/dtype.cc b/mindspore/core/ops/dtype.cc index f75976f2dd..f24c607fd4 100644 --- a/mindspore/core/ops/dtype.cc +++ b/mindspore/core/ops/dtype.cc @@ -27,8 +27,7 @@ namespace mindspore { namespace ops { -ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector &input_args, - const AbstractBasePtr &infer) { +ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name); @@ -41,7 +40,7 @@ ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector &input_args) { - auto value = DTypeInferValue(primitive, input_args, nullptr); + auto value = DTypeInferValue(primitive, input_args); MS_EXCEPTION_IF_NULL(value); auto type = value->cast(); MS_EXCEPTION_IF_NULL(type); diff --git a/mindspore/core/ops/primitive_c.cc b/mindspore/core/ops/primitive_c.cc index 94d2c5f23b..d4e5443799 100644 --- a/mindspore/core/ops/primitive_c.cc +++ b/mindspore/core/ops/primitive_c.cc @@ -32,7 +32,7 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) { if (iter == infer_map.end()) { MS_EXCEPTION(NotExistsError) << "Cannot find the " << this->name() << "infer function in the infer map!"; } - auto infer_function = iter->second.infer_shape_dtype_impl_; + auto infer_function = iter->second.infer_shape_impl_; return infer_function(nullptr, shared_from_base(), abstract_list); } diff --git a/mindspore/core/ops/shape.cc b/mindspore/core/ops/shape.cc index 35795daf24..dd61504a2d 100644 --- a/mindspore/core/ops/shape.cc +++ b/mindspore/core/ops/shape.cc @@ -45,8 +45,7 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP return abs; } -ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector &input_args, - const AbstractBasePtr &infer) { +ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name); diff --git a/mindspore/core/ops/zeros.cc b/mindspore/core/ops/zeros.cc index e72c4ce646..bd8f11f4cd 100644 --- a/mindspore/core/ops/zeros.cc +++ b/mindspore/core/ops/zeros.cc @@ -49,9 +49,17 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector &input_args, - const abstract::AbstractBasePtr &abs) { +AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto abs = std::make_shared(ZerosInferType(primitive, input_args), + ZerosInferShape(primitive, input_args)); + return abs; +} + +ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); + auto abs = ZerosInfer(nullptr, prim, input_args); // check auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(abs->BuildShape())[kShape]; auto out_type = abs->BuildType(); @@ -59,14 +67,6 @@ ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto abs = std::make_shared(ZerosInferType(primitive, input_args), - ZerosInferShape(primitive, input_args)); - return abs; -} REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/utils/tensor_construct_utils.cc b/mindspore/core/utils/tensor_construct_utils.cc index cb8bfc2549..601e7a1223 100644 --- a/mindspore/core/utils/tensor_construct_utils.cc +++ b/mindspore/core/utils/tensor_construct_utils.cc @@ -17,7 +17,7 @@ #include #include namespace mindspore { -tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr, const std::vector &shape) { +tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr &type_ptr, const std::vector &shape) { MS_EXCEPTION_IF_NULL(type_ptr); auto type_id = ExtractTypeId(type_ptr); tensor::TensorPtr tensor = std::make_shared(type_id, shape); @@ -30,7 +30,7 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr return tensor; } -tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr, const std::vector &shape) { +tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr &type_ptr, const std::vector &shape) { MS_EXCEPTION_IF_NULL(type_ptr); auto type_id = ExtractTypeId(type_ptr); tensor::TensorPtr tensor = std::make_shared(type_id, shape); @@ -43,7 +43,7 @@ tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr, return tensor; } -tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, const std::vector &shape, +tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr &type_ptr, const std::vector &shape, void *data) { MS_EXCEPTION_IF_NULL(type_ptr); auto type_id = ExtractTypeId(type_ptr); @@ -51,7 +51,7 @@ tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, con return tensor; } -TypeId TensorConstructUtils::ExtractTypeId(const TypePtr type_ptr) { +TypeId TensorConstructUtils::ExtractTypeId(const TypePtr &type_ptr) { MS_EXCEPTION_IF_NULL(type_ptr); TypeId type_id; if (type_ptr->isa()) { diff --git a/mindspore/core/utils/tensor_construct_utils.h b/mindspore/core/utils/tensor_construct_utils.h index fc8b3bebed..9a63f04240 100644 --- a/mindspore/core/utils/tensor_construct_utils.h +++ b/mindspore/core/utils/tensor_construct_utils.h @@ -28,12 +28,12 @@ void SetTensorData(void *data, T num, size_t data_length) { } class TensorConstructUtils { public: - static tensor::TensorPtr CreateZerosTensor(const TypePtr type, const std::vector &shape); - static tensor::TensorPtr CreateOnesTensor(const TypePtr type, const std::vector &shape); - static tensor::TensorPtr CreateTensor(const TypePtr type, const std::vector &shape, void *data); + static tensor::TensorPtr CreateZerosTensor(const TypePtr &type, const std::vector &shape); + static tensor::TensorPtr CreateOnesTensor(const TypePtr &type, const std::vector &shape); + static tensor::TensorPtr CreateTensor(const TypePtr &type, const std::vector &shape, void *data); private: - static TypeId ExtractTypeId(const TypePtr type); + static TypeId ExtractTypeId(const TypePtr &type); }; } // namespace mindspore #endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ diff --git a/tests/ut/cpp/abstract/abstract_test.cc b/tests/ut/cpp/abstract/abstract_test.cc index 380c94b250..bbcfa66011 100644 --- a/tests/ut/cpp/abstract/abstract_test.cc +++ b/tests/ut/cpp/abstract/abstract_test.cc @@ -87,9 +87,9 @@ TEST_F(TestAbstract, TestParseDataClass) { AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y}; auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord); - ASSERT_TRUE(nullptr != eval_impl.infer_shape_dtype_impl_); + ASSERT_TRUE(nullptr != eval_impl.infer_shape_impl_); - AbstractBasePtr new_cls = eval_impl.infer_shape_dtype_impl_(nullptr, prim::kPrimMakeRecord, args_list); + AbstractBasePtr new_cls = eval_impl.infer_shape_impl_(nullptr, prim::kPrimMakeRecord, args_list); ASSERT_TRUE(nullptr != new_cls); }