You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.h 6.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
  17. #define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
  18. #include <condition_variable>
  19. #include <list>
  20. #include <map>
  21. #include <memory>
  22. #include <mutex>
  23. #include <atomic>
  24. #include <queue>
  25. #include <string>
  26. #include <thread>
  27. #include <utility>
  28. #include <vector>
  29. #include "backend/session/session_basic.h"
  30. #include "ir/anf.h"
  31. #include "ir/tensor.h"
  32. #include "utils/any.h"
  33. #include "utils/comm_manager.h"
  34. #include "utils/contract.h"
  35. namespace mindspore {
  36. namespace session {
  37. enum TaskType {
  38. kUnKnown,
  39. kExit,
  40. kCompileNodes,
  41. kCompileGraph,
  42. kBuildGraph,
  43. kRunGraph,
  44. kRunOp,
  45. kCreateCommGroup,
  46. kDestroyCommGroup,
  47. kRunOpsInGraph
  48. };
  49. class Task {
  50. public:
  51. Task() = default;
  52. virtual ~Task() = default;
  53. SessionPtr session_{nullptr};
  54. TaskType type_{kUnKnown};
  55. bool sync_run_{false};
  56. virtual void Run() {}
  57. };
  58. class CompileNodesTask : public Task {
  59. public:
  60. CompileNodesTask() { type_ = kCompileNodes; }
  61. ~CompileNodesTask() override = default;
  62. void Run() override;
  63. GraphSegmentPtr segment_;
  64. AnfNodePtrList output_nodes_;
  65. GraphId graph_id_{0};
  66. };
  67. class CompileGraphTask : public Task {
  68. public:
  69. CompileGraphTask() { type_ = kCompileGraph; }
  70. ~CompileGraphTask() override = default;
  71. void Run() override;
  72. FuncGraphPtr func_graph_{nullptr};
  73. GraphId graph_id_{0};
  74. };
  75. class BuildGraphTask : public Task {
  76. public:
  77. BuildGraphTask() { type_ = kBuildGraph; }
  78. ~BuildGraphTask() override = default;
  79. void Run() override;
  80. GraphId graph_id_{0};
  81. };
  82. class RunGraphTask : public Task {
  83. public:
  84. RunGraphTask() { type_ = kRunGraph; }
  85. ~RunGraphTask() override = default;
  86. void Run() override;
  87. std::vector<tensor::TensorPtr> input_tensors_;
  88. std::vector<tensor::TensorPtr> input_need_wait_tensors_;
  89. std::vector<tensor::TensorPtr> input_need_lock_tensors_;
  90. VectorRef outputs_;
  91. GraphId graph_id_{0};
  92. std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_;
  93. };
  94. class RunOpsInGraphTask : public Task {
  95. public:
  96. RunOpsInGraphTask() { type_ = kRunOpsInGraph; }
  97. ~RunOpsInGraphTask() override = default;
  98. void Run() override;
  99. std::vector<tensor::TensorPtr> input_tensors_;
  100. VectorRef outputs_;
  101. GraphId graph_id_{0};
  102. };
  103. class RunOpTask : public Task {
  104. public:
  105. RunOpTask() { type_ = kRunOp; }
  106. ~RunOpTask() override = default;
  107. void Run() override;
  108. OpRunInfo *op_run_info_{nullptr};
  109. GraphInfo graph_info_;
  110. std::vector<tensor::TensorPtr> *input_tensors_{nullptr};
  111. VectorRef outputs_;
  112. std::vector<int64_t> tensors_mask_;
  113. };
  114. class CreateCommGroupTask : public Task {
  115. public:
  116. CreateCommGroupTask() { type_ = kCreateCommGroup; }
  117. ~CreateCommGroupTask() override = default;
  118. void Run() override;
  119. std::string group_name_;
  120. std::vector<uint32_t> ranks_;
  121. bool result_{false};
  122. };
  123. class DestroyCommGroupTask : public Task {
  124. public:
  125. DestroyCommGroupTask() { type_ = kDestroyCommGroup; }
  126. ~DestroyCommGroupTask() override = default;
  127. void Run() override;
  128. std::string group_name_;
  129. bool result_{false};
  130. };
  131. class ExitTask : public Task {
  132. public:
  133. ExitTask() { type_ = kExit; }
  134. ~ExitTask() override = default;
  135. };
  136. enum class ExecutorEvent { kClear, kRunGraphFinished, kException };
  137. class Executor {
  138. public:
  139. Executor(const std::string &device_name, uint32_t device_id);
  140. ~Executor();
  141. void WorkerLoop();
  142. void WorkerJoin();
  143. GraphId CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, const AnfNodePtrList &outputs);
  144. GraphId CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph);
  145. void BuildGraph(const SessionPtr &session, GraphId graphId);
  146. void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  147. VectorRef *outputs);
  148. void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  149. VectorRef *outputs);
  150. void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  151. std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
  152. const std::vector<int64_t> &tensors_mask);
  153. void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  154. VectorRef *outputs);
  155. bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks);
  156. bool DestroyCommGroup(const std::string &group_name);
  157. void OnEvent(const ExecutorEvent &event);
  158. private:
  159. void RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run = false);
  160. std::vector<std::shared_ptr<RunGraphTask>> GetReadyTasksFromPendingList();
  161. void OnWorkerExit();
  162. void OnClear();
  163. void OnRunGraphFinished();
  164. void OnException();
  165. void ClearDoneTasks();
  166. uint32_t device_id_;
  167. std::string device_name_;
  168. std::mutex task_mutex_;
  169. std::mutex done_task_mutex_;
  170. std::mutex pending_task_mutex_;
  171. std::mutex reenter_mutex_;
  172. std::condition_variable task_cond_var_;
  173. std::condition_variable sync_cond_var_;
  174. std::condition_variable reenter_cond_var_;
  175. std::queue<std::shared_ptr<Task>> ready_tasks_;
  176. std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
  177. std::vector<std::shared_ptr<Task>> done_tasks_;
  178. std::shared_ptr<std::thread> worker_;
  179. bool sync_run_task_finished_{false};
  180. };
  181. } // namespace session
  182. } // namespace mindspore
  183. #endif // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H