You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_session.h 7.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H
  17. #define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H
  18. #include <unordered_map>
  19. #include <string>
  20. #include <memory>
  21. #include <vector>
  22. #include <utility>
  23. #include <stack>
  24. #include <map>
  25. #include <tuple>
  26. #include <set>
  27. #include "backend/session/session_basic.h"
  28. #include "backend/session/kernel_graph.h"
  29. #include "backend/kernel_compiler/kernel.h"
  30. #include "backend/session/session_factory.h"
  31. #include "backend/session/ascend_control_parser.h"
  32. #include "runtime/context.h"
  33. namespace mindspore {
  34. namespace session {
  35. enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 };
  36. class AscendSession : public SessionBasic {
  37. public:
  38. AscendSession() { final_graph_id_ = kInvalidGraphId; }
  39. ~AscendSession() {
  40. if (rt_context_ != nullptr) {
  41. auto ret = rtCtxDestroy(rt_context_);
  42. if (ret != RT_ERROR_NONE) {
  43. MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]";
  44. }
  45. rt_context_ = nullptr;
  46. }
  47. }
  48. void Init(uint32_t device_id) override {
  49. InitDevice(kAscendDevice, device_id);
  50. auto ret = rtCtxCreate(&rt_context_, 0, device_id);
  51. if (ret != RT_ERROR_NONE) {
  52. MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
  53. }
  54. ret = rtCtxSetCurrent(rt_context_);
  55. if (ret != RT_ERROR_NONE) {
  56. MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]";
  57. }
  58. }
  59. // get graph id in child graphs by ME front anf node pointer
  60. GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override;
  61. // get graph id of final graph
  62. GraphId GetFinalRunGraph() const override { return final_graph_id_; }
  63. protected:
  64. GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
  65. GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override;
  66. GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override;
  67. void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
  68. void BuildGraphImpl(GraphId) override;
  69. void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  70. const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) override;
  71. void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
  72. const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override;
  73. private:
  74. // compile child graph when session have multiple child graphs
  75. void CompileChildGraph(const KernelGraphPtr &child_graph);
  76. void RecurseSetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
  77. void SetSummaryNodes(KernelGraph *graph) override;
  78. void InitRuntimeResource();
  79. void SelectKernel(const KernelGraph &kernel_graph) const;
  80. void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  81. void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  82. void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  83. void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const;
  84. void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  85. void MemoryAlloc(KernelGraph *kernel_graph) const;
  86. void RunOpMemoryAlloc(const ValuePtr &pre_output_value, const std::vector<tensor::TensorPtr> &input_tensors,
  87. KernelGraph *kernel_graph) const;
  88. void RunOpMemoryClear(const KernelGraph *kernel_graph) const;
  89. void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  90. void Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const;
  91. void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  92. void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs);
  93. void LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
  94. // below functions are used for run op
  95. void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const;
  96. static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
  97. static void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
  98. void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph);
  99. // merge execution order list of child graphs
  100. void MergeGraphExecOrder();
  101. // insert assion op to sync data bettween different graphs
  102. void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to);
  103. // get graph order vector by graph id
  104. const std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id) const;
  105. // get graph order type vector by graph id
  106. const std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id) const;
  107. // check if graph cache exist
  108. bool GraphCacheExist(const GraphInfo &graph_info) const;
  109. // insert all assign to child graph
  110. void InsertAllAssigns();
  111. // sync intial tensors' data to device
  112. void SyncInitialTenosrToDevice();
  113. void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
  114. // create parameter to receive data from multiple branch output
  115. void CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
  116. void SelectKernel(NotNull<KernelGraphPtr> root_graph);
  117. void RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> const memo,
  118. size_t *const raise_precision_count, size_t *const reduce_precision_count) const;
  119. void IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
  120. void HardwareOptimize(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
  121. void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
  122. void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
  123. // key is final_graph_id,value is child graph execute order of final graph
  124. std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
  125. // key is final_graph_id,value is the graph types of child graphs
  126. std::unordered_map<GraphId, std::vector<GraphType>> graph_order_types_;
  127. // share parameters
  128. std::vector<std::tuple<AnfNodePtr, GraphId, size_t>> assigns_;
  129. // initial tensors, these tensor will sync data to device before run graph
  130. std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_;
  131. // final_graph_id is used in every root graph has it's own session situation
  132. GraphId final_graph_id_;
  133. // ascend runtime context
  134. rtContext_t rt_context_{nullptr};
  135. };
  136. MS_REG_SESSION(kAscendDevice, AscendSession);
  137. } // namespace session
  138. } // namespace mindspore
  139. #endif // MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H