You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pynative_execute.h 22 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_
  17. #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_
  18. #include <utility>
  19. #include <vector>
  20. #include <string>
  21. #include <memory>
  22. #include <mutex>
  23. #include <stack>
  24. #include <set>
  25. #include <map>
  26. #include "utils/hash_map.h"
  27. #include "utils/hash_set.h"
  28. #include "pybind11/pybind11.h"
  29. #include "pybind11/numpy.h"
  30. #include "pybind_api/ir/base_ref_py.h"
  31. #include "ir/anf.h"
  32. #include "frontend/optimizer/ad/kpynative.h"
  33. #include "frontend/operator/composite/composite.h"
  34. #include "pipeline/jit/resource.h"
  35. #include "pipeline/pynative/base.h"
  36. #include "pipeline/pynative/pynative_cache.h"
  37. #include "utils/ms_context.h"
  38. namespace mindspore::pynative {
  39. namespace py = pybind11;
  40. using OpInfoWithTensorId = mindspore::HashMap<std::string, std::vector<std::string>>;
  41. using TensorIdWithTensorObject = mindspore::HashMap<std::string, std::vector<tensor::TensorPtr>>;
  42. using OpInfoWithMsFuncForwardTensors = mindspore::HashMap<std::string, std::vector<tensor::TensorPtr>>;
  43. py::object RealRunOp(const py::args &args);
  44. struct GraphInfo {
  45. std::string cell_id;
  46. AnfNodePtr output;
  47. OrderedMap<std::string, ParameterPtr> params; // hold input parameters and cell weights
  48. mindspore::HashMap<std::string, std::pair<AnfNodePtr, std::vector<int64_t>>> node_map;
  49. GraphInfo() = default;
  50. explicit GraphInfo(std::string id) : cell_id(std::move((id))) {}
  51. };
  52. using GraphInfoPtr = std::shared_ptr<GraphInfo>;
  53. class TopCellInfo {
  54. public:
  55. TopCellInfo() = default;
  56. ~TopCellInfo() = default;
  57. TopCellInfo(bool topest, size_t grad_order, pipeline::ResourcePtr r, FuncGraphPtr fg, FuncGraphPtr df,
  58. std::string cellid, std::string already_run_cell_id)
  59. : is_topest_(topest),
  60. grad_order_(grad_order),
  61. resource_(std::move(r)),
  62. fg_(std::move(fg)),
  63. df_builder_(std::move(df)),
  64. cell_id_(std::move(cellid)),
  65. already_run_cell_id_(std::move(already_run_cell_id)) {}
  66. bool is_init_kpynative() const { return is_init_kpynative_; }
  67. void set_init_kpynative(bool init) { is_init_kpynative_ = init; }
  68. bool is_topest() const { return is_topest_; }
  69. size_t grad_order() const { return grad_order_; }
  70. void set_grad_order(size_t grad_order) { grad_order_ = grad_order; }
  71. bool is_dynamic() const { return is_dynamic_; }
  72. void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; }
  73. bool vm_compiled() const { return vm_compiled_; }
  74. void set_vm_compiled(bool vm_compiled) { vm_compiled_ = vm_compiled; }
  75. bool ms_function_flag() const { return ms_function_flag_; }
  76. void set_ms_function_flag(bool ms_function_flag) { ms_function_flag_ = ms_function_flag; }
  77. bool need_compile_graph() const { return need_compile_graph_; }
  78. void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
  79. bool forward_already_run() const { return forward_already_run_; }
  80. void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
  81. pipeline::ResourcePtr resource() const { return resource_; }
  82. FuncGraphPtr df_builder() const { return df_builder_; }
  83. FuncGraphPtr fg() const { return fg_; }
  84. void set_fg(const FuncGraphPtr &fg) { fg_ = fg; }
  85. size_t op_num() const { return op_num_; }
  86. void set_op_num(size_t op_num) { op_num_ = op_num; }
  87. const std::string &cell_id() const { return cell_id_; }
  88. const std::string &already_run_cell_id() const { return already_run_cell_id_; }
  89. const std::string &input_args_id() const { return input_args_id_; }
  90. void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
  91. std::string &all_op_info() { return all_op_info_; }
  92. const std::string &grad_operation() const { return grad_operation_; }
  93. void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; }
  94. mindspore::HashSet<std::string> &sub_cell_list() { return sub_cell_list_; }
  95. bool IsSubCell(const std::string &cell_id) const;
  96. OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() { return graph_info_map_; }
  97. OpInfoWithTensorId &op_info_with_tensor_id() { return op_info_with_tensor_id_; }
  98. TensorIdWithTensorObject &tensor_id_with_tensor_object() { return tensor_id_with_tensor_object_; }
  99. ad::KPynativeCellPtr k_pynative_cell_ptr() const { return k_pynative_cell_ptr_; }
  100. void set_k_pynative_cell_ptr(const ad::KPynativeCellPtr &k_pynative_cell_ptr) {
  101. k_pynative_cell_ptr_ = k_pynative_cell_ptr;
  102. }
  103. const OpInfoWithMsFuncForwardTensors &op_info_with_ms_func_forward_tensors() const {
  104. return op_info_with_ms_func_forward_tensors_;
  105. }
  106. void set_op_info_with_ms_func_forward_tensors(const std::string &op_info,
  107. const std::vector<tensor::TensorPtr> &forward_tensors) {
  108. op_info_with_ms_func_forward_tensors_[op_info] = forward_tensors;
  109. }
  110. void ClearDeviceMemory();
  111. void Clear();
  112. private:
  113. bool is_topest_{false};
  114. bool is_dynamic_{false};
  115. bool vm_compiled_{false};
  116. bool ms_function_flag_{false};
  117. bool is_init_kpynative_{false};
  118. bool forward_already_run_{false};
  119. bool need_compile_graph_{false};
  120. size_t op_num_{0};
  121. size_t grad_order_{0};
  122. pipeline::ResourcePtr resource_{nullptr};
  123. FuncGraphPtr fg_{nullptr};
  124. FuncGraphPtr df_builder_{nullptr};
  125. ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr};
  126. std::string cell_id_;
  127. std::string already_run_cell_id_;
  128. std::string input_args_id_;
  129. std::string all_op_info_;
  130. std::string grad_operation_;
  131. OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
  132. mindspore::HashSet<std::string> sub_cell_list_;
  133. OpInfoWithTensorId op_info_with_tensor_id_;
  134. TensorIdWithTensorObject tensor_id_with_tensor_object_;
  135. OpInfoWithMsFuncForwardTensors op_info_with_ms_func_forward_tensors_;
  136. };
  137. using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;
  138. class ForwardExecutor;
  139. using ForwardExecutorPtr = std::shared_ptr<ForwardExecutor>;
  140. using ForwardExecutorWeakPtr = std::weak_ptr<ForwardExecutor>;
  141. class GradExecutor;
  142. using GradExecutorPtr = std::shared_ptr<GradExecutor>;
  143. using GradExecutorWeakPtr = std::weak_ptr<GradExecutor>;
  144. class GradExecutor {
  145. public:
  146. GradExecutor() = default;
  147. ~GradExecutor() = default;
  148. explicit GradExecutor(const ForwardExecutorPtr &forward_executor = nullptr)
  149. : forward_executor_(ForwardExecutorWeakPtr(forward_executor)) {}
  150. std::function<void(py::object *, const py::object &, const py::args &)> InitGraph = [this](auto &&PH1, auto &&PH2,
  151. auto &&PH3) {
  152. NewGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3));
  153. };
  154. std::function<void(py::object *, const py::object &, const py::object &, const py::args &)> LinkGraph =
  155. [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4) {
  156. EndGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2),
  157. std::forward<decltype(PH3)>(PH3), std::forward<decltype(PH4)>(PH4));
  158. };
  159. std::function<void(py::object *, const prim::GradOperationPtr &, const py::object &, const py::object &,
  160. const py::object &, const py::args &)>
  161. GradGraph = [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4, auto &&PH5, auto &&PH6) {
  162. GradNetInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3),
  163. std::forward<decltype(PH4)>(PH4), std::forward<decltype(PH5)>(PH5),
  164. std::forward<decltype(PH6)>(PH6));
  165. };
  166. std::function<void(py::object *, const py::object &, const py::tuple &)> RunGraph = [this](auto &&PH1, auto &&PH2,
  167. auto &&PH3) {
  168. RunGradGraph(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3));
  169. };
  170. FuncGraphPtr curr_g() const;
  171. TopCellInfoPtr top_cell() const;
  172. void CheckNeedCompileGraph();
  173. void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell);
  174. size_t GetHighOrderStackSize() const { return high_order_stack_.size(); }
  175. TopCellInfoPtr GetTopCell(const string &already_run_cell_id);
  176. void EnableOpGraphCache(bool is_enable);
  177. bool need_renormalize() const { return need_renormalize_; }
  178. bool enable_op_cache() const { return enable_op_cache_; }
  179. void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
  180. bool grad_flag() const { return grad_flag_; }
  181. void set_grad_flag(bool flag) { grad_flag_ = flag; }
  182. void set_graph_phase(const std::string &graph_phase) { graph_phase_ = graph_phase; }
  183. bool in_cell_with_custom_bprop_() const { return custom_bprop_cell_count_ > 0; }
  184. AnfNodePtr GetInput(const py::object &obj, bool op_mask);
  185. std::string GetCellId(const py::object &obj, const py::args &args);
  186. void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out);
  187. bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; }
  188. // Construct grad graph for ms_function
  189. bool eliminate_forward() const { return eliminate_forward_; }
  190. void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; }
  191. py::object GradMsFunction(const py::object &out, const py::args &args);
  192. void GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args,
  193. const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph);
  194. void UpdateMsFunctionForwardTensors(const OpExecInfoPtr &op_exec_info, const ValuePtr &new_forward_value);
  195. void MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
  196. const py::object &actual_out, const py::args &args, const ValuePtr &actual_out_v);
  197. void MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args, ValuePtrList *input_values,
  198. CNodePtr *ms_function_cnode);
  199. void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const CNodePtr &cnode);
  200. void DoOpGrad(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const ValuePtr &op_out);
  201. // Update forward tensors info
  202. void UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out);
  203. void SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const;
  204. py::object CheckGraph(const py::object &cell, const py::args &args);
  205. void RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args);
  206. py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, const py::args &args);
  207. void EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell);
  208. void ClearGrad(const py::object &cell, const py::args &args);
  209. void ClearRes();
  210. void ClearCellRes(const std::string &cell_id = "");
  211. private:
  212. ForwardExecutorPtr forward() const;
  213. // Higher derivative
  214. inline bool IsNestedGrad() const;
  215. void SwitchTopcell();
  216. void DoParameterReplace(const FuncGraphPtr &first_grad_fg, const py::tuple &forward_args,
  217. std::vector<AnfNodePtr> *inputs, ValuePtrList *weights_args);
  218. void MakeNestedCnode(const py::object &cell, const py::tuple &forward_args, const pipeline::ResourcePtr &resource,
  219. const py::object &out);
  220. void PushCellStack(const std::string &cell_id);
  221. void PopCellStack();
  222. TopCellInfoPtr PopHighOrderGraphStack();
  223. void HandleInputArgsForTopCell(const py::args &args, bool is_bprop_top);
  224. void InitResourceAndDfBuilder(const std::string &cell_id, const py::args &args);
  225. void MakeNewTopGraph(const string &cell_id, const py::args &args, bool is_topest);
  226. void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled);
  227. // Manage resource when run grad process.
  228. bool IsBpropGraph(const std::string &cell_id);
  229. bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id) const;
  230. void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
  231. void NewGraphInner(py::object *ret, const py::object &cell, const py::args &args);
  232. void EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args);
  233. void DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args);
  234. std::string GetAlreadyRunCellId(const std::string &cell_id);
  235. std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args);
  236. void GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
  237. const py::object &weights, const py::object &grad_position, const py::args &args);
  238. FuncGraphPtr GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell,
  239. const std::vector<AnfNodePtr> &weights, const std::vector<size_t> &grad_position,
  240. size_t arg_size, const py::args &args);
  241. std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
  242. void UpdateParamAbsByArgs(const py::list &args, const FuncGraphPtr &bprop_graph);
  243. std::vector<size_t> GetGradPositionArgs(const py::object &grad_position);
  244. // Manage resource for construct forward graph.
  245. const std::string &graph_phase() const { return graph_phase_; }
  246. AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
  247. AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
  248. AnfNodePtr CreateMakeTupleNode(const py::object &obj, const std::string &obj_id);
  249. AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id);
  250. void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
  251. const std::vector<int64_t> &index_sequence, bool is_param = false);
  252. void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
  253. bool is_param = false);
  254. void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr &param) const {
  255. auto &graph_info = top_cell()->graph_info_map()[g];
  256. MS_EXCEPTION_IF_NULL(graph_info);
  257. graph_info->params[id] = param;
  258. }
  259. void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
  260. int64_t index = -1) const {
  261. auto &graph_info = top_cell()->graph_info_map()[g];
  262. MS_EXCEPTION_IF_NULL(graph_info);
  263. graph_info->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
  264. }
  265. void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
  266. const std::vector<int64_t> &index) const {
  267. auto &graph_info = top_cell()->graph_info_map()[g];
  268. MS_EXCEPTION_IF_NULL(graph_info);
  269. graph_info->node_map[id] = std::make_pair(node, index);
  270. }
  271. private:
  272. bool grad_flag_{false};
  273. bool enable_op_cache_{true};
  274. bool grad_is_running_{false};
  275. bool need_renormalize_{false};
  276. bool eliminate_forward_{true};
  277. int custom_bprop_cell_count_{0};
  278. size_t grad_order_{0};
  279. size_t top_cell_switch_counts_{0};
  280. // The graph phase is used to obtain backend graph that is complied by ms_function
  281. std::string graph_phase_;
  282. // The cell run check graph which will be top cell
  283. std::string check_graph_cell_id_;
  284. std::string grad_operation_;
  285. TopCellInfoPtr top_cell_{nullptr};
  286. // Records forwrad cell, the bottom is top cell
  287. std::stack<std::string> cell_stack_;
  288. // For high grad of bprop
  289. std::stack<std::pair<std::string, bool>> bprop_grad_stack_;
  290. std::vector<std::string> bprop_cell_list_;
  291. // For high grad order
  292. std::stack<TopCellInfoPtr> high_order_stack_;
  293. // Use vector for keep order
  294. std::vector<TopCellInfoPtr> top_cell_list_;
  295. // Record all top cell which has been ran
  296. mindspore::HashMap<std::string, TopCellInfoPtr> already_run_top_cell_;
  297. // Use vector for keep order
  298. ForwardExecutorWeakPtr forward_executor_;
  299. };
  300. class ForwardExecutor {
  301. public:
  302. ForwardExecutor() = default;
  303. ~ForwardExecutor() = default;
  304. std::function<void(py::object *, const OpExecInfoPtr &)> RunOpS = [this](auto &&PH1, auto &&PH2) {
  305. RunOpInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2));
  306. };
  307. void RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_info);
  308. OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
  309. void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); }
  310. mindspore::HashMap<std::string, abstract::AbstractBasePtr> &node_abs_map() { return node_abs_map_; }
  311. void ClearRes();
  312. CNodePtr ConstructForwardGraph(const OpExecInfoPtr &op_exec_info);
  313. void set_lazy_build(bool lazy_build) { lazy_build_ = lazy_build; }
  314. private:
  315. GradExecutorPtr grad() const;
  316. MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info);
  317. py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info);
  318. void RunMixedPrecisionCastOp(const OpExecInfoPtr &op_exec_info, py::object *ret);
  319. py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
  320. py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
  321. py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
  322. PynativeStatusCode *status);
  323. void SetNonCostantValueAbs(const AbstractBasePtr &abs, size_t i, const std::string &id);
  324. void GetInputsArgsSpec(const OpExecInfoPtr &op_exec_info, abstract::AbstractBasePtrList *args_spec_list);
  325. void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
  326. bool *prim_cache_hit);
  327. void GetOpOutput(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
  328. const CNodePtr &cnode, bool prim_cache_hit, py::object *ret);
  329. // Mix precision and Implicit transform
  330. void SetCastForInputs(const OpExecInfoPtr &op_exec_info);
  331. void SetTensorMixPrecisionCast(const OpExecInfoPtr &op_exec_info);
  332. void SetImplicitCast(const OpExecInfoPtr &op_exec_info);
  333. py::object DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name, size_t index);
  334. py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple &tuple, const std::string &op_name,
  335. size_t index);
  336. py::object DoAutoCastTuple(const py::tuple &tuple, const TypeId &type_id, const std::string &op_name, size_t index);
  337. py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index);
  338. void DoSignatureCast(const PrimitivePyPtr &prim, const mindspore::HashMap<SignatureEnumDType, TypeId> &dst_type,
  339. const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info);
  340. void CheckIfNeedSyncForHeterogeneous(const std::string &cur_target);
  341. private:
  342. GradExecutorWeakPtr grad_executor_;
  343. PrimAbsCache prim_abs_list_;
  344. ImplicitCastCache implicit_cast_map_;
  345. mindspore::HashMap<std::string, abstract::AbstractBasePtr> node_abs_map_;
  346. bool lazy_build_{false};
  347. std::string last_target_{"Unknown"};
  348. };
  349. class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
  350. public:
  351. static std::shared_ptr<PynativeExecutor> GetInstance() {
  352. std::lock_guard<std::mutex> i_lock(instance_lock_);
  353. if (executor_ == nullptr) {
  354. executor_ = std::shared_ptr<PynativeExecutor>(new (std::nothrow) PynativeExecutor());
  355. forward_executor_ = std::make_shared<ForwardExecutor>();
  356. grad_executor_ = std::make_shared<GradExecutor>(forward_executor_);
  357. forward_executor_->set_grad_executor(grad_executor_);
  358. }
  359. return executor_;
  360. }
  361. ~PynativeExecutor() = default;
  362. PynativeExecutor(const PynativeExecutor &) = delete;
  363. PynativeExecutor &operator=(const PynativeExecutor &) = delete;
  364. GradExecutorPtr grad_executor() const;
  365. ForwardExecutorPtr forward_executor() const;
  366. bool grad_flag() const;
  367. void set_grad_flag(bool flag);
  368. void set_graph_phase(const std::string &graph_phase);
  369. void set_py_exe_path(const py::object &py_exe_path);
  370. void set_kernel_build_server_dir(const py::object &kernel_build_server_dir);
  371. void NewGraph(const py::object &cell, const py::args &args);
  372. void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
  373. void GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
  374. const py::object &grad_position, const py::args &args);
  375. py::object GradMsFunction(const py::object &out, const py::args &args);
  376. py::object CheckGraph(const py::object &cell, const py::args &args);
  377. py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, const py::args &args);
  378. py::object Run(const py::object &cell, const py::tuple &args);
  379. // Used by graph clean
  380. // Cell destruct will call
  381. void ClearCell(const std::string &cell_id);
  382. void ClearGrad(const py::object &cell, const py::args &args);
  383. // Abnormal existed
  384. void ClearRes();
  385. // Sync stream
  386. void Sync();
  387. void SetLazyBuild(bool enable);
  388. void ExecuteAllTask();
  389. void EnterCell();
  390. void ExitCell();
  391. bool IsTopCell() const;
  392. private:
  393. PynativeExecutor() = default;
  394. static std::shared_ptr<PynativeExecutor> executor_;
  395. static std::mutex instance_lock_;
  396. static ForwardExecutorPtr forward_executor_;
  397. static GradExecutorPtr grad_executor_;
  398. uint32_t cell_depth_{0};
  399. };
  400. using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
  401. } // namespace mindspore::pynative
  402. #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_