From 4d7b7649e9ecfa58f4aedea2c6f2ff01f3405ca9 Mon Sep 17 00:00:00 2001 From: lianliguang Date: Mon, 19 Apr 2021 17:12:52 +0800 Subject: [PATCH] remove the useless PrimType --- mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc | 6 +++--- mindspore/ccsrc/pipeline/jit/validator.cc | 2 +- mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 4 ++-- mindspore/core/ir/primitive.h | 10 ++++------ tests/ut/cpp/operator/ops_test.cc | 4 ++-- 5 files changed, 12 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index f8b7699204..85e7b40c36 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -534,7 +534,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &en const AbstractBasePtrList &args) { auto prim_py = dyn_cast(prim_); if (prim_py == nullptr) { - MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive."; + MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive."; } // Call checking method 'infer_value' for python primitive MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); @@ -568,7 +568,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &en EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { auto prim_py = dyn_cast(prim_); if (prim_py == nullptr) { - MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive."; + MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive."; } // Call checking method '__check__' for subclass of 'PrimitiveWithCheck' MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); @@ -596,7 +596,7 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c } } - if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) { + if (prim_->prim_type() == PrimType::kPrimTypePyCheck) { return EvalPyCheckPrim(engine, args); } auto context = MsContext::GetInstance(); diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index 944dc9d1c0..45916fb557 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -58,7 +58,7 @@ void ValidateOperation(const AnfNodePtr &node) { MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; return; } - if (prim->prim_type() == PrimType::kPrimTypePyInferCheck) { + if (prim->prim_type() == PrimType::kPrimTypePyCheck) { MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method."; return; } diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 4933c33c56..13c9889b9c 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -469,9 +469,9 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { (void)py::enum_(*m, "prim_type", py::arithmetic()) .value("unknown", PrimType::kPrimTypeUnknown) .value("builtin", PrimType::kPrimTypeBuiltIn) - .value("py_infer_shape", PrimType::kPrimTypePyInferShape) + .value("py_infer_shape", PrimType::kPrimTypePyInfer) .value("user_custom", PrimType::kPrimTypeUserCustom) - .value("py_infer_check", PrimType::kPrimTypePyInferCheck); + .value("py_infer_check", PrimType::kPrimTypePyCheck); (void)py::class_>(*m, "Primitive_") .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_) .def(py::init()) diff --git a/mindspore/core/ir/primitive.h b/mindspore/core/ir/primitive.h index 9dbd72a7cb..eeac471977 100644 --- a/mindspore/core/ir/primitive.h +++ b/mindspore/core/ir/primitive.h @@ -32,11 +32,10 @@ namespace mindspore { enum PrimType { kPrimTypeUnknown = 0, kPrimTypeBegin = kTypeUnknown, - kPrimTypeBuiltIn, // Built-in primitive operator - kPrimTypePyInferShape, // Primitive operator defined by custom - kPrimTypePyInferTensor, // Primitive operator defined by custom + kPrimTypeBuiltIn, // Built-in primitive operator + kPrimTypePyInfer, // Primitive operator defined by custom kPrimTypeUserCustom, - kPrimTypePyInferCheck // Primitive operator with input args checking method + kPrimTypePyCheck // Primitive operator with input args checking method }; class Primitive : public Named { @@ -100,8 +99,7 @@ class Primitive : public Named { void set_prim_type(const PrimType t) { prim_type_ = t; } virtual PrimitivePtr Clone() { return std::make_shared(*this); } void set_instance_name(const std::string &s) { instance_name_ = s; } - bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } - bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } + bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInfer || prim_type_ == kPrimTypeUserCustom; } bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } PrimType prim_type() const { return prim_type_; } diff --git a/tests/ut/cpp/operator/ops_test.cc b/tests/ut/cpp/operator/ops_test.cc index b3dcc756fd..f3dc363879 100644 --- a/tests/ut/cpp/operator/ops_test.cc +++ b/tests/ut/cpp/operator/ops_test.cc @@ -382,13 +382,13 @@ TEST_F(TestOps, Conv2dAttrTest) { } TEST_F(TestOps, CustomOpAttrTest) { - Primitive prim("CustomOp", true, kPrimTypePyInferShape); + Primitive prim("CustomOp", true, kPrimTypePyInfer); prim.SetAttrs({ {"attr1", MakeValue(static_cast(3))}, {"attr2", MakeValue(static_cast(1))}, }); ASSERT_EQ(prim.name(), std::string("CustomOp")); - ASSERT_EQ(prim.prim_type(), kPrimTypePyInferShape); + ASSERT_EQ(prim.prim_type(), kPrimTypePyInfer); auto attrs = prim.attrs(); for (auto attr : attrs) {