You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.h 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  17. #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  18. #include <vector>
  19. #include <string>
  20. #include <memory>
  21. #include <utility>
  22. #include <map>
  23. #include <set>
  24. #include <algorithm>
  25. #include <fstream>
  26. #include "utils/hash_map.h"
  27. #include "utils/hash_set.h"
  28. #include "runtime/graph_scheduler/control_node_scheduler.h"
  29. #include "runtime/graph_scheduler/actor/actor_set.h"
  30. #include "runtime/graph_scheduler/graph_compiler.h"
  31. #include "runtime/graph_scheduler/actor/actor_dump.h"
  32. #include "thread/actor_threadpool.h"
  33. #ifdef ENABLE_RPC_ACTOR
  34. #include "runtime/graph_scheduler/rpc_node_scheduler.h"
  35. #endif
  36. #include "include/backend/visible.h"
  37. namespace mindspore {
  38. namespace runtime {
  39. using mindspore::device::DeviceContext;
  40. using mindspore::session::KernelGraph;
  41. using mindspore::session::KernelWithIndex;
  42. // The second element of pair represents the output node and output index of abstract actor corresponding to the graph
  43. // output node.
  44. using GraphOutputPair = std::pair<AbstractActor *, KernelWithIndex>;
  45. class BACKEND_EXPORT GraphScheduler {
  46. public:
  47. static GraphScheduler &GetInstance() noexcept;
  48. // 1. Thread pool creating.
  49. // 2. The global actors creating and scheduling.
  50. void Initialize();
  51. // Clear the members.
  52. void Clear();
  53. void Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs,
  54. const std::vector<AnfNodePtr> &root_graph_parameters,
  55. const ControlNodeParserPtr &parser = nullptr) noexcept;
  56. // The control flow actors will generate some data in the loop body execution, so need clear on the end of execution.
  57. void ClearActorData(const ActorSet *actor_set);
  58. // Transform graph to actor DAG, contains build and link.
  59. ActorSet *Transform(const GraphCompilerInfo &graph_compiler_info);
  60. // Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling
  61. // will be supported in the future.
  62. void Schedule(const ActorSet *actor_set);
  63. // The processing entry of actors running. The fourth parameter is used only in the step execution strategy.
  64. void Run(ActorSet *constactor_set, const std::vector<DeviceContext *> &device_contexts,
  65. const std::vector<std::vector<TensorPtr>> &input_tensors,
  66. const std::vector<TensorPtr> &input_tensors_with_value_node = {},
  67. GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
  68. // Fetch the actor set by actor info.
  69. ActorSet *Fetch(const ActorInfo &actor_info) const;
  70. private:
  71. GraphScheduler() = default;
  72. ~GraphScheduler() = default;
  73. DISABLE_COPY_AND_ASSIGN(GraphScheduler);
  74. // Set using the multi thread or single thread to execute the actor set by the execution time compared.
  75. void SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy,
  76. double execution_time) const;
  77. // The Global actors contain memory manager actor, recorder actor and debug actor.
  78. void BuildAndScheduleGlobalActor();
  79. // Transform the nodes of graph to actors.
  80. ActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info);
  81. // Link actors to DAG through the edge connection of graph and graph execution strategy.
  82. void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info);
  83. // Optimize the actor DAG. For example, erase invalid data arrow, etc.
  84. void Optimize(ActorSet *const actor_set);
  85. // The processing of actors build.
  86. std::vector<DataSourceActorPtr> BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
  87. const HostTensorQueuePtr &host_queue);
  88. std::vector<KernelActorPtr> BuildKernelActor(const GraphCompilerInfo &graph_compiler_info);
  89. std::vector<CustomActorPtr> BuildCustomActor(const GraphCompilerInfo &graph_compiler_info);
  90. std::vector<SuperKernelActorPtr> BuildSuperKernelActor(const GraphCompilerInfo &graph_compiler_info);
  91. LoopCountActorPtr BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info);
  92. OutputActorPtr BuildOutputActor(const GraphCompilerInfo &graph_compiler_info);
  93. DataPrepareActorPtr BuildDataPrepareActor(const GraphCompilerInfo &graph_compiler_info,
  94. const std::vector<DataSourceActorPtr> &data_source_actors,
  95. const HostTensorQueuePtr &host_queue);
  96. std::vector<AbstractActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set, GraphExecutionStrategy strategy);
  97. // Generate rpc actor object inherited from kernel actor.
  98. KernelActorPtr GenerateRpcActor(const CNodePtr &kernel, const DeviceContext *device_context,
  99. GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
  100. const std::set<size_t> &modifiable_ref_output_indexes);
  101. // Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
  102. // previous graph and the head of next graph.
  103. void CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info);
  104. // The processing of actors linking.
  105. // 1. The processing of linking data arrows.
  106. void LinkDataArrowInSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info,
  107. std::vector<AbstractActor *> *const auto_monad_actors);
  108. void LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info,
  109. std::vector<AbstractActor *> *const auto_monad_actors,
  110. std::vector<CNodePtr> *const communication_nodes);
  111. // The gather of linking data arrows of kernel, it will call following functions by the different from actor type.
  112. void LinkDataArrow(AbstractActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
  113. const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
  114. const KernelWithIndex &to_kernel_with_input_idx);
  115. void LinkDataArrowForBaseActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  116. const KernelWithIndex &from_kernel_with_output_idx,
  117. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  118. // Link data arrows for internal parameter, convert internal parameter to actor by internal parameter cache to link.
  119. void LinkDataArrowForInternalParameter(AbstractActor *const from_actor, AbstractActor *const to_actor,
  120. const KernelWithIndex &from_kernel_with_output_idx,
  121. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  122. void LinkDataArrowForDeviceTensorStore(AbstractActor *const from_actor, AbstractActor *const to_actor,
  123. const KernelWithIndex &from_kernel_with_output_idx,
  124. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  125. void LinkDataArrowForHostDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  126. const KernelWithIndex &from_kernel_with_output_idx,
  127. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  128. void LinkDataArrowForKernelActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  129. const KernelWithIndex &from_kernel_with_output_idx,
  130. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  131. // Link data arrows in the copy actor scene, insert the copy actor between from_actor and to_actor.
  132. void LinkDataArrowForCopyActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  133. const KernelWithIndex &from_kernel_with_output_idx,
  134. const KernelWithIndex &to_kernel_with_input_idx);
  135. // 2. The processing of linking control arrows.
  136. void LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph,
  137. const ControlNodeParserPtr &parser = nullptr);
  138. // The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
  139. void LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node);
  140. // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
  141. void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);
  142. // The gather of linking the global control arrows, it will call following functions:
  143. void LinkGlobalControlArrow(ActorSet *const actor_set, const GroupNameToCommuNodes &communication_node_groups,
  144. const std::vector<AbstractActor *> &auto_monad_actors,
  145. const GraphCompilerInfo &graph_compiler_info);
  146. void LinkControlArrowForCustomActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
  147. // Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
  148. void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
  149. const GraphCompilerInfo &graph_compiler_info);
  150. void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<AbstractActor *> &auto_monad_actors);
  151. void LinkControlArrowForDataPrepareActor(DataPrepareActor *data_prepare_actor, const ActorSet *actor_set,
  152. const ControlNodeParserPtr &parser);
  153. void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
  154. const ControlNodeParserPtr &parser);
  155. void LinkControlArrowForOutputActor(OutputActor *output_actor, const ActorSet *actor_set);
  156. // 3. The processing of linking output result arrows.
  157. void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);
  158. void AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor);
  159. // Add the arrow between from actor and to actor.
  160. void AddDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor, const AnfNodePtr &from_kernel,
  161. size_t from_output_index, size_t to_input_index);
  162. void AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor, const AnfNodePtr &from_kernel,
  163. size_t from_output_index, size_t output_position);
  164. void AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor);
  165. // Check whether the actor set is valid.
  166. void CheckActorValid(const ActorSet *actor_set) const;
  167. // Persist device tensors of graph's some nodes(such as weights and value nodes).
  168. void PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info);
  169. // When the parameters of root graph are not in backend kernel graphs, need persist device tensor by this function.
  170. void PersistDeviceTensorForRootGraphControlNode(const GraphCompilerInfo &graph_compiler_info);
  171. // Display the actor information of corresponding kernel graph.
  172. void DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const;
  173. void DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const;
  174. // The global maps, only be cleared in the deconstruction.
  175. mindspore::HashMap<ActorInfo, ActorSetPtr> actors_;
  176. // The local maps and vectors, will be cleared at the end of each graph transform:
  177. // 1.The second element of pair represents the output index of op actor corresponding to the graph output front node.
  178. std::map<KernelWithIndex, GraphOutputPair, session::KernelWithIndexCmp> graph_output_to_actor_;
  179. // 2.Beaceuse the copy actors are built in the link, so need record the all copy actors in the link process to push
  180. // into the actor set after link.
  181. std::vector<CopyActorPtr> copy_actors_;
  182. // In the control flow, used to build and link control actor.
  183. ControlNodeScheduler control_node_scheduler_;
  184. #ifdef ENABLE_RPC_ACTOR
  185. // Used to build and link for rpc actors.
  186. std::unique_ptr<RpcNodeScheduler> rpc_node_scheduler_{nullptr};
  187. #endif
  188. // The id of global actor.
  189. AID memory_manager_aid_;
  190. const AID *recorder_aid_{nullptr};
  191. const AID *debug_aid_{nullptr};
  192. bool init_{false};
  193. };
  194. } // namespace runtime
  195. } // namespace mindspore
  196. #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_