| @@ -924,16 +924,16 @@ AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrLis | |||
| auto ret = prim_eval_implement_map.find(prim); | |||
| if (ret != prim_eval_implement_map.end()) { | |||
| // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr | |||
| MS_EXCEPTION_IF_NULL(ret->second.infer_shape_dtype_impl_); | |||
| MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_); | |||
| auto infer_spec_list = RectifyAbstract(prim, args_spec_list); | |||
| return ret->second.infer_shape_dtype_impl_(nullptr, prim, infer_spec_list); | |||
| return ret->second.infer_shape_impl_(nullptr, prim, infer_spec_list); | |||
| } else { | |||
| // if the infer function has been not founded in the front infer map find it in the backend infer map instead | |||
| auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap(); | |||
| auto ret_backend = prim_backend_eval_impl_map.find(prim); | |||
| if (ret_backend != prim_backend_eval_impl_map.end()) { | |||
| MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_dtype_impl_); | |||
| return ret_backend->second.infer_shape_dtype_impl_(nullptr, prim, args_spec_list); | |||
| MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_); | |||
| return ret_backend->second.infer_shape_impl_(nullptr, prim, args_spec_list); | |||
| } | |||
| } | |||
| MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() | |||
| @@ -576,7 +576,7 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en | |||
| prim_py->RunCheck(py_args); | |||
| prim_->BeginRecordAddAttr(); | |||
| AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args); | |||
| AbstractBasePtr abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args); | |||
| prim_->EndRecordAddAttr(); | |||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||
| @@ -602,16 +602,21 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool need_infer_value = eval_impl_.in_white_list_ || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode); | |||
| AbstractBasePtr abs_base = nullptr; | |||
| ValuePtr value = nullptr; | |||
| prim_->BeginRecordAddAttr(); | |||
| AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args); | |||
| prim_->EndRecordAddAttr(); | |||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||
| if (need_infer_value) { | |||
| if (eval_impl_.infer_value_func_ != nullptr) { | |||
| auto value = eval_impl_.infer_value_func_(prim_, args, abs_base); | |||
| abs_base->set_value(value); | |||
| if (need_infer_value && eval_impl_.infer_value_impl_ != nullptr) { | |||
| value = eval_impl_.infer_value_impl_(prim_, args); | |||
| if (value != nullptr) { | |||
| abs_base = value->ToAbstract(); | |||
| prim_->EndRecordAddAttr(); | |||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||
| return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||
| } | |||
| } | |||
| abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args); | |||
| prim_->EndRecordAddAttr(); | |||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||
| auto eval_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||
| return eval_result; | |||
| } | |||
| @@ -359,7 +359,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||
| // find prim infer function in the prim function map return a standard evaluator | |||
| auto eval_impl = GetPrimitiveInferImpl(prim); | |||
| if (eval_impl.infer_shape_dtype_impl_ != nullptr) { | |||
| if (eval_impl.infer_shape_impl_ != nullptr) { | |||
| return std::make_shared<StandardPrimEvaluator>(prim, eval_impl); | |||
| } | |||
| @@ -28,13 +28,13 @@ | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| using InferShapeAndTypeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &); | |||
| using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &); | |||
| using InferShapeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &); | |||
| using InferValueImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &); | |||
| struct StandardPrimitiveImplReg { | |||
| InferShapeAndTypeImpl infer_shape_dtype_impl_; // Implement function of Primitive | |||
| InferValueEvalImpl infer_value_func_; // infer value of primitive | |||
| InferShapeImpl infer_shape_impl_; // infer shape and type for ops | |||
| InferValueImpl infer_value_impl_; // infer value for ops | |||
| // in_white_list_ is true means this primitive can be executed by vm backend | |||
| // else will be optimized by frontend | |||
| bool in_white_list_; | |||
| @@ -55,8 +55,8 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard | |||
| class RegisterStandardPrimitiveEvalHelper { | |||
| public: | |||
| RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeAndTypeImpl &infer_impl, | |||
| const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) { | |||
| RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeImpl &infer_impl, | |||
| const InferValueImpl &infer_value_impl, const bool is_wight_list = true) { | |||
| const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list}; | |||
| RegisterStandardPrimitiveImpl(primitive, impl_reg); | |||
| } | |||
| @@ -27,8 +27,7 @@ | |||
| namespace mindspore { | |||
| namespace ops { | |||
| ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args, | |||
| const AbstractBasePtr &infer) { | |||
| ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto op_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name); | |||
| @@ -41,7 +40,7 @@ ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<Abstra | |||
| AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| auto value = DTypeInferValue(primitive, input_args, nullptr); | |||
| auto value = DTypeInferValue(primitive, input_args); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| auto type = value->cast<TypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| @@ -32,7 +32,7 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) { | |||
| if (iter == infer_map.end()) { | |||
| MS_EXCEPTION(NotExistsError) << "Cannot find the " << this->name() << "infer function in the infer map!"; | |||
| } | |||
| auto infer_function = iter->second.infer_shape_dtype_impl_; | |||
| auto infer_function = iter->second.infer_shape_impl_; | |||
| return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list); | |||
| } | |||
| @@ -45,8 +45,7 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||
| return abs; | |||
| } | |||
| ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args, | |||
| const AbstractBasePtr &infer) { | |||
| ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto op_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name); | |||
| @@ -49,9 +49,17 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP | |||
| kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; | |||
| return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name); | |||
| } | |||
| ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args, | |||
| const abstract::AbstractBasePtr &abs) { | |||
| AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto abs = std::make_shared<abstract::AbstractTensor>(ZerosInferType(primitive, input_args), | |||
| ZerosInferShape(primitive, input_args)); | |||
| return abs; | |||
| } | |||
| ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto abs = ZerosInfer(nullptr, prim, input_args); | |||
| // check | |||
| auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(abs->BuildShape())[kShape]; | |||
| auto out_type = abs->BuildType(); | |||
| @@ -59,14 +67,6 @@ ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBas | |||
| return TensorConstructUtils::CreateZerosTensor(out_type, out_shape); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto abs = std::make_shared<abstract::AbstractTensor>(ZerosInferType(primitive, input_args), | |||
| ZerosInferShape(primitive, input_args)); | |||
| return abs; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -17,7 +17,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| namespace mindspore { | |||
| tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) { | |||
| tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) { | |||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||
| auto type_id = ExtractTypeId(type_ptr); | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape); | |||
| @@ -30,7 +30,7 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr | |||
| return tensor; | |||
| } | |||
| tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) { | |||
| tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) { | |||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||
| auto type_id = ExtractTypeId(type_ptr); | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape); | |||
| @@ -43,7 +43,7 @@ tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr, | |||
| return tensor; | |||
| } | |||
| tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape, | |||
| tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape, | |||
| void *data) { | |||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||
| auto type_id = ExtractTypeId(type_ptr); | |||
| @@ -51,7 +51,7 @@ tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, con | |||
| return tensor; | |||
| } | |||
| TypeId TensorConstructUtils::ExtractTypeId(const TypePtr type_ptr) { | |||
| TypeId TensorConstructUtils::ExtractTypeId(const TypePtr &type_ptr) { | |||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||
| TypeId type_id; | |||
| if (type_ptr->isa<TensorType>()) { | |||
| @@ -28,12 +28,12 @@ void SetTensorData(void *data, T num, size_t data_length) { | |||
| } | |||
| class TensorConstructUtils { | |||
| public: | |||
| static tensor::TensorPtr CreateZerosTensor(const TypePtr type, const std::vector<int64_t> &shape); | |||
| static tensor::TensorPtr CreateOnesTensor(const TypePtr type, const std::vector<int64_t> &shape); | |||
| static tensor::TensorPtr CreateTensor(const TypePtr type, const std::vector<int64_t> &shape, void *data); | |||
| static tensor::TensorPtr CreateZerosTensor(const TypePtr &type, const std::vector<int64_t> &shape); | |||
| static tensor::TensorPtr CreateOnesTensor(const TypePtr &type, const std::vector<int64_t> &shape); | |||
| static tensor::TensorPtr CreateTensor(const TypePtr &type, const std::vector<int64_t> &shape, void *data); | |||
| private: | |||
| static TypeId ExtractTypeId(const TypePtr type); | |||
| static TypeId ExtractTypeId(const TypePtr &type); | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ | |||
| @@ -87,9 +87,9 @@ TEST_F(TestAbstract, TestParseDataClass) { | |||
| AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y}; | |||
| auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord); | |||
| ASSERT_TRUE(nullptr != eval_impl.infer_shape_dtype_impl_); | |||
| ASSERT_TRUE(nullptr != eval_impl.infer_shape_impl_); | |||
| AbstractBasePtr new_cls = eval_impl.infer_shape_dtype_impl_(nullptr, prim::kPrimMakeRecord, args_list); | |||
| AbstractBasePtr new_cls = eval_impl.infer_shape_impl_(nullptr, prim::kPrimMakeRecord, args_list); | |||
| ASSERT_TRUE(nullptr != new_cls); | |||
| } | |||