| @@ -307,7 +307,7 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||
| if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { | |||
| kernel_info->SetFeatureMapFlag(true); | |||
| } | |||
| if (AnfAlgo::IsRealCNodeKernel(cnode)) { | |||
| if (AnfAlgo::IsRealKernel(cnode)) { | |||
| AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode); | |||
| AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); | |||
| } | |||
| @@ -363,19 +363,21 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat | |||
| MS_LOG(INFO) << "RunOpInVM end"; | |||
| return std::move(result); | |||
| } | |||
| auto func = op_exec_info->py_primitive->GetComputeFunction(); | |||
| if (py::isinstance<py::none>(func)) { | |||
| MS_LOG(ERROR) << "VM failed to get func"; | |||
| auto primitive = op_exec_info->py_primitive; | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto result = primitive->RunPyComputeFunction(op_exec_info->op_inputs); | |||
| if (py::isinstance<py::none>(result)) { | |||
| MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func"; | |||
| *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; | |||
| py::tuple err_ret(0); | |||
| return std::move(err_ret); | |||
| } | |||
| // execute op | |||
| py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs)); | |||
| py::tuple tuple_result = py::make_tuple(result); | |||
| *status = PYNATIVE_SUCCESS; | |||
| MS_LOG(INFO) << "RunOpInVM end"; | |||
| return std::move(result); | |||
| return std::move(tuple_result); | |||
| } | |||
| bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim, | |||
| @@ -15,6 +15,9 @@ | |||
| */ | |||
| #include "utils/primitive_utils.h" | |||
| #include <memory> | |||
| #include "pipeline/jit/parse/python_adapter.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "common/utils.h" | |||
| @@ -43,4 +46,25 @@ py::function GetComputeFunction(std::string name) { | |||
| py::object fn = mod.attr(common::SafeCStr(name)); | |||
| return fn; | |||
| } | |||
| py::tuple ConvertDatatoPyTuple(const VectorRef &args) { | |||
| auto py_args = py::tuple(args.size()); | |||
| size_t i = 0; | |||
| for (auto &arg : args) { | |||
| py_args[i] = BaseRefToPyData(arg); | |||
| MS_LOG(DEBUG) << "arg:" << i << ":" << arg.ToString(); | |||
| i++; | |||
| } | |||
| return py_args; | |||
| } | |||
| BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args) { | |||
| auto func = GetComputeFunction(prim->name()); | |||
| if (py::isinstance<py::none>(func)) { | |||
| MS_LOG(EXCEPTION) << prim->name() << " 's compute function run failed, please check whether it is not implemented"; | |||
| } | |||
| auto py_args = ConvertDatatoPyTuple(args); | |||
| py::object obj = func(*py_args); | |||
| return std::make_shared<PyObjectRef>(obj); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -19,6 +19,7 @@ | |||
| #include <string> | |||
| #include "pybind11/pybind11.h" | |||
| #include "utils/base_ref.h" | |||
| namespace py = pybind11; | |||
| @@ -28,6 +29,10 @@ py::function GetBpropFunctionByObj(py::object obj); | |||
| py::function GetBpropFunction(std::string name); | |||
| py::function GetComputeFunction(std::string name); | |||
| BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args); | |||
| py::tuple ConvertDatatoPyTuple(const VectorRef &args); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_ | |||
| @@ -440,25 +440,13 @@ VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) { | |||
| } | |||
| BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) { | |||
| PrimitivePyPtr operation = dyn_cast<PrimitivePy>(prim); | |||
| MS_LOG(DEBUG) << "operation start " << prim->name(); | |||
| auto func = operation != nullptr ? operation->GetComputeFunction() : GetComputeFunction(prim->name()); | |||
| if (py::isinstance<py::none>(func)) { | |||
| MS_LOG(EXCEPTION) << prim->name() << " 's compute function is not implemented"; | |||
| } | |||
| py::tuple py_args = py::tuple(args.size()); | |||
| MS_LOG(DEBUG) << "input for operation:"; | |||
| size_t i = 0; | |||
| for (auto &arg : args) { | |||
| py_args[i] = BaseRefToPyData(arg); | |||
| MS_LOG(DEBUG) << "arg: " << i << ":"; | |||
| i++; | |||
| } | |||
| py::object obj = func(*py_args); | |||
| MS_LOG(DEBUG) << "result:" << py::str(obj); | |||
| return obj; | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto result = prim->RunComputeFunction(args); | |||
| if (result.is_null()) { | |||
| return RunComputeFunction(prim, args); | |||
| } | |||
| return result; | |||
| } | |||
| } // namespace compile | |||
| @@ -83,6 +83,7 @@ class Primitive : public Named { | |||
| void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } | |||
| void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } | |||
| virtual BaseRef RunComputeFunction(const VectorRef &args) const { return nullptr; } | |||
| ValuePtr GetAttr(const std::string &attrName) const { | |||
| auto iter = attrs_.find(attrName); | |||
| @@ -79,13 +79,7 @@ py::function PrimitivePy::GetBpropFunction() { | |||
| } | |||
| BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { | |||
| auto py_args = py::tuple(args.size()); | |||
| size_t i = 0; | |||
| for (auto &arg : args) { | |||
| py_args[i] = BaseRefToPyData(arg); | |||
| MS_LOG(DEBUG) << "arg:" << i << ":"; | |||
| i++; | |||
| } | |||
| auto py_args = ConvertDatatoPyTuple(args); | |||
| py::object obj; | |||
| bool is_bprop = this->HasAttr(kBpropAttrName); | |||
| if (is_bprop) { | |||
| @@ -123,7 +117,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { | |||
| return std::make_shared<PyObjectRef>(obj); | |||
| } | |||
| py::function PrimitivePy::GetComputeFunction() { | |||
| py::function PrimitivePy::GetComputeFunction() const { | |||
| static const char *const compute_func_name = "vm_impl"; | |||
| if (py::hasattr(python_obj_, compute_func_name)) { | |||
| @@ -176,6 +170,32 @@ void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) { | |||
| this->set_hook(primitive_py->hook()); | |||
| } | |||
| BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const { | |||
| auto py_args = ConvertDatatoPyTuple(args); | |||
| auto result = this->RunPyComputeFunction(py_args); | |||
| if (py::isinstance<py::none>(result)) { | |||
| return std::make_shared<BaseRef>(nullptr); | |||
| } | |||
| return std::make_shared<PyObjectRef>(result); | |||
| } | |||
| py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const { | |||
| auto func = this->GetComputeFunction(); | |||
| if (py::isinstance<py::none>(func)) { | |||
| return py::none(); | |||
| } | |||
| auto result = func(*py_args); | |||
| return result; | |||
| } | |||
| bool PrimitivePy::HasComputeFunction() const { | |||
| auto func = GetComputeFunction(); | |||
| if (py::isinstance<py::none>(func)) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | |||
| (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | |||
| .value("unknown", PrimType::kPrimTypeUnknown) | |||
| @@ -41,7 +41,6 @@ class PrimitivePy : public Primitive { | |||
| ~PrimitivePy() override = default; | |||
| MS_DECLARE_PARENT(PrimitivePy, Primitive); | |||
| py::function GetBpropFunction(); | |||
| py::function GetComputeFunction(); | |||
| void set_signatures( | |||
| std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> | |||
| @@ -57,11 +56,15 @@ class PrimitivePy : public Primitive { | |||
| void set_hook(const py::function &hook) { hook_ = hook; } | |||
| py::function hook() const { return hook_; } | |||
| BaseRef RunHookFunction(const VectorRef &args) const override; | |||
| BaseRef RunComputeFunction(const VectorRef &args) const override; | |||
| py::object RunPyComputeFunction(const py::tuple &py_args) const; | |||
| bool HasComputeFunction() const; | |||
| const bool parse_info_ = true; | |||
| const py::object &GetPyObj() const { return python_obj_; } | |||
| bool is_tuple_input_ = false; | |||
| private: | |||
| py::function GetComputeFunction() const; | |||
| py::object python_obj_; | |||
| py::function hook_; | |||
| std::vector<Signature> signatures_; | |||
| @@ -454,8 +454,7 @@ TEST_F(TestOps, GetConv2DPrimPyTest) { | |||
| ASSERT_TRUE(conv2d_ptr); | |||
| if (nullptr != conv2d_ptr) { | |||
| MS_LOG(INFO) << "Get PrimitivePyPtr: " << conv2d_ptr->name(); | |||
| auto func = conv2d_ptr->GetComputeFunction(); | |||
| if (py::isinstance<py::none>(func)) { | |||
| if(!conv2d_ptr->HasComputeFunction()){ | |||
| MS_LOG(EXCEPTION) << "" << conv2d_ptr->name() << "'s compute function is not implemented"; | |||
| } | |||
| @@ -294,8 +294,7 @@ TEST_F(TestStepParallel, CreatOpInstance) { | |||
| ASSERT_TRUE(allreduce_ptr); | |||
| if (nullptr != allreduce_ptr) { | |||
| MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name(); | |||
| auto func = allreduce_ptr->GetComputeFunction(); | |||
| if (py::isinstance<py::none>(func)) { | |||
| if (!allreduce_ptr->HasComputeFunction()) { | |||
| MS_LOG(EXCEPTION) << "" << allreduce_ptr->name() << "'s compute function is not implemented"; | |||
| } | |||
| @@ -57,11 +57,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) { | |||
| std::vector<BaseRef> todos(splits.size()); | |||
| auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), | |||
| [](const BaseRef& seg) -> bool { return utils::isa<VectorRef>(seg); }); | |||
| [](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); }); | |||
| todos.resize(std::distance(todos.begin(), it)); | |||
| ASSERT_EQ(todos.size(), 1); | |||
| AnfNodePtrList anf_list; | |||
| AnfNodePtrList anf_list; | |||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | |||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | |||
| } | |||
| @@ -81,11 +81,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) { | |||
| std::vector<BaseRef> todos(splits.size()); | |||
| auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), | |||
| [](const BaseRef& seg) -> bool { return utils::isa<VectorRef>(seg); }); | |||
| [](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); }); | |||
| todos.resize(std::distance(todos.begin(), it)); | |||
| ASSERT_EQ(todos.size(), 1); | |||
| AnfNodePtrList anf_list; | |||
| AnfNodePtrList anf_list; | |||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | |||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | |||
| } | |||
| @@ -105,11 +105,11 @@ TEST_F(TestCompileSegmentRunner, test_if) { | |||
| std::vector<BaseRef> todos(splits.size()); | |||
| auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), | |||
| [](const BaseRef& seg) -> bool { return utils::isa<VectorRef>(seg); }); | |||
| [](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); }); | |||
| todos.resize(std::distance(todos.begin(), it)); | |||
| ASSERT_EQ(todos.size(), 1); | |||
| AnfNodePtrList anf_list; | |||
| AnfNodePtrList anf_list; | |||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | |||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | |||
| } | |||
| @@ -122,13 +122,13 @@ TEST_F(TestCompileSegmentRunner, test_if) { | |||
| TEST_F(TestCompileSegmentRunner, test_RunOperation1) { | |||
| VectorRef args({1}); | |||
| auto res = RunOperation(prim::kPrimIdentity, args); | |||
| auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimIdentity->name()), py::none()), args); | |||
| ASSERT_EQ(py::cast<int>(BaseRefToPyData(res)), 1); | |||
| } | |||
| TEST_F(TestCompileSegmentRunner, test_RunOperation2) { | |||
| VectorRef args({1, 2}); | |||
| auto res = RunOperation(prim::kPrimScalarGt, args); | |||
| auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimScalarGt->name()), py::none()), args); | |||
| ASSERT_EQ(py::cast<bool>(BaseRefToPyData(res)), false); | |||
| } | |||
| } // namespace compile | |||