You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.h 5.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
  17. #define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
  18. #include <vector>
  19. #include <string>
  20. #include <utility>
  21. #include <memory>
  22. #include <list>
  23. #include <queue>
  24. #include <map>
  25. #include <thread>
  26. #include <mutex>
  27. #include <condition_variable>
  28. #include "backend/session/session_basic.h"
  29. #include "ir/anf.h"
  30. #include "ir/tensor.h"
  31. #include "utils/any.h"
  32. #include "utils/contract.h"
  33. #include "utils/comm_manager.h"
  34. namespace mindspore {
  35. namespace session {
  36. enum TaskType {
  37. kUnKnown,
  38. kExit,
  39. kCompileNodes,
  40. kCompileGraph,
  41. kBuildGraph,
  42. kBuildOp,
  43. kRunGraph,
  44. kRunOp,
  45. kCreateCommGroup,
  46. kDestroyCommGroup
  47. };
  48. class Task {
  49. public:
  50. Task() = default;
  51. virtual ~Task() = default;
  52. SessionPtr session_{nullptr};
  53. TaskType type_{kUnKnown};
  54. bool sync_run_{false};
  55. virtual void Run() {}
  56. };
  57. class CompileNodesTask : public Task {
  58. public:
  59. CompileNodesTask() { type_ = kCompileNodes; }
  60. ~CompileNodesTask() override = default;
  61. void Run() override;
  62. GraphSegmentPtr segment_;
  63. AnfNodePtrList output_nodes_;
  64. GraphId graph_id_{0};
  65. };
  66. class CompileGraphTask : public Task {
  67. public:
  68. CompileGraphTask() { type_ = kCompileGraph; }
  69. ~CompileGraphTask() override = default;
  70. void Run() override;
  71. FuncGraphPtr func_graph_{nullptr};
  72. GraphId graph_id_{0};
  73. };
  74. class BuildGraphTask : public Task {
  75. public:
  76. BuildGraphTask() { type_ = kBuildGraph; }
  77. ~BuildGraphTask() override = default;
  78. void Run() override;
  79. GraphId graph_id_{0};
  80. };
  81. class RunGraphTask : public Task {
  82. public:
  83. RunGraphTask() { type_ = kRunGraph; }
  84. ~RunGraphTask() override = default;
  85. void Run() override;
  86. std::vector<tensor::TensorPtr> input_tensors_;
  87. std::vector<tensor::TensorPtr> input_need_wait_tensors_;
  88. std::vector<tensor::TensorPtr> input_need_lock_tensors_;
  89. VectorRef outputs_;
  90. GraphId graph_id_{0};
  91. std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_;
  92. };
  93. class BuildOpTask : public Task {
  94. public:
  95. BuildOpTask() { type_ = kBuildOp; }
  96. ~BuildOpTask() override = default;
  97. void Run() override;
  98. OpRunInfo *op_run_info_{nullptr};
  99. GraphInfo graph_info_;
  100. std::vector<tensor::TensorPtr> input_tensors_;
  101. std::vector<int64_t> tensors_mask_;
  102. };
  103. class RunOpTask : public Task {
  104. public:
  105. RunOpTask() { type_ = kRunOp; }
  106. ~RunOpTask() override = default;
  107. void Run() override;
  108. OpRunInfo *op_run_info_{nullptr};
  109. GraphInfo graph_info_;
  110. std::vector<tensor::TensorPtr> input_tensors_;
  111. VectorRef outputs_;
  112. };
  113. class CreateCommGroupTask : public Task {
  114. public:
  115. CreateCommGroupTask() { type_ = kCreateCommGroup; }
  116. ~CreateCommGroupTask() override = default;
  117. void Run() override;
  118. std::string group_name_;
  119. std::vector<uint32_t> ranks_;
  120. bool result_{false};
  121. };
  122. class DestroyCommGroupTask : public Task {
  123. public:
  124. DestroyCommGroupTask() { type_ = kDestroyCommGroup; }
  125. ~DestroyCommGroupTask() override = default;
  126. void Run() override;
  127. std::string group_name_;
  128. bool result_{false};
  129. };
  130. class ExitTask : public Task {
  131. public:
  132. ExitTask() { type_ = kExit; }
  133. ~ExitTask() override = default;
  134. };
  135. class Executor {
  136. public:
  137. Executor(const std::string &device_name, uint32_t device_id);
  138. ~Executor();
  139. void WorkerLoop();
  140. void WorkerJoin();
  141. GraphId CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, const AnfNodePtrList &outputs);
  142. GraphId CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph);
  143. void BuildGraph(const SessionPtr &session, GraphId graphId);
  144. void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  145. VectorRef *outputs);
  146. void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  147. VectorRef *outputs);
  148. void BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  149. const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask);
  150. void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
  151. const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs);
  152. void OnRunGraphFinished();
  153. bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks);
  154. bool DestroyCommGroup(const std::string &group_name);
  155. private:
  156. void SyncRunTask(const std::shared_ptr<Task> &task);
  157. void UpdateOutputTensors(VectorRef *outputs,
  158. const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node);
  159. std::vector<std::shared_ptr<RunGraphTask>> GetNewReadyTasks();
  160. bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task);
  161. void CheckException();
  162. void OnWorkerExit();
  163. uint32_t device_id_;
  164. std::string device_name_;
  165. std::mutex task_mutex_;
  166. std::mutex pending_task_mutex_;
  167. std::condition_variable task_cond_var_;
  168. std::condition_variable sync_cond_var_;
  169. std::queue<std::shared_ptr<Task>> ready_tasks_;
  170. std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
  171. std::vector<std::shared_ptr<Task>> done_tasks_;
  172. std::shared_ptr<std::thread> worker_;
  173. };
  174. } // namespace session
  175. } // namespace mindspore
  176. #endif // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H