Merge pull request !3088 from lianliguang/primitive-decouplingtags/v0.6.0-beta
| @@ -523,14 +523,8 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs | |||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| auto py_args = PreparePyInputs(prim_py_, args); | auto py_args = PreparePyInputs(prim_py_, args); | ||||
| auto pyobj = prim_py_->GetPyObj(); | |||||
| if (pyobj == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty"; | |||||
| } | |||||
| auto infer_fuc = pyobj.attr("__infer__"); | |||||
| prim_py_->BeginRecordAddAttr(); | prim_py_->BeginRecordAddAttr(); | ||||
| py::dict output = infer_fuc(*py_args); | |||||
| py::dict output = prim_py_->RunInfer(py_args); | |||||
| prim_py_->EndRecordAddAttr(); | prim_py_->EndRecordAddAttr(); | ||||
| auto added_attrs = prim_py_->evaluate_added_attrs(); | auto added_attrs = prim_py_->evaluate_added_attrs(); | ||||
| MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); | MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); | ||||
| @@ -654,17 +654,7 @@ static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, c | |||||
| } | } | ||||
| } | } | ||||
| if (!is_attr_same) { | if (!is_attr_same) { | ||||
| if (prim->isa<PrimitivePy>()) { | |||||
| PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>(); | |||||
| auto clone_fn = prim_py->GetPyObj().attr("_clone"); | |||||
| py::object new_obj = clone_fn(); | |||||
| auto cloned_prim = new_obj.cast<PrimitivePyPtr>(); | |||||
| for (auto &item : *attrs) { | |||||
| cloned_prim->AddAttr(item.first, item.second); | |||||
| } | |||||
| return cloned_prim; | |||||
| } | |||||
| auto cloned_prim = std::make_shared<Primitive>(*prim); | |||||
| auto cloned_prim = prim->Clone(); | |||||
| for (auto &item : *attrs) { | for (auto &item : *attrs) { | ||||
| cloned_prim->AddAttr(item.first, item.second); | cloned_prim->AddAttr(item.first, item.second); | ||||
| } | } | ||||
| @@ -280,8 +280,8 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn | |||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| for (size_t i = 0; i < size; i++) { | for (size_t i = 0; i < size; i++) { | ||||
| ValuePtr input_value = PyAttrValue(py_args[i]); | ValuePtr input_value = PyAttrValue(py_args[i]); | ||||
| args_spec_list.emplace_back(abstract::FromValueInside( | |||||
| input_value, !py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>())); | |||||
| args_spec_list.emplace_back( | |||||
| abstract::FromValueInside(input_value, !prim->ObjHasAttr("const_value") && input_value->isa<tensor::Tensor>())); | |||||
| } | } | ||||
| AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); | AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); | ||||
| op_exec_info->abstract = infer_res; | op_exec_info->abstract = infer_res; | ||||
| @@ -296,8 +296,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args) | |||||
| MS_EXCEPTION_IF_NULL(op_exec_info); | MS_EXCEPTION_IF_NULL(op_exec_info); | ||||
| op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]); | op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]); | ||||
| auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]); | auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]); | ||||
| auto pyobj = prim->GetPyObj(); | |||||
| if (pyobj == nullptr) { | |||||
| if (!prim->HasPyObj()) { | |||||
| MS_LOG(EXCEPTION) << "pyobj is empty"; | MS_LOG(EXCEPTION) << "pyobj is empty"; | ||||
| } | } | ||||
| @@ -708,7 +707,7 @@ py::tuple RunOpInner(const py::args &args) { | |||||
| value_ret[0] = output["value"]; | value_ret[0] = output["value"]; | ||||
| return value_ret; | return value_ret; | ||||
| } | } | ||||
| if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) { | |||||
| if (op_exec_info->py_primitive->ObjHasAttr("const_value")) { | |||||
| py::tuple value_ret(1); | py::tuple value_ret(1); | ||||
| value_ret[0] = ""; | value_ret[0] = ""; | ||||
| return value_ret; | return value_ret; | ||||
| @@ -100,6 +100,7 @@ class Primitive : public Named { | |||||
| return !(iter == attrs_.cend()); | return !(iter == attrs_.cend()); | ||||
| } | } | ||||
| void set_prim_type(const PrimType t) { prim_type_ = t; } | void set_prim_type(const PrimType t) { prim_type_ = t; } | ||||
| virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); } | |||||
| void set_instance_name(const std::string s) { instance_name_ = s; } | void set_instance_name(const std::string s) { instance_name_ = s; } | ||||
| bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } | bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } | ||||
| bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } | bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } | ||||
| @@ -196,6 +196,21 @@ bool PrimitivePy::HasComputeFunction() const { | |||||
| return true; | return true; | ||||
| } | } | ||||
| PrimitivePtr PrimitivePy::Clone() { | |||||
| auto clone_fn = python_obj_.attr("_clone"); | |||||
| py::object new_obj = clone_fn(); | |||||
| auto cloned_prim = new_obj.cast<PrimitivePyPtr>(); | |||||
| return cloned_prim; | |||||
| } | |||||
| py::dict PrimitivePy::RunInfer(const py::tuple &args) { | |||||
| if (!HasPyObj()) { | |||||
| MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; | |||||
| } | |||||
| auto infer_fuc = python_obj_.attr("__infer__"); | |||||
| return infer_fuc(*args); | |||||
| } | |||||
| REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | ||||
| (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | ||||
| .value("unknown", PrimType::kPrimTypeUnknown) | .value("unknown", PrimType::kPrimTypeUnknown) | ||||
| @@ -61,6 +61,10 @@ class PrimitivePy : public Primitive { | |||||
| bool HasComputeFunction() const; | bool HasComputeFunction() const; | ||||
| const bool parse_info_ = true; | const bool parse_info_ = true; | ||||
| const py::object &GetPyObj() const { return python_obj_; } | const py::object &GetPyObj() const { return python_obj_; } | ||||
| py::dict RunInfer(const py::tuple &args); | |||||
| bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); } | |||||
| bool HasPyObj() { return python_obj_ != nullptr; } | |||||
| PrimitivePtr Clone() override; | |||||
| bool is_tuple_input_ = false; | bool is_tuple_input_ = false; | ||||
| private: | private: | ||||