| @@ -924,16 +924,16 @@ AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrLis | |||||
| auto ret = prim_eval_implement_map.find(prim); | auto ret = prim_eval_implement_map.find(prim); | ||||
| if (ret != prim_eval_implement_map.end()) { | if (ret != prim_eval_implement_map.end()) { | ||||
| // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr | // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr | ||||
| MS_EXCEPTION_IF_NULL(ret->second.infer_shape_dtype_impl_); | |||||
| MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_); | |||||
| auto infer_spec_list = RectifyAbstract(prim, args_spec_list); | auto infer_spec_list = RectifyAbstract(prim, args_spec_list); | ||||
| return ret->second.infer_shape_dtype_impl_(nullptr, prim, infer_spec_list); | |||||
| return ret->second.infer_shape_impl_(nullptr, prim, infer_spec_list); | |||||
| } else { | } else { | ||||
| // if the infer function has been not founded in the front infer map find it in the backend infer map instead | // if the infer function has been not founded in the front infer map find it in the backend infer map instead | ||||
| auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap(); | auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap(); | ||||
| auto ret_backend = prim_backend_eval_impl_map.find(prim); | auto ret_backend = prim_backend_eval_impl_map.find(prim); | ||||
| if (ret_backend != prim_backend_eval_impl_map.end()) { | if (ret_backend != prim_backend_eval_impl_map.end()) { | ||||
| MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_dtype_impl_); | |||||
| return ret_backend->second.infer_shape_dtype_impl_(nullptr, prim, args_spec_list); | |||||
| MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_); | |||||
| return ret_backend->second.infer_shape_impl_(nullptr, prim, args_spec_list); | |||||
| } | } | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() | MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() | ||||
| @@ -576,7 +576,7 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en | |||||
| prim_py->RunCheck(py_args); | prim_py->RunCheck(py_args); | ||||
| prim_->BeginRecordAddAttr(); | prim_->BeginRecordAddAttr(); | ||||
| AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args); | |||||
| AbstractBasePtr abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args); | |||||
| prim_->EndRecordAddAttr(); | prim_->EndRecordAddAttr(); | ||||
| auto added_attrs = prim_->evaluate_added_attrs(); | auto added_attrs = prim_->evaluate_added_attrs(); | ||||
| @@ -602,16 +602,21 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c | |||||
| auto context = MsContext::GetInstance(); | auto context = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context); | MS_EXCEPTION_IF_NULL(context); | ||||
| bool need_infer_value = eval_impl_.in_white_list_ || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode); | bool need_infer_value = eval_impl_.in_white_list_ || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode); | ||||
| AbstractBasePtr abs_base = nullptr; | |||||
| ValuePtr value = nullptr; | |||||
| prim_->BeginRecordAddAttr(); | prim_->BeginRecordAddAttr(); | ||||
| AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args); | |||||
| prim_->EndRecordAddAttr(); | |||||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||||
| if (need_infer_value) { | |||||
| if (eval_impl_.infer_value_func_ != nullptr) { | |||||
| auto value = eval_impl_.infer_value_func_(prim_, args, abs_base); | |||||
| abs_base->set_value(value); | |||||
| if (need_infer_value && eval_impl_.infer_value_impl_ != nullptr) { | |||||
| value = eval_impl_.infer_value_impl_(prim_, args); | |||||
| if (value != nullptr) { | |||||
| abs_base = value->ToAbstract(); | |||||
| prim_->EndRecordAddAttr(); | |||||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||||
| return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||||
| } | } | ||||
| } | } | ||||
| abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args); | |||||
| prim_->EndRecordAddAttr(); | |||||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||||
| auto eval_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | auto eval_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | ||||
| return eval_result; | return eval_result; | ||||
| } | } | ||||
| @@ -359,7 +359,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||||
| // find prim infer function in the prim function map return a standard evaluator | // find prim infer function in the prim function map return a standard evaluator | ||||
| auto eval_impl = GetPrimitiveInferImpl(prim); | auto eval_impl = GetPrimitiveInferImpl(prim); | ||||
| if (eval_impl.infer_shape_dtype_impl_ != nullptr) { | |||||
| if (eval_impl.infer_shape_impl_ != nullptr) { | |||||
| return std::make_shared<StandardPrimEvaluator>(prim, eval_impl); | return std::make_shared<StandardPrimEvaluator>(prim, eval_impl); | ||||
| } | } | ||||
| @@ -28,13 +28,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace abstract { | namespace abstract { | ||||
| using InferShapeAndTypeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, | |||||
| const AbstractBasePtrList &); | |||||
| using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &); | |||||
| using InferShapeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, | |||||
| const AbstractBasePtrList &); | |||||
| using InferValueImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &); | |||||
| struct StandardPrimitiveImplReg { | struct StandardPrimitiveImplReg { | ||||
| InferShapeAndTypeImpl infer_shape_dtype_impl_; // Implement function of Primitive | |||||
| InferValueEvalImpl infer_value_func_; // infer value of primitive | |||||
| InferShapeImpl infer_shape_impl_; // infer shape and type for ops | |||||
| InferValueImpl infer_value_impl_; // infer value for ops | |||||
| // in_white_list_ is true means this primitive can be executed by vm backend | // in_white_list_ is true means this primitive can be executed by vm backend | ||||
| // else will be optimized by frontend | // else will be optimized by frontend | ||||
| bool in_white_list_; | bool in_white_list_; | ||||
| @@ -55,8 +55,8 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard | |||||
| class RegisterStandardPrimitiveEvalHelper { | class RegisterStandardPrimitiveEvalHelper { | ||||
| public: | public: | ||||
| RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeAndTypeImpl &infer_impl, | |||||
| const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) { | |||||
| RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeImpl &infer_impl, | |||||
| const InferValueImpl &infer_value_impl, const bool is_wight_list = true) { | |||||
| const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list}; | const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list}; | ||||
| RegisterStandardPrimitiveImpl(primitive, impl_reg); | RegisterStandardPrimitiveImpl(primitive, impl_reg); | ||||
| } | } | ||||
| @@ -27,8 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args, | |||||
| const AbstractBasePtr &infer) { | |||||
| ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | auto op_name = primitive->name(); | ||||
| CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name); | CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name); | ||||
| @@ -41,7 +40,7 @@ ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<Abstra | |||||
| AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| auto value = DTypeInferValue(primitive, input_args, nullptr); | |||||
| auto value = DTypeInferValue(primitive, input_args); | |||||
| MS_EXCEPTION_IF_NULL(value); | MS_EXCEPTION_IF_NULL(value); | ||||
| auto type = value->cast<TypePtr>(); | auto type = value->cast<TypePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(type); | MS_EXCEPTION_IF_NULL(type); | ||||
| @@ -32,7 +32,7 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) { | |||||
| if (iter == infer_map.end()) { | if (iter == infer_map.end()) { | ||||
| MS_EXCEPTION(NotExistsError) << "Cannot find the " << this->name() << "infer function in the infer map!"; | MS_EXCEPTION(NotExistsError) << "Cannot find the " << this->name() << "infer function in the infer map!"; | ||||
| } | } | ||||
| auto infer_function = iter->second.infer_shape_dtype_impl_; | |||||
| auto infer_function = iter->second.infer_shape_impl_; | |||||
| return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list); | return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list); | ||||
| } | } | ||||
| @@ -45,8 +45,7 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||||
| return abs; | return abs; | ||||
| } | } | ||||
| ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args, | |||||
| const AbstractBasePtr &infer) { | |||||
| ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | auto op_name = primitive->name(); | ||||
| CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name); | CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name); | ||||
| @@ -49,9 +49,17 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP | |||||
| kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; | kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; | ||||
| return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name); | return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name); | ||||
| } | } | ||||
| ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args, | |||||
| const abstract::AbstractBasePtr &abs) { | |||||
| AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto abs = std::make_shared<abstract::AbstractTensor>(ZerosInferType(primitive, input_args), | |||||
| ZerosInferShape(primitive, input_args)); | |||||
| return abs; | |||||
| } | |||||
| ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| auto abs = ZerosInfer(nullptr, prim, input_args); | |||||
| // check | // check | ||||
| auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(abs->BuildShape())[kShape]; | auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(abs->BuildShape())[kShape]; | ||||
| auto out_type = abs->BuildType(); | auto out_type = abs->BuildType(); | ||||
| @@ -59,14 +67,6 @@ ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBas | |||||
| return TensorConstructUtils::CreateZerosTensor(out_type, out_shape); | return TensorConstructUtils::CreateZerosTensor(out_type, out_shape); | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto abs = std::make_shared<abstract::AbstractTensor>(ZerosInferType(primitive, input_args), | |||||
| ZerosInferShape(primitive, input_args)); | |||||
| return abs; | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false); | REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,7 +17,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| namespace mindspore { | namespace mindspore { | ||||
| tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) { | |||||
| tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) { | |||||
| MS_EXCEPTION_IF_NULL(type_ptr); | MS_EXCEPTION_IF_NULL(type_ptr); | ||||
| auto type_id = ExtractTypeId(type_ptr); | auto type_id = ExtractTypeId(type_ptr); | ||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape); | tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape); | ||||
| @@ -30,7 +30,7 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr | |||||
| return tensor; | return tensor; | ||||
| } | } | ||||
| tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) { | |||||
| tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) { | |||||
| MS_EXCEPTION_IF_NULL(type_ptr); | MS_EXCEPTION_IF_NULL(type_ptr); | ||||
| auto type_id = ExtractTypeId(type_ptr); | auto type_id = ExtractTypeId(type_ptr); | ||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape); | tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape); | ||||
| @@ -43,7 +43,7 @@ tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr, | |||||
| return tensor; | return tensor; | ||||
| } | } | ||||
| tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape, | |||||
| tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape, | |||||
| void *data) { | void *data) { | ||||
| MS_EXCEPTION_IF_NULL(type_ptr); | MS_EXCEPTION_IF_NULL(type_ptr); | ||||
| auto type_id = ExtractTypeId(type_ptr); | auto type_id = ExtractTypeId(type_ptr); | ||||
| @@ -51,7 +51,7 @@ tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, con | |||||
| return tensor; | return tensor; | ||||
| } | } | ||||
| TypeId TensorConstructUtils::ExtractTypeId(const TypePtr type_ptr) { | |||||
| TypeId TensorConstructUtils::ExtractTypeId(const TypePtr &type_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(type_ptr); | MS_EXCEPTION_IF_NULL(type_ptr); | ||||
| TypeId type_id; | TypeId type_id; | ||||
| if (type_ptr->isa<TensorType>()) { | if (type_ptr->isa<TensorType>()) { | ||||
| @@ -28,12 +28,12 @@ void SetTensorData(void *data, T num, size_t data_length) { | |||||
| } | } | ||||
| class TensorConstructUtils { | class TensorConstructUtils { | ||||
| public: | public: | ||||
| static tensor::TensorPtr CreateZerosTensor(const TypePtr type, const std::vector<int64_t> &shape); | |||||
| static tensor::TensorPtr CreateOnesTensor(const TypePtr type, const std::vector<int64_t> &shape); | |||||
| static tensor::TensorPtr CreateTensor(const TypePtr type, const std::vector<int64_t> &shape, void *data); | |||||
| static tensor::TensorPtr CreateZerosTensor(const TypePtr &type, const std::vector<int64_t> &shape); | |||||
| static tensor::TensorPtr CreateOnesTensor(const TypePtr &type, const std::vector<int64_t> &shape); | |||||
| static tensor::TensorPtr CreateTensor(const TypePtr &type, const std::vector<int64_t> &shape, void *data); | |||||
| private: | private: | ||||
| static TypeId ExtractTypeId(const TypePtr type); | |||||
| static TypeId ExtractTypeId(const TypePtr &type); | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ | #endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ | ||||
| @@ -87,9 +87,9 @@ TEST_F(TestAbstract, TestParseDataClass) { | |||||
| AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y}; | AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y}; | ||||
| auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord); | auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord); | ||||
| ASSERT_TRUE(nullptr != eval_impl.infer_shape_dtype_impl_); | |||||
| ASSERT_TRUE(nullptr != eval_impl.infer_shape_impl_); | |||||
| AbstractBasePtr new_cls = eval_impl.infer_shape_dtype_impl_(nullptr, prim::kPrimMakeRecord, args_list); | |||||
| AbstractBasePtr new_cls = eval_impl.infer_shape_impl_(nullptr, prim::kPrimMakeRecord, args_list); | |||||
| ASSERT_TRUE(nullptr != new_cls); | ASSERT_TRUE(nullptr != new_cls); | ||||
| } | } | ||||