/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ #include #include #include #include #include #include #include #include #include #include "pybind11/pybind11.h" #include "pybind11/numpy.h" #include "pybind_api/ir/base_ref_py.h" #include "pipeline/pynative/base.h" #include "utils/ms_context.h" #include "ir/anf.h" #include "pipeline/jit/resource.h" #include "frontend/operator/composite/composite.h" namespace mindspore { namespace pynative { namespace py = pybind11; using ResourcePtr = std::shared_ptr; using GradOperationPtr = std::shared_ptr; struct PrimAbsInfo { abstract::AbstractBasePtr abs; std::unordered_map attrs; }; using AbstractListMap = std::unordered_map; py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); py::tuple RunOp(const py::args &args); void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args, py::list *const out_args_list); void ClearPyNativeSession(); struct GraphInfo { std::unordered_map>> param_map; std::unordered_map>> obj_node_map; AnfNodePtr output; std::vector objects; }; class PynativeExecutor : public std::enable_shared_from_this { public: static std::shared_ptr GetInstance() { std::lock_guard i_lock(instance_lock_); if (executor_ == nullptr) { executor_ = std::shared_ptr(new (std::nothrow) PynativeExecutor()); resource_ = std::make_shared(); } return executor_; } void NewGraph(const py::object &cell, const py::args &args); void NewGraphInner(const py::object &cell, const py::args &args); void EndGraph(const py::object &cell, const py::object &out, const py::args &args); void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args); std::vector GetWeightsArgs(const py::object &weights); abstract::AbstractBasePtrList GetArgsSpec(const py::args &args); void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); void Clear(const std::string &flag = ""); void Clean(); void ClearRes(); bool grad_flag() { return grad_flag_; } void set_grad_flag(bool flag) { grad_flag_ = flag; } AnfNodePtr GetInput(const py::object &obj, bool op_mask); AnfNodePtr GetObjNode(const py::object &obj); AnfNodePtr GetParamNode(const py::object &obj); std::string GetCellId(const py::object &obj, const py::args &args); FuncGraphPtr curr_g() { return curr_g_; } void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); } void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) { graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, std::vector{-1}); } void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) { graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, std::vector{index}); } void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector index) { graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index); } void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) { graph_info_map_[g].param_map[obj] = std::make_pair(node, std::vector{-1}); } void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) { graph_info_map_[g].param_map[obj] = std::make_pair(node, std::vector{index}); } void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector index) { graph_info_map_[g].param_map[obj] = std::make_pair(node, index); } AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, abstract::AbstractBasePtrList *args_spec_list); void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode); ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info); void SaveOpForwardValue(const std::string &id, const ValuePtr &value, std::map *t_map); void SaveForwardResult(const CNodePtr &cnode, const py::object &out); void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out); py::object Run(const py::tuple &args, const py::object &phase); void Pushp(); void Popp(); FuncGraphPtr GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector &weights, size_t arg_size); void SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector idx); void SetTupleParam(const py::object &obj, const AnfNodePtr ¶_node, std::vector idx); AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); py::tuple RunOpInner(const py::args &args); py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info); ~PynativeExecutor(); private: PynativeExecutor(); static std::shared_ptr executor_; static std::mutex instance_lock_; static ResourcePtr resource_; static int graph_id_; bool grad_flag_; bool first_grad_step_; std::unordered_map graph_map_; std::unordered_map cell_graph_map_; std::unordered_map cell_resource_map_; std::unordered_map graph_info_map_; std::unordered_map op_forward_map_; std::unordered_map op_id_map_; std::unordered_map obj_to_forward_id_; std::unordered_map node_abs_map_; std::unordered_map df_builder_map_; // the stack that records the context of graph created, the bottom is the top graph std::stack graph_context_; FuncGraphPtr top_g_; FuncGraphPtr df_builder_; FuncGraphPtr curr_g_; std::unordered_map prim_abs_list_; std::set top_graph_cells_; }; using PynativeExecutorPtr = std::shared_ptr; } // namespace pynative } // namespace mindspore #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_