Browse Source

!15385 remove the useless prim type

From: @lianliguang
Reviewed-by: @ginfung,@zh_qh,@ginfung
Signed-off-by: @zh_qh
pull/15311/head
mindspore-ci-bot Gitee 4 years ago
parent
commit
490d2e1efb
5 changed files with 12 additions and 14 deletions
  1. +3
    -3
      mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
  2. +1
    -1
      mindspore/ccsrc/pipeline/jit/validator.cc
  3. +2
    -2
      mindspore/ccsrc/pybind_api/ir/primitive_py.cc
  4. +4
    -6
      mindspore/core/ir/primitive.h
  5. +2
    -2
      tests/ut/cpp/operator/ops_test.cc

+ 3
- 3
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc View File

@@ -534,7 +534,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &en
const AbstractBasePtrList &args) {
auto prim_py = dyn_cast<PrimitivePy>(prim_);
if (prim_py == nullptr) {
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
}
// Call checking method 'infer_value' for python primitive
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
@@ -568,7 +568,7 @@ EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &en
EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
auto prim_py = dyn_cast<PrimitivePy>(prim_);
if (prim_py == nullptr) {
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
}
// Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
@@ -596,7 +596,7 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
}
}

if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) {
if (prim_->prim_type() == PrimType::kPrimTypePyCheck) {
return EvalPyCheckPrim(engine, args);
}
auto context = MsContext::GetInstance();


+ 1
- 1
mindspore/ccsrc/pipeline/jit/validator.cc View File

@@ -58,7 +58,7 @@ void ValidateOperation(const AnfNodePtr &node) {
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
return;
}
if (prim->prim_type() == PrimType::kPrimTypePyInferCheck) {
if (prim->prim_type() == PrimType::kPrimTypePyCheck) {
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method.";
return;
}


+ 2
- 2
mindspore/ccsrc/pybind_api/ir/primitive_py.cc View File

@@ -469,9 +469,9 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
.value("unknown", PrimType::kPrimTypeUnknown)
.value("builtin", PrimType::kPrimTypeBuiltIn)
.value("py_infer_shape", PrimType::kPrimTypePyInferShape)
.value("py_infer_shape", PrimType::kPrimTypePyInfer)
.value("user_custom", PrimType::kPrimTypeUserCustom)
.value("py_infer_check", PrimType::kPrimTypePyInferCheck);
.value("py_infer_check", PrimType::kPrimTypePyCheck);
(void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_")
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_)
.def(py::init<py::str &>())


+ 4
- 6
mindspore/core/ir/primitive.h View File

@@ -32,11 +32,10 @@ namespace mindspore {
enum PrimType {
kPrimTypeUnknown = 0,
kPrimTypeBegin = kTypeUnknown,
kPrimTypeBuiltIn, // Built-in primitive operator
kPrimTypePyInferShape, // Primitive operator defined by custom
kPrimTypePyInferTensor, // Primitive operator defined by custom
kPrimTypeBuiltIn, // Built-in primitive operator
kPrimTypePyInfer, // Primitive operator defined by custom
kPrimTypeUserCustom,
kPrimTypePyInferCheck // Primitive operator with input args checking method
kPrimTypePyCheck // Primitive operator with input args checking method
};

class Primitive : public Named {
@@ -100,8 +99,7 @@ class Primitive : public Named {
void set_prim_type(const PrimType t) { prim_type_ = t; }
virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); }
void set_instance_name(const std::string &s) { instance_name_ = s; }
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInfer || prim_type_ == kPrimTypeUserCustom; }
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }

PrimType prim_type() const { return prim_type_; }


+ 2
- 2
tests/ut/cpp/operator/ops_test.cc View File

@@ -382,13 +382,13 @@ TEST_F(TestOps, Conv2dAttrTest) {
}

TEST_F(TestOps, CustomOpAttrTest) {
Primitive prim("CustomOp", true, kPrimTypePyInferShape);
Primitive prim("CustomOp", true, kPrimTypePyInfer);
prim.SetAttrs({
{"attr1", MakeValue(static_cast<int64_t>(3))},
{"attr2", MakeValue(static_cast<int64_t>(1))},
});
ASSERT_EQ(prim.name(), std::string("CustomOp"));
ASSERT_EQ(prim.prim_type(), kPrimTypePyInferShape);
ASSERT_EQ(prim.prim_type(), kPrimTypePyInfer);

auto attrs = prim.attrs();
for (auto attr : attrs) {


Loading…
Cancel
Save