You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_graph.h 26 kB

4 years ago
4 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H
  17. #define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H
  18. #include <vector>
  19. #include <memory>
  20. #include <utility>
  21. #include <string>
  22. #include <queue>
  23. #include <map>
  24. #include <set>
  25. #include <stack>
  26. #include <atomic>
  27. #include "utils/hash_map.h"
  28. #include "utils/hash_set.h"
  29. #include "ir/func_graph.h"
  30. #include "ir/anf.h"
  31. #include "ir/graph_utils.h"
  32. #include "utils/contract.h"
  33. #include "runtime/device/kernel_info.h"
  34. namespace mindspore {
  35. namespace session {
  36. using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>;
  37. using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
  38. struct KernelWithIndexCmp {
  39. bool operator()(const KernelWithIndex &key1, const KernelWithIndex &key2) const {
  40. if (key1.first != key2.first) {
  41. return key1.first < key2.first;
  42. }
  43. if (key1.second != key2.second) {
  44. return key1.second < key2.second;
  45. }
  46. return false;
  47. }
  48. };
  49. using KernelMapTensor = std::map<session::KernelWithIndex, BaseRef, session::KernelWithIndexCmp>;
  50. class KernelGraph : public FuncGraph {
  51. public:
  52. KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), current_epoch_(0), is_dynamic_shape_(false) {
  53. inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
  54. execution_order_ = {};
  55. mem_reuse_exec_order_ = {};
  56. executable_ = true;
  57. summary_node_exist_ = false;
  58. stream_distinction_label_ = kInvalidDistincLabel;
  59. }
  60. KernelGraph(const KernelGraph &graph) : FuncGraph(graph) {
  61. inputs_ = graph.inputs_;
  62. child_graph_result_ = graph.child_graph_result_;
  63. execution_order_ = graph.execution_order_;
  64. mem_reuse_exec_order_ = graph.mem_reuse_exec_order_;
  65. graph_id_ = graph.graph_id_;
  66. stream_distinction_label_ = graph.stream_distinction_label_;
  67. front_backend_anf_map_ = graph.front_backend_anf_map_;
  68. backend_front_anf_map_ = graph.backend_front_anf_map_;
  69. tensor_to_value_node_map_ = graph.tensor_to_value_node_map_;
  70. graph_value_nodes_ = graph.graph_value_nodes_;
  71. node_input_num_ = graph.node_input_num_;
  72. node_input_edges_ = graph.node_input_edges_;
  73. ref_out_in_map_ = graph.ref_out_in_map_;
  74. node_output_edges_ = graph.node_output_edges_;
  75. summary_nodes_ = graph.summary_nodes_;
  76. updated_parameters_ = graph.updated_parameters_;
  77. executable_ = graph.executable_;
  78. summary_node_exist_ = graph.summary_node_exist_;
  79. valid_inputs_ = graph.valid_inputs_;
  80. child_graph_order_ = graph.child_graph_order_;
  81. device_loop_ctrl_tensors_ = graph.device_loop_ctrl_tensors_;
  82. device_loop_ctrl_params_ = graph.device_loop_ctrl_params_;
  83. parent_graph_ = graph.parent_graph_;
  84. start_label_ = graph.start_label_;
  85. end_goto_ = graph.end_goto_;
  86. internal_parameter_to_front_node_map_ = graph.internal_parameter_to_front_node_map_;
  87. graph_output_to_front_node_map_ = graph.graph_output_to_front_node_map_;
  88. front_node_to_graph_output_map_ = graph.front_node_to_graph_output_map_;
  89. front_to_internal_outputs_map_ = graph.front_to_internal_outputs_map_;
  90. internal_outputs_to_front_map_ = graph.internal_outputs_to_front_map_;
  91. internal_outputs_tensor_map_ = graph.internal_outputs_tensor_map_;
  92. current_epoch_ = graph.current_epoch_;
  93. tuple_parameter_to_make_tuple_map_ = graph.tuple_parameter_to_make_tuple_map_;
  94. visited_nodes_ = graph.visited_nodes_;
  95. edge_to_ = graph.edge_to_;
  96. loop_nodes_ = graph.loop_nodes_;
  97. input_nodes_ = graph.input_nodes_;
  98. pre_graphs_ = graph.pre_graphs_;
  99. post_graphs_ = graph.post_graphs_;
  100. allreduce_from_send_recv_pairs_ = graph.allreduce_from_send_recv_pairs_;
  101. allreduce_to_send_recv_pairs_ = graph.allreduce_to_send_recv_pairs_;
  102. size_t pre_graph_finished_count = graph.pre_graph_finished_count_;
  103. pre_graph_finished_count_ = pre_graph_finished_count;
  104. size_t post_graph_finished_count = graph.post_graph_finished_count_;
  105. post_graph_finished_count_ = post_graph_finished_count;
  106. first_step_ = graph.first_step_;
  107. has_optimizer_ = graph.has_optimizer_;
  108. is_dynamic_shape_ = graph.is_dynamic_shape_;
  109. }
  110. ~KernelGraph() override;
  111. MS_DECLARE_PARENT(KernelGraph, FuncGraph);
  112. const std::vector<AnfNodePtr> &inputs() const;
  113. std::vector<AnfNodePtr> *MutableInputs() const { return inputs_.get(); }
  114. void SetGraphInputs(const std::vector<AnfNodePtr> &inputs) {
  115. inputs_ = std::make_shared<std::vector<AnfNodePtr>>(inputs);
  116. }
  117. void ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter);
  118. std::vector<AnfNodePtr> outputs() const;
  119. CNodePtr NewCNode(std::vector<AnfNodePtr> &&inputs) override;
  120. CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
  121. CNodePtr NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, const CNodePtr &ori_cnode = nullptr);
  122. void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
  123. CNodePtr NewCNode(const CNodePtr &cnode);
  124. void ResetAssignInputFeatureMapFlag(const CNodePtr &cnode) const;
  125. ParameterPtr NewParameter(const ParameterPtr &parameter = nullptr);
  126. ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract);
  127. ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value);
  128. ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr);
  129. ValueNodePtr NewValueNode(const tensor::TensorPtr &input_tensor);
  130. // trans tuple output to maketuple + no_tuple out
  131. AnfNodePtr TransTupleToMakeTuple(const AnfNodePtr &node);
  132. void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; }
  133. void set_execution_order(std::vector<CNodePtr> &&order) { execution_order_ = std::move(order); }
  134. const std::vector<CNodePtr> &execution_order() const { return execution_order_; }
  135. // Set new exec_order for mem_reuse
  136. void set_mem_reuse_exec_order(const std::vector<CNodePtr> &order) { mem_reuse_exec_order_ = order; }
  137. const std::vector<CNodePtr> &mem_reuse_exec_order() const { return mem_reuse_exec_order_; }
  138. void SetExecOrderByDefault();
  139. uint32_t graph_id() const { return graph_id_; }
  140. void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; }
  141. uint32_t root_graph_id() const { return root_graph_id_; }
  142. void set_root_graph_id(uint32_t root_graph_id) { root_graph_id_ = root_graph_id; }
  143. // and a new front to backend anf relation to maop
  144. void FrontBackendMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf);
  145. // replace old backend anf with new backend anf
  146. void FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf);
  147. // get backend anf by front anf
  148. AnfNodePtr GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf);
  149. // get front anf by backend anf
  150. AnfNodePtr GetFrontAnfByBackendAnf(const AnfNodePtr &backend_anf);
  151. // check backend node whether exist in map
  152. bool BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf);
  153. // get value node by tensor
  154. ValueNodePtr GetValueNodeByTensor(const tensor::TensorPtr &tensor);
  155. // add value node tensor relation map
  156. void TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node);
  157. // get all value nodes of graph
  158. const mindspore::HashSet<ValueNodePtr> graph_value_nodes() const { return graph_value_nodes_; }
  159. // add value node to graph
  160. void AddValueNodeToGraph(const ValueNodePtr &value_node);
  161. // ref output is in map
  162. bool IsInRefOutputMap(const AnfWithOutIndex &pair) const;
  163. // get ref correspond pairs
  164. AnfWithOutIndex GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const;
  165. // add ref correspond pairs
  166. void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair);
  167. // get map
  168. std::map<AnfWithOutIndex, AnfWithOutIndex> GetRefMap() const { return ref_out_in_map_; }
  169. // check whether graph is executable
  170. bool executable() const { return executable_; }
  171. // set executable of graph
  172. void set_executable(bool executable) { executable_ = executable; }
  173. #ifndef ENABLE_SECURITY
  174. // set summary_node of graph
  175. void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; }
  176. #endif
  177. // check whether exist summary node in graph
  178. bool summary_node_exist() const { return summary_node_exist_; }
  179. // set invalid inputs for control sink
  180. std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
  181. std::vector<bool> valid_inputs() const { return valid_inputs_; }
  182. // replace node in graph
  183. void ReplaceNode(const AnfNodePtr &old_anf_node, const AnfNodePtr &new_anf_node);
  184. // set stream label of graph
  185. void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; }
  186. // get stream label of graph
  187. uint32_t stream_distinction_label() { return stream_distinction_label_; }
  188. // refresh execute kernel stream label
  189. void UpdateExecuteKernelStreamLabel();
  190. // calculate the leaf graph order of root graph
  191. std::vector<std::shared_ptr<KernelGraph>> GetLeafGraphOrder();
  192. // the child graph of current graph
  193. const std::vector<std::weak_ptr<KernelGraph>> &child_graph_order() const { return child_graph_order_; }
  194. void set_child_graph_order(const std::vector<std::weak_ptr<KernelGraph>> &order) { child_graph_order_ = order; }
  195. // checkout whether current graph is leaf graph
  196. bool IsLeafGraph() const;
  197. void set_device_loop_ctrl_tensors(const std::map<std::string, tensor::TensorPtr> &device_loop_ctrl_tensors) {
  198. device_loop_ctrl_tensors_ = device_loop_ctrl_tensors;
  199. }
  200. std::map<std::string, tensor::TensorPtr> device_loop_control_tensors() const { return device_loop_ctrl_tensors_; }
  201. void set_device_loop_ctrl_params(const std::map<std::string, mindspore::ParameterPtr> &device_loop_ctrl_params) {
  202. device_loop_ctrl_params_ = device_loop_ctrl_params;
  203. }
  204. const std::map<std::string, mindspore::ParameterPtr> device_loop_control_params() const {
  205. return device_loop_ctrl_params_;
  206. }
  207. // get parent kernel graph
  208. std::weak_ptr<KernelGraph> parent_graph() const { return parent_graph_; }
  209. // set parent kernel graph
  210. void set_parent_graph(const std::weak_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; }
  211. // find anf node in graph
  212. std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
  213. std::vector<CNodePtr> FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const;
  214. // used to dump ir
  215. std::string ToString() const override;
  216. void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; }
  217. CNodePtr get_start_label() { return start_label_; }
  218. void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; }
  219. CNodePtr get_end_goto() { return end_goto_; }
  220. void PrintGraphExecuteOrder() const;
  221. const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; }
  222. void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; }
  223. void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, size_t output_idx, bool unique_target);
  224. void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, size_t src_output_idx,
  225. size_t dst_output_idx);
  226. void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node);
  227. AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
  228. bool IsInternalOutput(const AnfNodePtr &node, size_t output_idx) const;
  229. bool IsInternalOutput(const AnfNodePtr &node) const;
  230. bool IsUniqueTargetInternalOutput(const AnfNodePtr &node, size_t output_idx) const;
  231. void AddInternalOutputTensor(const AnfNodePtr &node, size_t output_idx, const tensor::TensorPtr &tensor);
  232. tensor::TensorPtr GetInternalOutputTensor(const AnfNodePtr &node, size_t output_idx);
  233. AnfWithOutIndex GetGraphOutputByFrontNode(const AnfWithOutIndex &front_node) const;
  234. // Cache the internal parameter and corresponding to front node into internal_parameter_to_front_node_map_.
  235. void CacheInternalParameterToFrontNode(const AnfNodePtr &parameter, const AnfWithOutIndex &front_node_with_index);
  236. AnfWithOutIndex GetFrontNodeByInternalParameter(const AnfNodePtr &parameter) const;
  237. // Get the funcgraph to which the kernel graph belongs.
  238. FuncGraphPtr GetFuncGraph();
  239. // Cache the backend graph output nodes and corresponding to front nodes with output index into
  240. // graph_output_to_front_node_map_.
  241. void CacheGraphOutputToFrontNodeWithIndex(const std::vector<AnfNodePtr> &backend_outputs,
  242. const std::vector<AnfNodePtr> &front_outputs);
  243. AnfWithOutIndex GetFrontNodeWithIndexByGraphOutput(const AnfWithOutIndex &backend_graph_output_with_index) const;
  244. uint32_t current_epoch() const { return current_epoch_; }
  245. void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }
  246. void UpdateChildGraphOrder();
  247. const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; }
  248. void AddChildGraphResult(const AnfNodePtr &parameter) { child_graph_result_.push_back(parameter); }
  249. bool IsChildGraphResult(const AnfNodePtr &node);
  250. void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
  251. child_graph_result_ = child_graph_result;
  252. }
  253. void InsertTupleParameterToMakeTupleMap(const AnfNodePtr &param, const AnfNodePtr &make_tuple) {
  254. if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
  255. return;
  256. }
  257. tuple_parameter_to_make_tuple_map_[param] = make_tuple;
  258. }
  259. AnfNodePtr FindTupleParameterToMakeTupleMap(const AnfNodePtr &param) {
  260. if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
  261. return tuple_parameter_to_make_tuple_map_[param];
  262. } else {
  263. return nullptr;
  264. }
  265. }
  266. void RemoveNodeFromGraph(const AnfNodePtr &node);
  267. void UpdateGraphDynamicAttr();
  268. void SetGraphDynamicAttr(bool is_dynamic_shape) { is_dynamic_shape_ = is_dynamic_shape; }
  269. bool is_dynamic_shape() const { return is_dynamic_shape_; }
  270. void UpdateGraphAquireGilAttr();
  271. void SetOptimizerFlag();
  272. void SetInputNodes();
  273. const std::vector<AnfNodePtr> &input_nodes() const { return input_nodes_; }
  274. void SetInputTensors(const std::vector<tensor::TensorPtr> &input_tensors) { input_tensors_ = input_tensors; }
  275. const std::vector<tensor::TensorPtr> &input_tensors() const { return input_tensors_; }
  276. void SetOutputNodeToTensor(const KernelMapTensor &node_to_tensor) { output_node_to_tensor_ = node_to_tensor; }
  277. tensor::TensorPtr GetNodeOutputTensor(const session::KernelWithIndex &output_index) const {
  278. auto iter = output_node_to_tensor_.find(output_index);
  279. if (iter != output_node_to_tensor_.end()) {
  280. return utils::cast<tensor::TensorPtr>(iter->second);
  281. }
  282. return nullptr;
  283. }
  284. bool has_optimizer() const { return has_optimizer_; }
  285. bool IsUpdatedParameter(const ParameterPtr &param) const {
  286. if (updated_parameters_.find(param) != updated_parameters_.end()) {
  287. return true;
  288. }
  289. return false;
  290. }
  291. // handle graph dependency
  292. void AddPreGraph(const std::shared_ptr<session::KernelGraph> &graph) {
  293. if (graph != nullptr) {
  294. pre_graphs_[graph->graph_id()] = graph;
  295. }
  296. }
  297. mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> get_pre_graphs() const { return pre_graphs_; }
  298. void AddPostGraph(const std::shared_ptr<session::KernelGraph> &graph) {
  299. if (graph != nullptr) {
  300. post_graphs_[graph->graph_id()] = graph;
  301. }
  302. }
  303. bool IsPreGraphFinished() const { return pre_graphs_.size() == pre_graph_finished_count_; }
  304. bool IsPostGraphFinished() const {
  305. if (first_step_) {
  306. return true;
  307. }
  308. return post_graphs_.size() == post_graph_finished_count_;
  309. }
  310. bool HasPostGraph() const { return !post_graphs_.empty(); }
  311. void IncPreGraphFinishedCount() { pre_graph_finished_count_++; }
  312. void IncPostGraphFinishedCount() { post_graph_finished_count_++; }
  313. void ResetGraphRunningStatus() {
  314. first_step_ = false;
  315. post_graph_finished_count_ = 0;
  316. pre_graph_finished_count_ = 0;
  317. }
  318. void OnRunGraphFinished() {
  319. for (auto post_graph : post_graphs_) {
  320. auto post_graph_ptr = post_graph.second.lock();
  321. if (post_graph_ptr != nullptr) {
  322. post_graph_ptr->IncPreGraphFinishedCount();
  323. }
  324. }
  325. for (auto pre_graph : pre_graphs_) {
  326. auto pre_graph_ptr = pre_graph.second.lock();
  327. if (pre_graph_ptr != nullptr) {
  328. pre_graph_ptr->IncPostGraphFinishedCount();
  329. }
  330. }
  331. }
  332. // end of handle graph dependency
  333. // The interface of allreduce send/recv pairs map.
  334. void InsertFromSendRecvPair(const CNodePtr &allreduce, const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
  335. allreduce_from_send_recv_pairs_[allreduce] = send_recv_pair;
  336. }
  337. void InsertToSendRecvPair(const CNodePtr &allreduce, const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
  338. allreduce_to_send_recv_pairs_[allreduce] = send_recv_pair;
  339. }
  340. const mindspore::HashMap<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_from_send_recv_pairs() const {
  341. return allreduce_from_send_recv_pairs_;
  342. }
  343. const mindspore::HashMap<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_to_send_recv_pairs() const {
  344. return allreduce_to_send_recv_pairs_;
  345. }
  346. uint32_t label_num() const { return label_num_; }
  347. void set_label_num(uint32_t num) { label_num_ = num; }
  348. // The graphs has recursion.
  349. bool recursive_call() const { return has_recursive_call_; }
  350. // The graphs has subgraph multi-call.
  351. bool subgraph_multi_call() const { return has_subgraph_multicall_; }
  352. // set flag to indicate whether has recursion.
  353. void set_recursive_call(bool flag) { has_recursive_call_ = flag; }
  354. // set flag to indicate whether has multi-call.
  355. void set_subgraph_multi_call(bool flag) { has_subgraph_multicall_ = flag; }
  356. bool is_all_nop_node() const { return is_all_nop_node_; }
  357. void set_is_all_nop_node(bool is_all_nop_node) { is_all_nop_node_ = is_all_nop_node; }
  358. std::map<AnfWithOutIndex, AnfWithOutIndex> graph_output_map() { return graph_output_to_front_node_map_; }
  359. std::map<AnfWithOutIndex, AnfWithOutIndex> front_node_to_graph_output_map() {
  360. return front_node_to_graph_output_map_;
  361. }
  362. // The interface to set/get the graph GIL flag.
  363. void set_is_need_gil(bool flag) { is_need_gil_ = flag; }
  364. bool is_need_gil() { return is_need_gil_; }
  365. bool IsDatasetGraph() const;
  366. bool is_executing_sink() const { return is_executing_sink_; }
  367. void set_is_executing_sink(bool is_executing_sink) { is_executing_sink_ = is_executing_sink; }
  368. bool is_loop_count_sink() const { return is_loop_count_sink_; }
  369. void set_is_loop_count_sink(bool is_loop_count_sink) { is_loop_count_sink_ = is_loop_count_sink; }
  370. const mindspore::HashMap<AnfNodePtr, AnfNodePtr> &front_backend_anf_map() { return front_backend_anf_map_; }
  371. AnfWithOutIndex GetElementInTupleBackendFrontIndexMap(const AnfNodePtr &back_node) {
  372. auto iter = tuple_backend_front_anf_index_map_.find(back_node);
  373. if (iter == tuple_backend_front_anf_index_map_.end()) {
  374. return AnfWithOutIndex(nullptr, 0);
  375. }
  376. return iter->second;
  377. }
  378. private:
  379. // remove value node form graph
  380. bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
  381. void SetKernelInfoForNode(const AnfNodePtr &node) const;
  382. AnfNodePtr MakeValueNode(const AnfNodePtr &node) const;
  383. void EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
  384. mindspore::HashSet<AnfNodePtr> *visited_nodes, bool comm_first = true);
  385. // update node edge list
  386. void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);
  387. // add node depend edge by data edge
  388. void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num);
  389. std::vector<AnfNodePtr> GetOutputNodes(const AnfNodePtr &node);
  390. AnfNodePtr TransValueNodeTuple(const AbstractBasePtr &abstract, const ValuePtr &value);
  391. AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract);
  392. AnfNodePtr TransCNodeTuple(const CNodePtr &node);
  393. AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx);
  394. std::vector<CNodePtr> SortStartLabelAndEndGoto();
  395. // checkout whether loop exist in graph
  396. void CheckLoop();
  397. uint32_t GetLoopNum(const std::map<AnfNodePtr, size_t> &none_zero_nodes);
  398. void GetLoopNodesByDFS(const AnfNodePtr &node, uint32_t *loop_num);
  399. void PostNewCNode(const CNodePtr &cnode);
  400. // members
  401. std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
  402. std::vector<AnfNodePtr> child_graph_result_;
  403. std::vector<CNodePtr> execution_order_;
  404. std::vector<CNodePtr> mem_reuse_exec_order_;
  405. uint32_t graph_id_;
  406. uint32_t stream_distinction_label_;
  407. uint32_t root_graph_id_{0};
  408. // record map bettween front anf and backend anf,use two map implement bidirectional map
  409. mindspore::HashMap<AnfNodePtr, AnfNodePtr> front_backend_anf_map_;
  410. mindspore::HashMap<AnfNodePtr, AnfNodePtr> backend_front_anf_map_;
  411. mindspore::HashMap<AnfNodePtr, AnfWithOutIndex> tuple_backend_front_anf_index_map_;
  412. // there may be a tensor from ME backend ,a value ndoe will be create according the tensor,map record
  413. mindspore::HashMap<tensor::TensorPtr, ValueNodePtr> tensor_to_value_node_map_;
  414. // include all value nodes
  415. mindspore::HashSet<ValueNodePtr> graph_value_nodes_;
  416. mindspore::HashMap<AnfNodePtr, size_t> node_input_num_;
  417. mindspore::HashMap<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_input_edges_;
  418. // record map between ref final output anf with index and ref origin input with index
  419. std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_;
  420. mindspore::HashMap<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_;
  421. std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_;
  422. // parameters that will be updated when graph is executed
  423. mindspore::HashSet<ParameterPtr> updated_parameters_;
  424. // graph needn't execute
  425. bool executable_{false};
  426. // exist summary node in graph
  427. bool summary_node_exist_{false};
  428. // valid inputs
  429. std::vector<bool> valid_inputs_;
  430. // child graph execute order in parent graph
  431. std::vector<std::weak_ptr<KernelGraph>> child_graph_order_;
  432. // device loop control frontend tensors
  433. std::map<std::string, tensor::TensorPtr> device_loop_ctrl_tensors_;
  434. // device loop control backend nodes
  435. std::map<std::string, mindspore::ParameterPtr> device_loop_ctrl_params_;
  436. // parameter graph
  437. std::weak_ptr<KernelGraph> parent_graph_;
  438. CNodePtr start_label_;
  439. CNodePtr end_goto_;
  440. // Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
  441. // related to the input of this kernel graph. The first of unordered map is the input of this kernel graph, the second
  442. // of unordered map is front node corresponding to the output of previous kernel graph.
  443. mindspore::HashMap<AnfNodePtr, AnfWithOutIndex> internal_parameter_to_front_node_map_;
  444. // The first of map is the backend graph output of this kernel graph, the second of map is front node corresponding to
  445. // the backend node with index.
  446. std::map<AnfWithOutIndex, AnfWithOutIndex> graph_output_to_front_node_map_;
  447. std::map<AnfWithOutIndex, AnfWithOutIndex> front_node_to_graph_output_map_;
  448. mindspore::HashMap<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
  449. mindspore::HashMap<AnfNodePtr, mindspore::HashMap<size_t, std::pair<AnfNodePtr, bool>>>
  450. internal_outputs_to_front_map_;
  451. mindspore::HashMap<AnfNodePtr, mindspore::HashMap<size_t, tensor::TensorPtr>> internal_outputs_tensor_map_;
  452. uint32_t current_epoch_;
  453. mindspore::HashMap<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_;
  454. std::set<AnfNodePtr> visited_nodes_;
  455. std::map<AnfNodePtr, AnfNodePtr> edge_to_;
  456. std::stack<AnfNodePtr> loop_nodes_;
  457. std::vector<AnfNodePtr> input_nodes_;
  458. std::vector<tensor::TensorPtr> input_tensors_;
  459. KernelMapTensor output_node_to_tensor_;
  460. mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> pre_graphs_;
  461. mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> post_graphs_;
  462. // The send/recv pairs inserted for allreduce, the key is allreduce kernel, the first of pair is send node, the second
  463. // of pair is recv node.
  464. mindspore::HashMap<CNodePtr, std::pair<CNodePtr, CNodePtr>> allreduce_from_send_recv_pairs_;
  465. mindspore::HashMap<CNodePtr, std::pair<CNodePtr, CNodePtr>> allreduce_to_send_recv_pairs_;
  466. std::atomic<size_t> pre_graph_finished_count_{0};
  467. std::atomic<size_t> post_graph_finished_count_{0};
  468. bool first_step_{true};
  469. bool has_optimizer_{false};
  470. bool is_dynamic_shape_{false};
  471. // Indicate the graphs has recursion or multi-call or not as the root graph.
  472. bool has_recursive_call_{false};
  473. bool has_subgraph_multicall_{false};
  474. // Number of labels. This is also the 'batch_num' for DavinciModel,
  475. // It should be 1 if no labels used for control flow.
  476. uint32_t label_num_ = 1;
  477. // If all the nodes of graph is the nop node.
  478. bool is_all_nop_node_{false};
  479. // Indicate whether the kernels in the graphs acquire Python GIL.
  480. bool is_need_gil_{false};
  481. // Indicate whether the kernel graph sink to the device executing.
  482. bool is_executing_sink_{false};
  483. // Indicate whether the kernel graph loop sink to the device executing.
  484. bool is_loop_count_sink_{false};
  485. };
  486. } // namespace session
  487. using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
  488. } // namespace mindspore
  489. #endif // MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H