From: @lianliguang Reviewed-by: @ginfung,@zh_qh,@ginfung Signed-off-by: @zh_qhpull/15311/head
| @@ -534,7 +534,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &en | |||||
| const AbstractBasePtrList &args) { | const AbstractBasePtrList &args) { | ||||
| auto prim_py = dyn_cast<PrimitivePy>(prim_); | auto prim_py = dyn_cast<PrimitivePy>(prim_); | ||||
| if (prim_py == nullptr) { | if (prim_py == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive."; | |||||
| MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive."; | |||||
| } | } | ||||
| // Call checking method 'infer_value' for python primitive | // Call checking method 'infer_value' for python primitive | ||||
| MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); | MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); | ||||
| @@ -568,7 +568,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &en | |||||
| EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | ||||
| auto prim_py = dyn_cast<PrimitivePy>(prim_); | auto prim_py = dyn_cast<PrimitivePy>(prim_); | ||||
| if (prim_py == nullptr) { | if (prim_py == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive."; | |||||
| MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive."; | |||||
| } | } | ||||
| // Call checking method '__check__' for subclass of 'PrimitiveWithCheck' | // Call checking method '__check__' for subclass of 'PrimitiveWithCheck' | ||||
| MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); | MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); | ||||
| @@ -596,7 +596,7 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c | |||||
| } | } | ||||
| } | } | ||||
| if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) { | |||||
| if (prim_->prim_type() == PrimType::kPrimTypePyCheck) { | |||||
| return EvalPyCheckPrim(engine, args); | return EvalPyCheckPrim(engine, args); | ||||
| } | } | ||||
| auto context = MsContext::GetInstance(); | auto context = MsContext::GetInstance(); | ||||
| @@ -58,7 +58,7 @@ void ValidateOperation(const AnfNodePtr &node) { | |||||
| MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; | MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; | ||||
| return; | return; | ||||
| } | } | ||||
| if (prim->prim_type() == PrimType::kPrimTypePyInferCheck) { | |||||
| if (prim->prim_type() == PrimType::kPrimTypePyCheck) { | |||||
| MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method."; | MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method."; | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -469,9 +469,9 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | |||||
| (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | ||||
| .value("unknown", PrimType::kPrimTypeUnknown) | .value("unknown", PrimType::kPrimTypeUnknown) | ||||
| .value("builtin", PrimType::kPrimTypeBuiltIn) | .value("builtin", PrimType::kPrimTypeBuiltIn) | ||||
| .value("py_infer_shape", PrimType::kPrimTypePyInferShape) | |||||
| .value("py_infer_shape", PrimType::kPrimTypePyInfer) | |||||
| .value("user_custom", PrimType::kPrimTypeUserCustom) | .value("user_custom", PrimType::kPrimTypeUserCustom) | ||||
| .value("py_infer_check", PrimType::kPrimTypePyInferCheck); | |||||
| .value("py_infer_check", PrimType::kPrimTypePyCheck); | |||||
| (void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_") | (void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_") | ||||
| .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_) | .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_) | ||||
| .def(py::init<py::str &>()) | .def(py::init<py::str &>()) | ||||
| @@ -32,11 +32,10 @@ namespace mindspore { | |||||
| enum PrimType { | enum PrimType { | ||||
| kPrimTypeUnknown = 0, | kPrimTypeUnknown = 0, | ||||
| kPrimTypeBegin = kTypeUnknown, | kPrimTypeBegin = kTypeUnknown, | ||||
| kPrimTypeBuiltIn, // Built-in primitive operator | |||||
| kPrimTypePyInferShape, // Primitive operator defined by custom | |||||
| kPrimTypePyInferTensor, // Primitive operator defined by custom | |||||
| kPrimTypeBuiltIn, // Built-in primitive operator | |||||
| kPrimTypePyInfer, // Primitive operator defined by custom | |||||
| kPrimTypeUserCustom, | kPrimTypeUserCustom, | ||||
| kPrimTypePyInferCheck // Primitive operator with input args checking method | |||||
| kPrimTypePyCheck // Primitive operator with input args checking method | |||||
| }; | }; | ||||
| class Primitive : public Named { | class Primitive : public Named { | ||||
| @@ -100,8 +99,7 @@ class Primitive : public Named { | |||||
| void set_prim_type(const PrimType t) { prim_type_ = t; } | void set_prim_type(const PrimType t) { prim_type_ = t; } | ||||
| virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); } | virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); } | ||||
| void set_instance_name(const std::string &s) { instance_name_ = s; } | void set_instance_name(const std::string &s) { instance_name_ = s; } | ||||
| bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } | |||||
| bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } | |||||
| bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInfer || prim_type_ == kPrimTypeUserCustom; } | |||||
| bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } | bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } | ||||
| PrimType prim_type() const { return prim_type_; } | PrimType prim_type() const { return prim_type_; } | ||||
| @@ -382,13 +382,13 @@ TEST_F(TestOps, Conv2dAttrTest) { | |||||
| } | } | ||||
| TEST_F(TestOps, CustomOpAttrTest) { | TEST_F(TestOps, CustomOpAttrTest) { | ||||
| Primitive prim("CustomOp", true, kPrimTypePyInferShape); | |||||
| Primitive prim("CustomOp", true, kPrimTypePyInfer); | |||||
| prim.SetAttrs({ | prim.SetAttrs({ | ||||
| {"attr1", MakeValue(static_cast<int64_t>(3))}, | {"attr1", MakeValue(static_cast<int64_t>(3))}, | ||||
| {"attr2", MakeValue(static_cast<int64_t>(1))}, | {"attr2", MakeValue(static_cast<int64_t>(1))}, | ||||
| }); | }); | ||||
| ASSERT_EQ(prim.name(), std::string("CustomOp")); | ASSERT_EQ(prim.name(), std::string("CustomOp")); | ||||
| ASSERT_EQ(prim.prim_type(), kPrimTypePyInferShape); | |||||
| ASSERT_EQ(prim.prim_type(), kPrimTypePyInfer); | |||||
| auto attrs = prim.attrs(); | auto attrs = prim.attrs(); | ||||
| for (auto attr : attrs) { | for (auto attr : attrs) { | ||||