You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_session.h 9.5 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H
  17. #define MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H
  18. #include <unordered_map>
  19. #include <string>
  20. #include <memory>
  21. #include <vector>
  22. #include <utility>
  23. #include <stack>
  24. #include <map>
  25. #include <tuple>
  26. #include <set>
  27. #include "session/session_basic.h"
  28. #include "session/kernel_graph.h"
  29. #include "kernel/kernel.h"
  30. #include "session/session_factory.h"
  31. #include "session/ascend_control_parser.h"
  32. namespace mindspore {
  33. namespace session {
  34. enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 };
  35. class AscendSession : public SessionBasic {
  36. public:
  37. AscendSession() { final_graph_id_ = kInvalidGraphId; }
  38. ~AscendSession() override = default;
  39. void Init(uint32_t device_id) override {
  40. SessionBasic::Init(device_id);
  41. context_ = std::make_shared<Context>(kAscendDevice, device_id);
  42. }
  43. GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
  44. GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
  45. void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
  46. void BuildGraph(GraphId) override;
  47. void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  48. const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) override;
  49. py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  50. const std::vector<tensor::TensorPtr> &input_tensors) override;
  51. // set parameters of final graph
  52. GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &args) override;
  53. // set output of final graph
  54. void SetFinalGraphOutput(const BaseRef &output) override;
  55. // insert switch and set the relative active ops
  56. void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override;
  57. // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
  58. void SetChildGraphInput(GraphId g, const VectorRef &args) override;
  59. // get graph id in child graphs by ME front anf node pointer
  60. GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override;
  61. // get graph id of final graph
  62. GraphId GetFinalRunGraph() const override { return final_graph_id_; }
  63. // insert active to graph
  64. void SetActive(GraphId, GraphId) override;
  65. // compile child graph when session have multiple child graphs
  66. void CompileChildGraph(const KernelGraphPtr &child_graph);
  67. void RecurseGetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
  68. void GetSummaryNodes(KernelGraph *graph);
  69. private:
  70. void InitRuntimeResource();
  71. void SelectKernel(const KernelGraph &kernel_graph) const;
  72. void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  73. void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  74. void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  75. void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const;
  76. void AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const;
  77. void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  78. void MemoryAlloc(KernelGraph *kernel_graph) const;
  79. void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
  80. void RunOpMemoryClear(const KernelGraph *kernel_graph) const;
  81. void GenerateTaskInfo(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  82. void LoadTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  83. void ExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  84. void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  85. void ExportChildGraphs(const GraphId graph_id);
  86. void LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  87. // below functions are used for run op
  88. void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const;
  89. void RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  90. size_t SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index);
  91. size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index);
  92. size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index);
  93. void SetFinalGraphOutput(const AnfNodePtr &node);
  94. void SetFinalGraphOutput(const ValuePtr &value);
  95. void SetFinalGraphOutput(const VectorRef &vec_output);
  96. void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims);
  97. // split graphs with recurse from root graph
  98. void SplitGraphs(NotNull<KernelGraphPtr> root_graph);
  99. void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
  100. void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
  101. void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph);
  102. std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
  103. const std::vector<CNodePtr> &list);
  104. void RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo);
  105. void RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo);
  106. AnfNodePtr BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph, const std::vector<CNodePtr> &child_graph_list);
  107. // merge execution order list of child graphs
  108. void MergeGraphExecOrder();
  109. // insert assion op to sync data bettween different graphs
  110. void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to);
  111. // insert mutiple assigns to graph
  112. void InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to);
  113. // insert active op to graph
  114. void InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream);
  115. // get execute index of graph
  116. size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph);
  117. // handle condition graph from vm
  118. void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id);
  119. // insert depend to graph, used to attch control nodes to graph
  120. void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node);
  121. // insert depend to graph, used to attch control nodes to graph
  122. void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node);
  123. // Get graph by graph id ,if not exist return null ptr
  124. KernelGraphPtr GetGraph(GraphId graph_id);
  125. // set child graph parameter if front arg is a anf
  126. void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx);
  127. // set child graph parameter if front arg is a tensor
  128. void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx);
  129. // update the execution order of all child graphs
  130. void UpdateGraphOrder(GraphId to_graph);
  131. // handle switch when merge
  132. void MergeSwitchCompile();
  133. // get graph order vector by graph id
  134. std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id);
  135. // get graph order type vector by graph id
  136. std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id);
  137. // copy output of if and else
  138. void CopyOutputOfIf(GraphId false_graph_id);
  139. // check if graph cache exist
  140. bool GraphCacheExist(const GraphInfo &graph_info) const;
  141. // insert all assign to child graph
  142. void InsertAllAssigns();
  143. // create fake output of final graph
  144. AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output);
  145. // sync intial tensors' data to device
  146. void SyncInitialTenosrToDevice();
  147. void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
  148. // member variables
  149. // key is final_graph_id,value is child graph execute order of final graph
  150. std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
  151. // key is final_graph_id,value is the graph types of child graphs
  152. std::unordered_map<GraphId, std::vector<GraphType>> graph_order_types_;
  153. // record condition graph of while
  154. std::unordered_map<GraphId, GraphId> while_condition_graphs_;
  155. // record all conditions
  156. std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_;
  157. std::unordered_map<GraphId, AnfNodePtr> condition_output_;
  158. // share parameters
  159. std::vector<std::tuple<AnfNodePtr, GraphId, size_t>> assigns_;
  160. // initial tensors, these tensor will sync data to device before run graph
  161. std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_;
  162. // final_graph_id is used in every root graph has it's own session situation
  163. GraphId final_graph_id_;
  164. };
  165. MS_REG_SESSION(kAscendDevice, AscendSession);
  166. } // namespace session
  167. } // namespace mindspore
  168. #endif // MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H