You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_graph.h 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H
  17. #define MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H
  18. #include <vector>
  19. #include <unordered_map>
  20. #include <memory>
  21. #include <utility>
  22. #include <string>
  23. #include <queue>
  24. #include <stack>
  25. #include <map>
  26. #include <unordered_set>
  27. #include "ir/func_graph.h"
  28. #include "ir/anf.h"
  29. #include "utils/graph_utils.h"
  30. namespace mindspore {
  31. namespace session {
  32. using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>;
  33. class KernelGraph : public FuncGraph {
  34. public:
  35. KernelGraph() : graph_id_(0) {
  36. inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
  37. execution_order_ = {};
  38. executable_ = true;
  39. }
  40. ~KernelGraph() override = default;
  41. MS_DECLARE_PARENT(KernelGraph, FuncGraph);
  42. const std::vector<AnfNodePtr> &inputs() const;
  43. std::vector<AnfNodePtr> *MutableInputs() const { return inputs_.get(); }
  44. std::vector<AnfNodePtr> outputs() const;
  45. CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
  46. CNodePtr NewCNode(const CNodePtr &cnode);
  47. ParameterPtr NewParameter(const ParameterPtr &parameter = nullptr);
  48. ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr);
  49. std::vector<AnfNodePtr> SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node);
  50. void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; }
  51. const std::vector<CNodePtr> &execution_order() const { return execution_order_; }
  52. void SetExecOrderByDefault();
  53. uint32_t graph_id() const { return graph_id_; }
  54. void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; }
  55. // and a new front to backend anf relation to maop
  56. void FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf);
  57. // replace old backend anf with new backend anf
  58. void FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf);
  59. // get backend anf by front anf
  60. AnfNodePtr GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf);
  61. // check backend node whether exist in map
  62. bool BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf);
  63. // get value node by tensor
  64. ValueNodePtr GetValueNodeByTensor(const tensor::TensorPtr &tensor);
  65. // add value node tensor relation map
  66. void TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node);
  67. // get all value nodes of graph
  68. std::unordered_set<ValueNodePtr> graph_value_nodes() { return graph_value_nodes_; }
  69. // add value node to graph
  70. void AddValueNodeToGraph(const ValueNodePtr &value_node);
  71. // ref output is in map
  72. bool IsInRefOutputMap(const AnfWithOutIndex &pair) const;
  73. // get ref correspond pairs
  74. AnfWithOutIndex GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const;
  75. // add ref correspond pairs
  76. void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair);
  77. // get map
  78. std::map<AnfWithOutIndex, AnfWithOutIndex> GetRefMap() const { return ref_out_in_map_; }
  79. // checkout whether loop exist in graph
  80. void CheckLoop();
  81. // check whether graph is executable
  82. bool executable() const { return executable_; }
  83. // set executable of graph
  84. void set_executable(bool executable) { executable_ = executable; }
  85. // set invalid inputs for control sink
  86. std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
  87. std::vector<bool> ValidInputs() { return valid_inputs_; }
  88. private:
  89. // remove value node form graph
  90. bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
  91. // update node edge list
  92. void UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes);
  93. // add node depend edge by data edge or control depend
  94. void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num);
  95. // handle control depend
  96. std::vector<AnfNodePtr> GetOutputNodes(const AnfNodePtr &node);
  97. bool HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  98. std::unordered_set<AnfNodePtr> *visited_nodes);
  99. void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends);
  100. std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
  101. std::vector<CNodePtr> execution_order_;
  102. uint32_t graph_id_;
  103. // record map bettween front anf and backend anf,use two map implement bidirectional map
  104. std::unordered_map<AnfNodePtr, AnfNodePtr> front_backend_anf_map_;
  105. std::unordered_map<AnfNodePtr, AnfNodePtr> backend_front_anf_map_;
  106. // there may be a tensor from ME backend ,a value ndoe will be create according the tensor,map record
  107. std::unordered_map<tensor::TensorPtr, ValueNodePtr> tensor_to_value_node_map_;
  108. // include all value nodes
  109. std::unordered_set<ValueNodePtr> graph_value_nodes_;
  110. std::unordered_map<AnfNodePtr, size_t> node_input_num_;
  111. std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_input_edges_;
  112. // record map between ref final output anf with index and ref origin input with index
  113. std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_;
  114. std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_;
  115. // graph needn't execute
  116. bool executable_;
  117. // valid inputs
  118. std::vector<bool> valid_inputs_;
  119. };
  120. } // namespace session
  121. using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
  122. } // namespace mindspore
  123. #endif // MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H