| @@ -923,14 +923,16 @@ AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrLis | |||
| auto ret = prim_eval_implement_map.find(prim); | |||
| if (ret != prim_eval_implement_map.end()) { | |||
| // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr | |||
| MS_EXCEPTION_IF_NULL(ret->second.infer_shape_dtype_impl_); | |||
| auto infer_spec_list = RectifyAbstract(prim, args_spec_list); | |||
| return ret->second.impl_(nullptr, prim, infer_spec_list); | |||
| return ret->second.infer_shape_dtype_impl_(nullptr, prim, infer_spec_list); | |||
| } else { | |||
| // if the infer function has been not founded in the front infer map find it in the backend infer map instead | |||
| auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap(); | |||
| auto ret_backend = prim_backend_eval_impl_map.find(prim); | |||
| if (ret_backend != prim_backend_eval_impl_map.end()) { | |||
| return ret_backend->second.impl_(nullptr, prim, args_spec_list); | |||
| MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_dtype_impl_); | |||
| return ret_backend->second.infer_shape_dtype_impl_(nullptr, prim, args_spec_list); | |||
| } | |||
| } | |||
| MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() | |||
| @@ -639,26 +639,26 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt | |||
| return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs, | |||
| nullptr, false); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, | |||
| InferImplBroadcastGradientArgs, nullptr); | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -59,7 +59,9 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| #define REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl) \ | |||
| static auto helper_##name = \ | |||
| abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, false); | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_ | |||
| @@ -530,28 +530,17 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic | |||
| } | |||
| } // end anonymous namespace | |||
| EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||
| EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base, | |||
| const AbstractBasePtrList &args) { | |||
| auto prim_py = dyn_cast<PrimitivePy>(prim_); | |||
| if (prim_py == nullptr) { | |||
| MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive."; | |||
| } | |||
| // Call checking method '__check__' for subclass of 'PrimitiveWithCheck' | |||
| // Call checking method 'infer_value' for python primitive | |||
| MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); | |||
| auto py_args = PreparePyInputs(prim_py, args); | |||
| prim_py->RunCheck(py_args); | |||
| prim_->BeginRecordAddAttr(); | |||
| AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); | |||
| prim_->EndRecordAddAttr(); | |||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||
| if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) { | |||
| return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||
| } | |||
| // Call method 'infer_value' for primitive with this method for constant propagation | |||
| py::tuple py_vals(py_args.size()); | |||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||
| for (size_t i = 0; i < py_args.size(); ++i) { | |||
| py_vals[i] = py_args[i][ATTR_VALUE]; | |||
| } | |||
| @@ -559,7 +548,6 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en | |||
| if (py::isinstance<py::none>(py_ret)) { | |||
| return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||
| } | |||
| // Convert pyobject to Value, then to AbstractValue | |||
| ValuePtr converted_ret = nullptr; | |||
| TypePtr dtype = abs_base->BuildType(); | |||
| @@ -577,6 +565,28 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en | |||
| return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs)); | |||
| } | |||
| EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||
| auto prim_py = dyn_cast<PrimitivePy>(prim_); | |||
| if (prim_py == nullptr) { | |||
| MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive."; | |||
| } | |||
| // Call checking method '__check__' for subclass of 'PrimitiveWithCheck' | |||
| MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); | |||
| auto py_args = PreparePyInputs(prim_py, args); | |||
| prim_py->RunCheck(py_args); | |||
| prim_->BeginRecordAddAttr(); | |||
| AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args); | |||
| prim_->EndRecordAddAttr(); | |||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||
| if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) { | |||
| return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||
| } | |||
| // Call method 'infer_value' for primitive with this method for constant propagation | |||
| return RunPyInferValue(engine, abs_base, args); | |||
| } | |||
| EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||
| if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) { | |||
| auto ret_abstract = AbstractEval(args); | |||
| @@ -589,11 +599,19 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c | |||
| if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) { | |||
| return EvalPyCheckPrim(engine, args); | |||
| } | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool need_infer_value = eval_impl_.in_white_list_ || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode); | |||
| prim_->BeginRecordAddAttr(); | |||
| AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); | |||
| AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args); | |||
| prim_->EndRecordAddAttr(); | |||
| auto added_attrs = prim_->evaluate_added_attrs(); | |||
| if (need_infer_value) { | |||
| if (eval_impl_.infer_value_func_ != nullptr) { | |||
| auto value = eval_impl_.infer_value_func_(prim_, args, abs_base); | |||
| abs_base->set_value(value); | |||
| } | |||
| } | |||
| auto eval_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs)); | |||
| return eval_result; | |||
| } | |||
| @@ -617,7 +635,6 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs | |||
| auto added_attrs = prim_py_->evaluate_added_attrs(); | |||
| MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); | |||
| auto res_spec = PyInferRes2Abstract(prim_py_, output); | |||
| MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; | |||
| auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs)); | |||
| (*evaluator_cache_map_)[args] = infer_result; | |||
| @@ -689,7 +706,7 @@ ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const { | |||
| // Primitive implementation | |||
| // static function start | |||
| namespace { | |||
| EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveEvalImpl eval_impl) { | |||
| EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveImplReg eval_impl) { | |||
| EvaluatorPtr prim_evaluator = std::make_shared<StandardPrimEvaluator>(primitive, eval_impl); | |||
| return prim_evaluator; | |||
| } | |||
| @@ -1279,7 +1296,7 @@ void InitPrimEvaluatorConstructors() { | |||
| PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; | |||
| for (const auto &iter : GetPrimitiveToEvalImplMap()) { | |||
| constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_); | |||
| constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second); | |||
| } | |||
| for (const auto &iter : GetUniformPrimitiveToImplMap()) { | |||
| @@ -32,7 +32,7 @@ namespace mindspore { | |||
| namespace abstract { | |||
| class StandardPrimEvaluator : public TrivialPrimEvaluator { | |||
| public: | |||
| StandardPrimEvaluator(const PrimitivePtr primitive, StandardPrimitiveEvalImpl eval_impl) | |||
| StandardPrimEvaluator(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &eval_impl) | |||
| : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {} | |||
| ~StandardPrimEvaluator() override = default; | |||
| MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator); | |||
| @@ -43,9 +43,10 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator { | |||
| private: | |||
| EvalResultPtr EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args); | |||
| EvalResultPtr RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base, | |||
| const AbstractBasePtrList &args); | |||
| PrimitivePtr prim_; | |||
| const StandardPrimitiveEvalImpl eval_impl_; | |||
| const StandardPrimitiveImplReg eval_impl_; | |||
| }; | |||
| using StandardPrimEvaluatorPtr = std::shared_ptr<StandardPrimEvaluator>; | |||
| @@ -358,8 +358,8 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||
| } | |||
| // find prim infer function in the prim function map return a standard evaluator | |||
| StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); | |||
| if (eval_impl != nullptr) { | |||
| auto eval_impl = GetPrimitiveInferImpl(prim); | |||
| if (eval_impl.infer_shape_dtype_impl_ != nullptr) { | |||
| return std::make_shared<StandardPrimEvaluator>(prim, eval_impl); | |||
| } | |||
| @@ -213,13 +213,13 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { | |||
| return prim_backend_eval_implement_map; | |||
| } | |||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { | |||
| StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto iter = GetPrimitiveToEvalImplMap().find(primitive); | |||
| if (iter == GetPrimitiveToEvalImplMap().end()) { | |||
| return nullptr; | |||
| return {nullptr, nullptr, false}; | |||
| } | |||
| return iter->second.impl_; | |||
| return iter->second; | |||
| } | |||
| void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) { | |||
| @@ -19,21 +19,24 @@ | |||
| #define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ir/primitive.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "base/core_ops.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "ir/anf.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &); | |||
| using InferShapeAndTypeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &); | |||
| using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &); | |||
| struct StandardPrimitiveImplReg { | |||
| StandardPrimitiveEvalImpl impl_; // Implement function of Primitive | |||
| InferValueEvalImpl infer_value_func_; // infer value of primitive | |||
| // true means this primitive can be executed by vm backend else will be constant folded by frontend | |||
| InferShapeAndTypeImpl infer_shape_dtype_impl_; // Implement function of Primitive | |||
| InferValueEvalImpl infer_value_func_; // infer value of primitive | |||
| // in_white_list_ is true means this primitive can be executed by vm backend | |||
| // else will be optimized by frontend | |||
| bool in_white_list_; | |||
| }; | |||
| @@ -44,7 +47,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap(); | |||
| PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap(); | |||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); | |||
| StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive); | |||
| std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode); | |||
| @@ -52,17 +55,22 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard | |||
| class RegisterStandardPrimitiveEvalHelper { | |||
| public: | |||
| RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl, | |||
| RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeAndTypeImpl &infer_impl, | |||
| const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) { | |||
| const StandardPrimitiveImplReg impl_reg{impl, infer_value_impl, is_wight_list}; | |||
| const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list}; | |||
| RegisterStandardPrimitiveImpl(primitive, impl_reg); | |||
| } | |||
| ~RegisterStandardPrimitiveEvalHelper() = default; | |||
| }; | |||
| #define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl, infer_value_impl, is_wight_list) \ | |||
| static auto helper_##name = \ | |||
| abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl, infer_value_impl, is_wight_list) | |||
| #define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl, is_wight_list) \ | |||
| static auto helper_##name = \ | |||
| abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_wight_list); \ | |||
| std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() { \ | |||
| auto out = std::make_shared<name>(); \ | |||
| return out; \ | |||
| } \ | |||
| ops::OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name); | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | |||
| @@ -70,6 +70,5 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv | |||
| return abs; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, false); | |||
| REGISTER_PRIMITIVE_C(kNameGatherD, GatherD); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -32,7 +32,7 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) { | |||
| if (iter == infer_map.end()) { | |||
| MS_EXCEPTION(NotExistsError) << "Cannot find the " << this->name() << "infer function in the infer map!"; | |||
| } | |||
| auto infer_function = iter->second.impl_; | |||
| auto infer_function = iter->second.infer_shape_dtype_impl_; | |||
| return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list); | |||
| } | |||
| @@ -51,6 +51,5 @@ AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr | |||
| return std::make_shared<abstract::AbstractTensor>(kInt32, ScalarSummaryInferShape(primitive, input_args)); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true); | |||
| REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -81,6 +81,5 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv | |||
| SoftMaxInferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer, nullptr, true); | |||
| REGISTER_PRIMITIVE_C(kNameSoftmax, Softmax); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -51,6 +51,5 @@ AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr | |||
| return std::make_shared<abstract::AbstractTensor>(kInt32, TensorSummaryInferShape(primitive, input_args)); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true); | |||
| REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -66,10 +66,8 @@ AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto abs = std::make_shared<abstract::AbstractTensor>(ZerosInferType(primitive, input_args), | |||
| ZerosInferShape(primitive, input_args)); | |||
| abs->set_value(ZerosInferValue(primitive, input_args, abs)); | |||
| return abs; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false); | |||
| REGISTER_PRIMITIVE_C(kNameZeros, Zeros); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -16,6 +16,7 @@ | |||
| #ifndef MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ | |||
| #define MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "ir/tensor.h" | |||
| namespace mindspore { | |||
| template <typename T> | |||
| @@ -23,10 +24,7 @@ void SetTensorData(void *data, T num, size_t data_length) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| auto tensor_data = reinterpret_cast<T *>(data); | |||
| MS_EXCEPTION_IF_NULL(tensor_data); | |||
| for (size_t index = 0; index < data_length; ++index) { | |||
| *tensor_data = num; | |||
| ++tensor_data; | |||
| } | |||
| std::fill(tensor_data, tensor_data + data_length, num); | |||
| } | |||
| class TensorConstructUtils { | |||
| public: | |||
| @@ -86,10 +86,10 @@ TEST_F(TestAbstract, TestParseDataClass) { | |||
| AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y}; | |||
| StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord); | |||
| ASSERT_TRUE(nullptr != eval_impl); | |||
| auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord); | |||
| ASSERT_TRUE(nullptr != eval_impl.infer_shape_dtype_impl_); | |||
| AbstractBasePtr new_cls = eval_impl(nullptr, prim::kPrimMakeRecord, args_list); | |||
| AbstractBasePtr new_cls = eval_impl.infer_shape_dtype_impl_(nullptr, prim::kPrimMakeRecord, args_list); | |||
| ASSERT_TRUE(nullptr != new_cls); | |||
| } | |||
| @@ -47,8 +47,8 @@ AbstractBasePtr InferImplScalarAddStub(const AnalysisEnginePtr &engine, const Pr | |||
| } | |||
| EvaluatorPtr InitPrimitiveScalarAddEvaluatorStub() { | |||
| EvaluatorPtr PrimitiveScalarAddEvaluator = | |||
| std::make_shared<StandardPrimEvaluator>(prim::kPrimScalarAdd, InferImplScalarAddStub); | |||
| EvaluatorPtr PrimitiveScalarAddEvaluator = std::make_shared<StandardPrimEvaluator>( | |||
| prim::kPrimScalarAdd, StandardPrimitiveImplReg{InferImplScalarAddStub, nullptr, true}); | |||
| return PrimitiveScalarAddEvaluator; | |||
| } | |||
| @@ -63,8 +63,8 @@ AbstractBasePtr InferImplReturnStub(const AnalysisEnginePtr &engine, const Primi | |||
| } | |||
| EvaluatorPtr InitPrimitiveReturnEvaluatorStub() { | |||
| EvaluatorPtr PrimitiveReturnEvaluator = | |||
| std::make_shared<StandardPrimEvaluator>(prim::kPrimReturn, InferImplReturnStub); | |||
| EvaluatorPtr PrimitiveReturnEvaluator = std::make_shared<StandardPrimEvaluator>( | |||
| prim::kPrimReturn, StandardPrimitiveImplReg{InferImplReturnStub, nullptr, true}); | |||
| return PrimitiveReturnEvaluator; | |||
| } | |||
| @@ -396,7 +396,6 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) { | |||
| ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt64); | |||
| } | |||
| class TestEvalOnePrim : public UT::Common { | |||
| public: | |||
| TestEvalOnePrim() : getPyFun("gtest_input.pipeline.infer.infer_test", true), engine_(nullptr) {} | |||
| @@ -435,7 +434,7 @@ class TestGraphEval : public UT::Common { | |||
| UT::PyFuncGraphFetcher getPyFun; | |||
| }; | |||
| void TestGraphEval::SetUp() { engine_ = SetupAnalysisEngine(); } | |||
| void TestGraphEval::SetUp() { engine_ = SetupAnalysisEngine(); } | |||
| void TestGraphEval::TearDown() { | |||
| // destroy resource | |||
| @@ -25,6 +25,14 @@ | |||
| #include "common/common_test.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TestAttr : public ops::PrimitiveC { | |||
| public: | |||
| TestAttr() : PrimitiveC("") {} | |||
| }; | |||
| class TestDynamicInput : public ops::PrimitiveC { | |||
| public: | |||
| TestDynamicInput() : PrimitiveC("") {} | |||
| }; | |||
| constexpr auto kAttrConvertTestName = "attr_convert_test"; | |||
| constexpr auto kDynamicInputTestName = "dynamic_input_test"; | |||
| inline const PrimitivePtr kPrimAttrConvertTest = std::make_shared<Primitive>(kAttrConvertTestName); | |||