You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.h 6.2 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
  17. #define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
  18. #include <condition_variable>
  19. #include <list>
  20. #include <map>
  21. #include <memory>
  22. #include <mutex>
  23. #include <atomic>
  24. #include <queue>
  25. #include <string>
  26. #include <thread>
  27. #include <utility>
  28. #include <vector>
  29. #include "backend/common/session/session_basic.h"
  30. #include "ir/anf.h"
  31. #include "ir/tensor.h"
  32. #include "utils/any.h"
  33. #include "include/common/utils/comm_manager.h"
  34. #include "include/common/utils/contract.h"
  35. namespace mindspore {
  36. namespace session {
  37. enum TaskType {
  38. kUnKnown,
  39. kExit,
  40. kCompileNodes,
  41. kCompileGraph,
  42. kBuildGraph,
  43. kRunGraph,
  44. kRunOp,
  45. kCreateCommGroup,
  46. kDestroyCommGroup,
  47. kRunOpsInGraph
  48. };
  49. class Task {
  50. public:
  51. Task() = default;
  52. virtual ~Task() = default;
  53. SessionPtr session_{nullptr};
  54. TaskType type_{kUnKnown};
  55. bool sync_run_{false};
  56. virtual void Run() {}
  57. };
  58. class CompileNodesTask : public Task {
  59. public:
  60. CompileNodesTask() { type_ = kCompileNodes; }
  61. ~CompileNodesTask() override = default;
  62. void Run() override;
  63. GraphSegmentPtr segment_;
  64. AnfNodePtrList output_nodes_;
  65. GraphId graph_id_{0};
  66. };
  67. class CompileGraphTask : public Task {
  68. public:
  69. CompileGraphTask() { type_ = kCompileGraph; }
  70. ~CompileGraphTask() override = default;
  71. void Run() override;
  72. FuncGraphPtr func_graph_{nullptr};
  73. GraphId graph_id_{0};
  74. };
  75. class BuildGraphTask : public Task {
  76. public:
  77. BuildGraphTask() { type_ = kBuildGraph; }
  78. ~BuildGraphTask() override = default;
  79. void Run() override;
  80. GraphId graph_id_{0};
  81. };
  82. class RunGraphTask : public Task {
  83. public:
  84. RunGraphTask() { type_ = kRunGraph; }
  85. ~RunGraphTask() override = default;
  86. void Run() override;
  87. std::vector<tensor::TensorPtr> input_tensors_;
  88. std::vector<tensor::TensorPtr> input_need_wait_tensors_;
  89. std::vector<tensor::TensorPtr> input_need_lock_tensors_;
  90. VectorRef outputs_;
  91. GraphId graph_id_{0};
  92. std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_;
  93. KernelMapTensor node_to_tensor_;
  94. };
  95. class RunOpsInGraphTask : public Task {
  96. public:
  97. RunOpsInGraphTask() { type_ = kRunOpsInGraph; }
  98. ~RunOpsInGraphTask() override = default;
  99. void Run() override;
  100. std::vector<tensor::TensorPtr> input_tensors_;
  101. VectorRef outputs_;
  102. GraphId graph_id_{0};
  103. };
  104. class RunOpTask : public Task {
  105. public:
  106. RunOpTask() { type_ = kRunOp; }
  107. ~RunOpTask() override {
  108. op_run_info_ = nullptr;
  109. input_tensors_ = nullptr;
  110. }
  111. void Run() override;
  112. OpRunInfo *op_run_info_{nullptr};
  113. GraphInfo graph_info_;
  114. std::vector<tensor::TensorPtr> *input_tensors_{nullptr};
  115. VectorRef outputs_;
  116. std::vector<int64_t> tensors_mask_;
  117. };
  118. class CreateCommGroupTask : public Task {
  119. public:
  120. CreateCommGroupTask() { type_ = kCreateCommGroup; }
  121. ~CreateCommGroupTask() override = default;
  122. void Run() override;
  123. std::string group_name_;
  124. std::vector<uint32_t> ranks_;
  125. bool result_{false};
  126. };
  127. class DestroyCommGroupTask : public Task {
  128. public:
  129. DestroyCommGroupTask() { type_ = kDestroyCommGroup; }
  130. ~DestroyCommGroupTask() override = default;
  131. void Run() override;
  132. std::string group_name_;
  133. bool result_{false};
  134. };
  135. class ExitTask : public Task {
  136. public:
  137. ExitTask() { type_ = kExit; }
  138. ~ExitTask() override = default;
  139. };
  140. enum class ExecutorEvent { kClear, kRunGraphFinished, kException };
  141. class Executor {
  142. public:
  143. Executor(const std::string &device_name, uint32_t device_id);
  144. ~Executor();
  145. void WorkerLoop();
  146. void WorkerJoin();
  147. GraphId CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, const AnfNodePtrList &outputs);
  148. GraphId CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph);
  149. void BuildGraph(const SessionPtr &session, GraphId graphId);
  150. void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  151. VectorRef *outputs);
  152. void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  153. VectorRef *outputs);
  154. void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  155. std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
  156. const std::vector<int64_t> &tensors_mask);
  157. void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  158. VectorRef *outputs);
  159. bool CreateCommGroup(const std::string &group_name, const std::vector<uint32_t> &ranks);
  160. bool DestroyCommGroup(const std::string &group_name);
  161. void OnEvent(const ExecutorEvent &event);
  162. void ClearDoneTasks();
  163. private:
  164. void RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run = false);
  165. std::vector<std::shared_ptr<RunGraphTask>> GetReadyTasksFromPendingList();
  166. void OnWorkerExit();
  167. void OnClear();
  168. void OnRunGraphFinished();
  169. void OnException();
  170. uint32_t device_id_;
  171. std::string device_name_;
  172. std::mutex task_mutex_;
  173. std::mutex done_task_mutex_;
  174. std::mutex pending_task_mutex_;
  175. std::mutex reenter_mutex_;
  176. std::condition_variable task_cond_var_;
  177. std::condition_variable sync_cond_var_;
  178. std::condition_variable reenter_cond_var_;
  179. std::queue<std::shared_ptr<Task>> ready_tasks_;
  180. std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
  181. std::vector<std::shared_ptr<Task>> done_tasks_;
  182. std::shared_ptr<std::thread> worker_;
  183. bool sync_run_task_finished_{false};
  184. };
  185. } // namespace session
  186. } // namespace mindspore
  187. #endif // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H