You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_manager.h 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_MANAGER_GRAPH_MANAGER_H_
  17. #define GE_GRAPH_MANAGER_GRAPH_MANAGER_H_
  18. #include <iostream>
  19. #include <map>
  20. #include <memory>
  21. #include <set>
  22. #include <string>
  23. #include <thread>
  24. #include <vector>
  25. #include "common/blocking_queue.h"
  26. #include "common/ge_inner_error_codes.h"
  27. #include "common/helper/model_cache_helper.h"
  28. #include "external/graph/types.h"
  29. #include "ge/ge_api_types.h"
  30. #include "graph/build/graph_builder.h"
  31. #include "graph/execute/graph_execute.h"
  32. #include "graph/ge_local_context.h"
  33. #include "graph/load/graph_loader.h"
  34. #include "graph/manager/graph_manager_utils.h"
  35. #include "graph/manager/util/variable_accelerate_ctrl.h"
  36. #include "graph/optimize/graph_optimize.h"
  37. #include "graph/partition/graph_partition.h"
  38. #include "graph/preprocess/graph_preprocess.h"
  39. #include "model/ge_model.h"
  40. namespace ge {
  41. class GraphManager {
  42. public:
  43. GraphManager();
  44. ~GraphManager() = default;
  45. ///
  46. /// @ingroup ge_graph
  47. /// @brief graph manager init
  48. /// @param [in] options user config params
  49. /// @return Status result of function
  50. ///
  51. Status Initialize(const std::map<string, string> &options);
  52. ///
  53. /// @ingroup ge_graph
  54. /// @brief graph manager finalize
  55. /// @return Status result of function
  56. ///
  57. Status Finalize();
  58. ///
  59. /// @ingroup ge_graph
  60. /// @brief add specific graph
  61. /// @param [in] graph_id graph id
  62. /// @param [out] Graph output graph
  63. /// @return Status result of function
  64. ///
  65. Status AddGraph(const GraphId &graph_id, const Graph &graph, const std::map<std::string, std::string> &options);
  66. ///
  67. /// @ingroup ge_graph
  68. /// @brief remove specific graph
  69. /// @param [in] graph_id graph id
  70. /// @return Status result of function
  71. ///
  72. Status RemoveGraph(const GraphId &graph_id);
  73. ///
  74. /// @ingroup ge_graph
  75. /// @brief run specific graph
  76. /// @param [in] graph_id graph id
  77. /// @param [in] inputs input data
  78. /// @param [out] outputs output data
  79. /// @return Status result of function
  80. ///
  81. Status RunGraph(const GraphId &graph_id, const std::vector<GeTensor> &inputs, std::vector<GeTensor> &outputs,
  82. uint64_t session_id = INVALID_SESSION_ID);
  83. ///
  84. /// @ingroup ge_graph
  85. /// @brief build specific graph
  86. /// @param [in] graph_id graph id
  87. /// @param [in] inputs input data
  88. /// @param [out] models build result
  89. /// @return Status result of function
  90. ///
  91. ge::Status BuildGraph(const GraphId &graph_id, const std::vector<GeTensor> &inputs, GeRootModelPtr &models);
  92. ///
  93. /// @ingroup ge_graph
  94. /// @brief Save extra attribute to Model
  95. /// @param [in] model: Model attribues will save to.
  96. /// @param [in] type: type of OpDesc.
  97. /// @param [in] attrs: attributes of OpDesc
  98. /// @param [in] inputs: input tensor
  99. /// @param [in] outputs: output tensor
  100. /// @return: Status
  101. ///
  102. Status SaveParams(ge::GeModel &model, const std::string &type, const std::map<string, GeAttrValue> &attrs,
  103. const std::vector<GeTensor> &inputs, const std::vector<GeTensor> &outputs);
  104. ///
  105. /// @ingroup ge_graph
  106. /// @brief get variable value from the session with specific session id
  107. /// @param [in] sessionId session id
  108. /// @param [in] name op name
  109. /// @param [out] val out value tensor
  110. /// @return Status result of function
  111. ///
  112. Status GetVariable(const std::string &name, Tensor &val);
  113. ///
  114. /// @ingroup ge_graph
  115. /// @brief run graph async on session with specific session id
  116. /// @param [in] graph_id graph id
  117. /// @param [in] inputs input data
  118. /// @param [out] callback: callback while run graph async finish
  119. /// @return Status result of function
  120. ///
  121. Status RunGraphAsync(const GraphId &graph_id, const std::vector<ge::InputTensorInfo> &inputs, uint64_t session_id,
  122. RunAsyncCallback callback);
  123. ///
  124. /// @ingroup ge_graph
  125. /// @brief me register the callback function to get the result of summary or checkpoin
  126. /// @param [in] key: summary or checkpoint
  127. /// @param [in] callbak: The real callback object of me
  128. /// @return Status result of function
  129. ///
  130. Status RegisterCallBackFunc(
  131. const std::string &key, const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback);
  132. const bool GetTrainFlag() const { return options_.train_graph_flag; }
  133. bool IsGraphNeedRebuild(uint32_t graph_id);
  134. Status GenerateInfershapeGraph(GraphId &graph_id);
  135. const std::map<std::string, std::string> *GetGraphOptions(uint32_t graph_id);
  136. void SetOptionsRunGraphFlag(bool run_graph_flag);
  137. private:
  138. struct PreRunArgs {
  139. GraphId graph_id;
  140. std::vector<ge::InputTensorInfo> input_tensor;
  141. uint64_t session_id;
  142. GEThreadLocalContext context;
  143. RunAsyncCallback callback;
  144. };
  145. struct RunArgs {
  146. GraphNodePtr graph_node;
  147. GraphId graph_id;
  148. std::vector<ge::InputTensorInfo> input_tensor;
  149. GeRootModelPtr ge_root_model;
  150. GEThreadLocalContext context;
  151. RunAsyncCallback callback;
  152. };
  153. Status GetGraphNode(const GraphId &graph_id, GraphNodePtr &out);
  154. std::shared_ptr<GraphModelListener> GetModelListener() const { return graph_run_listener_; }
  155. static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, const SubGraphInfoPtr &sub_graph_info_ptr,
  156. uint64_t session_id, const GEThreadLocalContext &ge_context);
  157. Status PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, GeRootModelPtr &ge_root_model,
  158. uint64_t session_id = INVALID_SESSION_ID);
  159. Status OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, uint64_t session_id);
  160. Status Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, GeRootModelPtr &ge_root_model,
  161. uint64_t session_id);
  162. Status StartForRunGraph(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs,
  163. GeRootModelPtr &ge_root_model, uint64_t session_id = INVALID_SESSION_ID);
  164. Status InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, const std::vector<GeTensor> &inputs,
  165. std::vector<GeTensor> &outputs);
  166. Status ParseOptions(const std::map<std::string, std::string> &options);
  167. static void ParseOption(const std::map<std::string, std::string> &options, const std::string &key,
  168. std::string &option);
  169. static Status ParseOption(const std::map<std::string, std::string> &options, const std::string &key, bool &option);
  170. static Status ParseOption(const std::map<std::string, std::string> &options, const std::string &key, int &option);
  171. static Status ParseOption(const std::map<std::string, std::string> &options, const std::string &key,
  172. std::map<std::string, int> &option);
  173. static void Trim(std::string &str);
  174. static Status CheckEngineName(const std::string &engine_name, const std::string &key,
  175. const std::map<std::string, int> &option);
  176. static Status ParseParallelNum(const std::string &parallel_num, const std::string &key, int &num);
  177. static Status ParseTrainGraphFlag(bool &options, bool &option);
  178. static bool IsPerfLevelInvalid(int32_t perf_level);
  179. Status SummaryHandle(const GraphId &graph_id, std::vector<GeTensor> &outputs);
  180. Status CheckpointHandle(const GraphId &graph_id, const ComputeGraphPtr &compute_graph,
  181. const std::vector<GeTensor> &outputs);
  182. // call the callback function of ME to push summary result data to ME
  183. Status PushSummaryData2ME(const GraphId &graph_id, const std::map<std::string, ge::Tensor> &summary_data);
  184. // call the callback function of ME to push save result data to ME
  185. Status PushSaveData2ME(const GraphId &graph_id, const std::map<std::string, ge::Tensor> &save_data);
  186. bool IsCheckpointGraph(ComputeGraphPtr &compute_graph);
  187. bool CheckNetOutputForCheckpointGraph(NodePtr &node);
  188. bool CheckVariableForCheckpointGraph(NodePtr &node);
  189. bool CheckTransOpForCheckpointGraph(NodePtr &node);
  190. Status MergeSubGraph(ComputeGraphPtr &compute_graph, const ge::ComputeGraphPtr &original_compute_graph);
  191. Status SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph);
  192. void SetAttrForHcomBroadCastOp(ge::ComputeGraphPtr &compute_graph);
  193. bool IsBroadCastOpData(const ge::NodePtr &var_node);
  194. void AdjustBroadCastOpData(const ge::NodePtr &var_node);
  195. bool IsAssignOpData(const ge::NodePtr &var_node);
  196. void AdjustAssignOpData(const ge::NodePtr &var_node);
  197. bool ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, const map<string, std::set<int>> &confirm_ops,
  198. ge::NodePtr &use_node);
  199. bool ConfirmUseOpAndIndexByNode(const ge::NodePtr &var_node, const map<string, std::set<int>> &confirm_ops,
  200. ge::NodePtr &use_node);
  201. // graph context
  202. std::shared_ptr<GraphContext> GetGraphContext() const { return graph_context_; }
  203. Status RemoveIsolatedConst(ge::ComputeGraphPtr &compute_graph);
  204. Status RemoveIsolatedConstInThisGraph(ge::ComputeGraphPtr &compute_graph);
  205. Status OptimizeStage1(ComputeGraphPtr &compute_graph);
  206. Status OptimizeStage2(ComputeGraphPtr &compute_graph);
  207. Status OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph);
  208. Status NewOptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph);
  209. Status LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node);
  210. Status CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node);
  211. bool CheckModelLoad(const GeRootModelPtr &ge_model, bool load_flag);
  212. Status LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node);
  213. bool IsGraphNeedBuild(const GraphNodePtr &graph_node);
  214. Status LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, GeModelPtr &ge_model);
  215. Status SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper);
  216. Status SaveCacheAfterBuild(uint32_t graph_id, ComputeGraphPtr graph, GeModelPtr &ge_model);
  217. void AddModelCacheHelperToMap(const GraphId &graph_id, uint64_t session_id, ComputeGraphPtr &compute_graph);
  218. Status IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model);
  219. void RemoveModelCacheHelper(const GraphId &graph_id);
  220. static void PreRunThread(GraphManager *graph_manager);
  221. static void RunThread(GraphManager *graph_manager);
  222. static void StopQueue(GraphManager *graph_manager);
  223. static void ReturnError(GraphManager *graph_manager, RunAsyncCallback callback, Status ret, const string &log);
  224. void ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph);
  225. std::atomic_bool thread_run_flag_;
  226. BlockingQueue<PreRunArgs> prerun_args_q_{};
  227. BlockingQueue<RunArgs> run_args_q_{};
  228. std::thread prerun_thread_;
  229. std::thread run_thread_;
  230. std::map<GraphId, GraphNodePtr> graph_map_;
  231. std::map<GraphId, ModelCacheHelperPtr> cache_helper_map_;
  232. // for run graph synchronous return
  233. std::mutex sync_run_mutex_;
  234. std::condition_variable condition_;
  235. // run graph synchronization call back listener
  236. std::shared_ptr<GraphModelListener> graph_run_listener_;
  237. // summary and checkpoint callback function list for ME, key is summary or checkpoint
  238. std::map<std::string, std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)>> me_callback_map_;
  239. bool init_flag_;
  240. GraphManagerOptions options_;
  241. GraphPrepare graph_preparer_;
  242. GraphOptimize graph_optimize_;
  243. GraphPartitioner graph_partitioner_;
  244. GraphBuilder graph_builder_;
  245. GraphLoader graph_loader_;
  246. GraphExecutor graph_executor_;
  247. GraphContextPtr graph_context_ = nullptr;
  248. VarAccelerateCtrl var_acc_ctrl_;
  249. std::mutex run_mutex_;
  250. };
  251. }; // namespace ge
  252. #endif // GE_GRAPH_MANAGER_GRAPH_MANAGER_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示