You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_graph.h 13 kB

adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H
  17. #define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H
  18. #include <vector>
  19. #include <unordered_map>
  20. #include <memory>
  21. #include <utility>
  22. #include <string>
  23. #include <queue>
  24. #include <map>
  25. #include <set>
  26. #include <stack>
  27. #include <unordered_set>
  28. #include "ir/func_graph.h"
  29. #include "ir/anf.h"
  30. #include "ir/graph_utils.h"
  31. #include "utils/contract.h"
  32. #include "runtime/device/kernel_info.h"
  33. namespace mindspore {
  34. namespace session {
  35. using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>;
  36. class KernelGraph : public FuncGraph {
  37. public:
  38. KernelGraph()
  39. : graph_id_(0),
  40. start_label_(nullptr),
  41. end_goto_(nullptr),
  42. null_output_(false),
  43. current_epoch_(0),
  44. is_dynamic_shape_(false) {
  45. inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
  46. execution_order_ = {};
  47. executable_ = true;
  48. summary_node_exist_ = false;
  49. stream_distinction_label_ = kInvalidDistincLabel;
  50. }
  51. ~KernelGraph() override;
  52. MS_DECLARE_PARENT(KernelGraph, FuncGraph);
  53. const std::vector<AnfNodePtr> &inputs() const;
  54. std::vector<AnfNodePtr> *MutableInputs() const { return inputs_.get(); }
  55. void ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter);
  56. std::vector<AnfNodePtr> outputs() const;
  57. CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
  58. void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
  59. CNodePtr NewCNode(const CNodePtr &cnode);
  60. ParameterPtr NewParameter(const ParameterPtr &parameter = nullptr);
  61. ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract);
  62. ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value);
  63. ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr);
  64. // trans tuple output to maketuple + no_tuple out
  65. AnfNodePtr TransTupleToMakeTuple(const AnfNodePtr &node);
  66. void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; }
  67. const std::vector<CNodePtr> &execution_order() const { return execution_order_; }
  68. void SetExecOrderByDefault();
  69. uint32_t graph_id() const { return graph_id_; }
  70. void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; }
  71. // and a new front to backend anf relation to maop
  72. void FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf);
  73. // replace old backend anf with new backend anf
  74. void FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf);
  75. // get backend anf by front anf
  76. AnfNodePtr GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf);
  77. // check backend node whether exist in map
  78. bool BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf);
  79. // get value node by tensor
  80. ValueNodePtr GetValueNodeByTensor(const tensor::TensorPtr &tensor);
  81. // add value node tensor relation map
  82. void TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node);
  83. // get all value nodes of graph
  84. const std::unordered_set<ValueNodePtr> graph_value_nodes() const { return graph_value_nodes_; }
  85. // add value node to graph
  86. void AddValueNodeToGraph(const ValueNodePtr &value_node);
  87. // ref output is in map
  88. bool IsInRefOutputMap(const AnfWithOutIndex &pair) const;
  89. // get ref correspond pairs
  90. AnfWithOutIndex GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const;
  91. // add ref correspond pairs
  92. void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair);
  93. // get map
  94. std::map<AnfWithOutIndex, AnfWithOutIndex> GetRefMap() const { return ref_out_in_map_; }
  95. // check whether graph is executable
  96. bool executable() const { return executable_; }
  97. // set executable of graph
  98. void set_executable(bool executable) { executable_ = executable; }
  99. // set summary_node of graph
  100. void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; }
  101. // check whether exist summary node in graph
  102. bool summary_node_exist() const { return summary_node_exist_; }
  103. // set invalid inputs for control sink
  104. std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
  105. std::vector<bool> valid_inputs() const { return valid_inputs_; }
  106. // replace node in graph
  107. void ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodePtr> new_anf_node);
  108. // set stream label of graph
  109. void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; }
  110. // get stream label of graph
  111. uint32_t stream_distinction_label() { return stream_distinction_label_; }
  112. // refresh execute kernel stream label
  113. void UpdateExecuteKernelStreamLabel();
  114. // calculate the leaf graph order of root graph
  115. std::vector<std::shared_ptr<KernelGraph>> GetLeafGraphOrder();
  116. // the child graph of current graph
  117. const std::vector<std::weak_ptr<KernelGraph>> &child_graph_order() const { return child_graph_order_; }
  118. void set_child_graph_order(const std::vector<std::weak_ptr<KernelGraph>> &order) { child_graph_order_ = order; }
  119. // checkout whether current graph is leaf graph
  120. bool IsLeafGraph() const;
  121. // set input_tensors pointer of control parameter
  122. void set_input_ctrl_tensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &input_tensors_ptr) {
  123. input_ctrl_tensors_ = input_tensors_ptr;
  124. }
  125. // get input_tensors pointer of control parameter
  126. std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors() const { return input_ctrl_tensors_; }
  127. // get parent kernel graph
  128. std::weak_ptr<KernelGraph> parent_graph() const { return parent_graph_; }
  129. // set parent kernel graph
  130. void set_parent_graph(const std::weak_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; }
  131. // find anf node in graph
  132. std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
  133. std::vector<CNodePtr> FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const;
  134. // used to dump ir
  135. std::string ToString() const override;
  136. void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; }
  137. CNodePtr get_start_label() { return start_label_; }
  138. void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; }
  139. CNodePtr get_end_goto() { return end_goto_; }
  140. bool get_output_null() { return null_output_; }
  141. void set_output_null(bool is_output_null) { null_output_ = is_output_null; }
  142. void PrintGraphExecuteOrder() const;
  143. const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; }
  144. void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; }
  145. void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, int output_idx = 0,
  146. bool unique_target = false);
  147. void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1,
  148. int dst_output_idx = -1);
  149. AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
  150. bool IsInternalOutput(const AnfNodePtr &node, int output_idx = -1) const;
  151. bool IsUniqueTargetInternalOutput(const AnfNodePtr &node, int output_idx) const;
  152. void AddInternalOutputTensor(const AnfNodePtr &node, int output_idx, const tensor::TensorPtr &tensor);
  153. tensor::TensorPtr GetInternalOutputTensor(const AnfNodePtr &node, int output_idx);
  154. uint32_t current_epoch() const { return current_epoch_; }
  155. void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }
  156. void UpdateChildGraphOrder();
  157. const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; }
  158. void AddChildGraphResult(const AnfNodePtr &parameter) { child_graph_result_.push_back(parameter); }
  159. void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
  160. child_graph_result_ = child_graph_result;
  161. }
  162. void InsertTupleParameterToMakeTupleMap(const AnfNodePtr &param, const AnfNodePtr &make_tuple) {
  163. if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
  164. return;
  165. }
  166. tuple_parameter_to_make_tuple_map_[param] = make_tuple;
  167. }
  168. AnfNodePtr FindTupleParameterToMakeTupleMap(const AnfNodePtr &param) {
  169. if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
  170. return tuple_parameter_to_make_tuple_map_[param];
  171. } else {
  172. return nullptr;
  173. }
  174. }
  175. void RemoveNodeFromGraph(const AnfNodePtr &node);
  176. void UpdateGraphDynamicAttr();
  177. bool is_dynamic_shape() const { return is_dynamic_shape_; }
  178. private:
  179. // remove value node form graph
  180. bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
  181. void SetKernelInfoForNode(const AnfNodePtr &node) const;
  182. AnfNodePtr MakeValueNode(const AnfNodePtr &node);
  183. void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
  184. std::unordered_set<AnfNodePtr> *visited_nodes);
  185. // update node edge list
  186. void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);
  187. // add node depend edge by data edge or control depend
  188. void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num);
  189. void UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes,
  190. const std::vector<AnfNodePtr> &real_depend_nodes);
  191. // handle control depend
  192. std::vector<AnfNodePtr> GetOutputNodes(const AnfNodePtr &node);
  193. bool HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  194. std::unordered_set<AnfNodePtr> *visited_nodes);
  195. void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends);
  196. AnfNodePtr TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value);
  197. AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract);
  198. AnfNodePtr TransCNodeTuple(const CNodePtr &node);
  199. AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx);
  200. std::vector<CNodePtr> SortStartLabelAndEndGoto();
  201. // checkout whether loop exist in graph
  202. void CheckLoop();
  203. uint32_t GetLoopNum(std::map<AnfNodePtr, size_t> none_zero_nodes);
  204. void GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num);
  205. std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
  206. std::vector<AnfNodePtr> child_graph_result_;
  207. std::vector<CNodePtr> execution_order_;
  208. uint32_t graph_id_;
  209. uint32_t stream_distinction_label_;
  210. // record map bettween front anf and backend anf,use two map implement bidirectional map
  211. std::unordered_map<AnfNodePtr, AnfNodePtr> front_backend_anf_map_;
  212. std::unordered_map<AnfNodePtr, AnfNodePtr> backend_front_anf_map_;
  213. // there may be a tensor from ME backend ,a value ndoe will be create according the tensor,map record
  214. std::unordered_map<tensor::TensorPtr, ValueNodePtr> tensor_to_value_node_map_;
  215. // include all value nodes
  216. std::unordered_set<ValueNodePtr> graph_value_nodes_;
  217. std::unordered_map<AnfNodePtr, size_t> node_input_num_;
  218. std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_input_edges_;
  219. // record map between ref final output anf with index and ref origin input with index
  220. std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_;
  221. std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_;
  222. std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_;
  223. // graph needn't execute
  224. bool executable_;
  225. // exist summary node in graph
  226. bool summary_node_exist_;
  227. // valid inputs
  228. std::vector<bool> valid_inputs_;
  229. // child graph execute order in root graph
  230. std::vector<std::weak_ptr<KernelGraph>> child_graph_order_;
  231. // input_tensors of control parameter
  232. std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors_;
  233. // parameter graph
  234. std::weak_ptr<KernelGraph> parent_graph_;
  235. CNodePtr start_label_;
  236. CNodePtr end_goto_;
  237. bool null_output_;
  238. std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
  239. std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_;
  240. std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;
  241. uint32_t current_epoch_;
  242. std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_;
  243. std::set<AnfNodePtr> visited_nodes_;
  244. std::map<AnfNodePtr, AnfNodePtr> edge_to_;
  245. std::stack<AnfNodePtr> loop_nodes_;
  246. bool is_dynamic_shape_;
  247. };
  248. } // namespace session
  249. using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
  250. } // namespace mindspore
  251. #endif // MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H