You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_graph.h 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H
  17. #define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H
  18. #include <vector>
  19. #include <unordered_map>
  20. #include <memory>
  21. #include <utility>
  22. #include <string>
  23. #include <queue>
  24. #include <map>
  25. #include <set>
  26. #include <unordered_set>
  27. #include "ir/func_graph.h"
  28. #include "ir/anf.h"
  29. #include "utils/graph_utils.h"
  30. #include "utils/contract.h"
  31. #include "runtime/device/kernel_info.h"
  32. namespace mindspore {
  33. namespace session {
  34. using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>;
  35. class KernelGraph : public FuncGraph {
  36. public:
  37. KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), null_output_(false), current_epoch_(0) {
  38. inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
  39. execution_order_ = {};
  40. executable_ = true;
  41. summary_node_exist_ = false;
  42. stream_distinction_label_ = kInvalidDistincLabel;
  43. }
  44. ~KernelGraph() override;
  45. MS_DECLARE_PARENT(KernelGraph, FuncGraph);
  46. const std::vector<AnfNodePtr> &inputs() const;
  47. std::vector<AnfNodePtr> *MutableInputs() const { return inputs_.get(); }
  48. std::vector<AnfNodePtr> outputs() const;
  49. CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
  50. void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
  51. CNodePtr NewCNode(const CNodePtr &cnode);
  52. ParameterPtr NewParameter(const ParameterPtr &parameter = nullptr);
  53. ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract);
  54. ValueNodePtr NewValueNode(const ValuePtr &value);
  55. ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr);
  56. std::vector<AnfNodePtr> SplitTupleOutputNodeToNodeList(const AnfNodePtr &node);
  57. void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; }
  58. const std::vector<CNodePtr> &execution_order() const { return execution_order_; }
  59. void SetExecOrderByDefault();
  60. uint32_t graph_id() const { return graph_id_; }
  61. void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; }
  62. // and a new front to backend anf relation to maop
  63. void FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf);
  64. // replace old backend anf with new backend anf
  65. void FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf);
  66. // get backend anf by front anf
  67. AnfNodePtr GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf);
  68. // check backend node whether exist in map
  69. bool BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf);
  70. // get value node by tensor
  71. ValueNodePtr GetValueNodeByTensor(const tensor::TensorPtr &tensor);
  72. // add value node tensor relation map
  73. void TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node);
  74. // get all value nodes of graph
  75. const std::unordered_set<ValueNodePtr> graph_value_nodes() const { return graph_value_nodes_; }
  76. // add value node to graph
  77. void AddValueNodeToGraph(const ValueNodePtr &value_node);
  78. // ref output is in map
  79. bool IsInRefOutputMap(const AnfWithOutIndex &pair) const;
  80. // get ref correspond pairs
  81. AnfWithOutIndex GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const;
  82. // add ref correspond pairs
  83. void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair);
  84. // get map
  85. std::map<AnfWithOutIndex, AnfWithOutIndex> GetRefMap() const { return ref_out_in_map_; }
  86. // checkout whether loop exist in graph
  87. void CheckLoop();
  88. // check whether graph is executable
  89. bool executable() const { return executable_; }
  90. // set executable of graph
  91. void set_executable(bool executable) { executable_ = executable; }
  92. // set summary_node of graph
  93. void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; }
  94. // check whether exist summary node in graph
  95. bool summary_node_exist() const { return summary_node_exist_; }
  96. // set invalid inputs for control sink
  97. std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
  98. std::vector<bool> valid_inputs() const { return valid_inputs_; }
  99. // replace node in graph
  100. void ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodePtr> new_anf_node);
  101. // set stream label of graph
  102. void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; }
  103. // get stream label of graph
  104. uint32_t stream_distinction_label() { return stream_distinction_label_; }
  105. // refresh execute kernel stream label
  106. void UpdateExecuteKernelStreamLabel();
  107. // calculate the leaf graph order of root graph
  108. std::vector<std::shared_ptr<KernelGraph>> GetLeafGraphOrder();
  109. // the child graph of current graph
  110. const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order() const { return child_graph_order_; }
  111. void set_child_graph_order(const std::vector<std::shared_ptr<KernelGraph>> &order) { child_graph_order_ = order; }
  112. // checkout whether current graph is leaf graph
  113. bool IsLeafGraph() const;
  114. // set input_tensors pointer of control parameter
  115. void set_input_ctrl_tensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &input_tensors_ptr) {
  116. input_ctrl_tensors_ = input_tensors_ptr;
  117. }
  118. // get input_tensors pointer of control parameter
  119. std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors() const { return input_ctrl_tensors_; }
  120. // get parent kernel graph
  121. std::shared_ptr<KernelGraph> parent_graph() const { return parent_graph_; }
  122. // set parent kernel graph
  123. void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; }
  124. // find anf node in graph
  125. std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
  126. // get real inputs
  127. const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs() const { return real_inputs_; }
  128. void SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg);
  129. // mark unreused args
  130. void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph);
  131. const std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> &unreuse_args() const { return unreuse_args_; }
  132. // used to dump ir
  133. std::string ToString() const override;
  134. // update the real input if the node is a call
  135. void UpdateCallRealInput();
  136. void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; }
  137. CNodePtr get_start_label() { return start_label_; }
  138. void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; }
  139. CNodePtr get_end_goto() { return end_goto_; }
  140. bool get_output_null() { return null_output_; }
  141. void set_output_null(bool is_output_null) { null_output_ = is_output_null; }
  142. void PrintGraphExecuteOrder() const;
  143. const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; }
  144. void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; }
  145. void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node);
  146. void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1,
  147. int dst_output_idx = -1);
  148. AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
  149. bool IsInternalOutput(const AnfNodePtr &node) const;
  150. void AddFinalOutputKernel(const AnfNodePtr &node);
  151. bool IsFinalOutputKernel(const AnfNodePtr &node) const;
  152. uint32_t current_epoch() const { return current_epoch_; }
  153. void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }
  154. void UpdateChildGraphOrder();
  155. const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; }
  156. void AddChildGraphResult(const AnfNodePtr &parameter) { child_graph_result_.push_back(parameter); }
  157. void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
  158. child_graph_result_ = child_graph_result;
  159. }
  160. private:
  161. // remove value node form graph
  162. bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
  163. void SetKernelInfoForNode(const AnfNodePtr &node) const;
  164. std::vector<AnfNodePtr> SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node);
  165. std::vector<AnfNodePtr> SplitTupleParameterToNodeList(const ParameterPtr &parameter);
  166. AnfNodePtr MakeValueNode(const AnfNodePtr &node);
  167. void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
  168. std::unordered_set<AnfNodePtr> *visited_nodes);
  169. // update node edge list
  170. void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);
  171. // add node depend edge by data edge or control depend
  172. void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num);
  173. // handle control depend
  174. std::vector<AnfNodePtr> GetOutputNodes(const AnfNodePtr &node);
  175. bool HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  176. std::unordered_set<AnfNodePtr> *visited_nodes);
  177. void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends);
  178. std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
  179. std::vector<AnfNodePtr> child_graph_result_;
  180. std::vector<CNodePtr> execution_order_;
  181. uint32_t graph_id_;
  182. uint32_t stream_distinction_label_;
  183. // record map bettween front anf and backend anf,use two map implement bidirectional map
  184. std::unordered_map<AnfNodePtr, AnfNodePtr> front_backend_anf_map_;
  185. std::unordered_map<AnfNodePtr, AnfNodePtr> backend_front_anf_map_;
  186. // there may be a tensor from ME backend ,a value ndoe will be create according the tensor,map record
  187. std::unordered_map<tensor::TensorPtr, ValueNodePtr> tensor_to_value_node_map_;
  188. // include all value nodes
  189. std::unordered_set<ValueNodePtr> graph_value_nodes_;
  190. std::unordered_map<AnfNodePtr, size_t> node_input_num_;
  191. std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_input_edges_;
  192. // record map between ref final output anf with index and ref origin input with index
  193. std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_;
  194. std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_;
  195. std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_;
  196. // graph needn't execute
  197. bool executable_;
  198. // exist summary node in graph
  199. bool summary_node_exist_;
  200. // valid inputs
  201. std::vector<bool> valid_inputs_;
  202. // new members for control sink process
  203. // all child grahs refers to partial node
  204. std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_;
  205. // child graph execute order in root graph
  206. std::vector<std::shared_ptr<KernelGraph>> child_graph_order_;
  207. // input_tensors of control parameter
  208. std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors_;
  209. // parameter graph
  210. std::shared_ptr<KernelGraph> parent_graph_;
  211. // record real parameters,inputs_ is the formal parameters
  212. std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_;
  213. std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> unreuse_args_;
  214. CNodePtr start_label_;
  215. CNodePtr end_goto_;
  216. bool null_output_;
  217. std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
  218. std::unordered_map<AnfNodePtr, std::unordered_map<int, AnfNodePtr>> internal_outputs_to_front_map_;
  219. std::set<AnfNodePtr> final_output_kernels_;
  220. uint32_t current_epoch_;
  221. };
  222. } // namespace session
  223. using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
  224. } // namespace mindspore
  225. #endif // MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H