You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.h 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  17. #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  18. #include <vector>
  19. #include <string>
  20. #include <memory>
  21. #include <utility>
  22. #include <map>
  23. #include <set>
  24. #include <algorithm>
  25. #include <fstream>
  26. #include "utils/hash_map.h"
  27. #include "utils/hash_set.h"
  28. #include "runtime/graph_scheduler/control_node_scheduler.h"
  29. #include "runtime/graph_scheduler/actor/actor_set.h"
  30. #include "runtime/graph_scheduler/graph_compiler.h"
  31. #include "runtime/graph_scheduler/actor/actor_dump.h"
  32. #include "thread/actor_threadpool.h"
  33. #ifdef ENABLE_RPC_ACTOR
  34. #include "runtime/graph_scheduler/rpc_node_scheduler.h"
  35. #endif
  36. namespace mindspore {
  37. namespace runtime {
  38. using mindspore::device::DeviceContext;
  39. using mindspore::session::KernelGraph;
  40. using mindspore::session::KernelWithIndex;
  41. // The second element of pair represents the output node and output index of abstract actor corresponding to the graph
  42. // output node.
  43. using GraphOutputPair = std::pair<AbstractActor *, KernelWithIndex>;
  44. class GraphScheduler {
  45. public:
  46. static GraphScheduler &GetInstance() noexcept {
  47. static GraphScheduler instance;
  48. return instance;
  49. }
  50. // 1. Thread pool creating.
  51. // 2. The global actors creating and scheduling.
  52. void Initialize();
  53. // Clear the members.
  54. void Clear();
  55. void Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs,
  56. const std::vector<AnfNodePtr> &root_graph_parameters,
  57. const ControlNodeParserPtr &parser = nullptr) noexcept;
  58. // The control flow actors will generate some data in the loop body execution, so need clear on the end of execution.
  59. void ClearActorData(const ActorSet *actor_set);
  60. // Transform graph to actor DAG, contains build and link.
  61. ActorSet *Transform(const GraphCompilerInfo &graph_compiler_info);
  62. // Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling
  63. // will be supported in the future.
  64. void Schedule(const ActorSet *actor_set);
  65. // The processing entry of actors running. The fourth parameter is used only in the step execution strategy.
  66. void Run(ActorSet *constactor_set, const std::vector<DeviceContext *> &device_contexts,
  67. const std::vector<std::vector<TensorPtr>> &input_tensors,
  68. const std::vector<TensorPtr> &input_tensors_with_value_node = {},
  69. GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
  70. // Fetch the actor set by actor info.
  71. ActorSet *Fetch(const ActorInfo &actor_info) const;
  72. private:
  73. GraphScheduler() = default;
  74. ~GraphScheduler() = default;
  75. DISABLE_COPY_AND_ASSIGN(GraphScheduler);
  76. // Set using the multi thread or single thread to execute the actor set by the execution time compared.
  77. void SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy,
  78. double execution_time) const;
  79. // The Global actors contain memory manager actor, recorder actor and debug actor.
  80. void BuildAndScheduleGlobalActor();
  81. // Transform the nodes of graph to actors.
  82. ActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info);
  83. // Link actors to DAG through the edge connection of graph and graph execution strategy.
  84. void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info);
  85. // Optimize the actor DAG. For example, erase invalid data arrow, etc.
  86. void Optimize(ActorSet *const actor_set);
  87. // The processing of actors build.
  88. std::vector<DataSourceActorPtr> BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
  89. const HostTensorQueuePtr &host_queue);
  90. std::vector<KernelActorPtr> BuildKernelActor(const GraphCompilerInfo &graph_compiler_info);
  91. std::vector<CustomActorPtr> BuildCustomActor(const GraphCompilerInfo &graph_compiler_info);
  92. std::vector<SuperKernelActorPtr> BuildSuperKernelActor(const GraphCompilerInfo &graph_compiler_info);
  93. LoopCountActorPtr BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info);
  94. OutputActorPtr BuildOutputActor(const GraphCompilerInfo &graph_compiler_info);
  95. DataPrepareActorPtr BuildDataPrepareActor(const GraphCompilerInfo &graph_compiler_info,
  96. const std::vector<DataSourceActorPtr> &data_source_actors,
  97. const HostTensorQueuePtr &host_queue);
  98. std::vector<AbstractActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set, GraphExecutionStrategy strategy);
  99. // Generate rpc actor object inherited from kernel actor.
  100. KernelActorPtr GenerateRpcActor(const CNodePtr &kernel, const DeviceContext *device_context,
  101. GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
  102. const std::set<size_t> &modifiable_ref_output_indexes);
  103. // Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
  104. // previous graph and the head of next graph.
  105. void CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info);
  106. // The processing of actors linking.
  107. // 1. The processing of linking data arrows.
  108. void LinkDataArrowInSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info,
  109. std::vector<AbstractActor *> *const auto_monad_actors);
  110. void LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info,
  111. std::vector<AbstractActor *> *const auto_monad_actors,
  112. std::vector<CNodePtr> *const communication_nodes);
  113. // The gather of linking data arrows of kernel, it will call following functions by the different from actor type.
  114. void LinkDataArrow(AbstractActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
  115. const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
  116. const KernelWithIndex &to_kernel_with_input_idx);
  117. void LinkDataArrowForBaseActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  118. const KernelWithIndex &from_kernel_with_output_idx,
  119. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  120. // Link data arrows for internal parameter, convert internal parameter to actor by internal parameter cache to link.
  121. void LinkDataArrowForInternalParameter(AbstractActor *const from_actor, AbstractActor *const to_actor,
  122. const KernelWithIndex &from_kernel_with_output_idx,
  123. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  124. void LinkDataArrowForDeviceTensorStore(AbstractActor *const from_actor, AbstractActor *const to_actor,
  125. const KernelWithIndex &from_kernel_with_output_idx,
  126. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  127. void LinkDataArrowForHostDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  128. const KernelWithIndex &from_kernel_with_output_idx,
  129. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  130. void LinkDataArrowForKernelActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  131. const KernelWithIndex &from_kernel_with_output_idx,
  132. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  133. // Link data arrows in the copy actor scene, insert the copy actor between from_actor and to_actor.
  134. void LinkDataArrowForCopyActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  135. const KernelWithIndex &from_kernel_with_output_idx,
  136. const KernelWithIndex &to_kernel_with_input_idx);
  137. // 2. The processing of linking control arrows.
  138. void LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph,
  139. const ControlNodeParserPtr &parser = nullptr);
  140. // The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
  141. void LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node);
  142. // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
  143. void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);
  144. // The gather of linking the global control arrows, it will call following functions:
  145. void LinkGlobalControlArrow(ActorSet *const actor_set, const GroupNameToCommuNodes &communication_node_groups,
  146. const std::vector<AbstractActor *> &auto_monad_actors,
  147. const GraphCompilerInfo &graph_compiler_info);
  148. void LinkControlArrowForCustomActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
  149. // Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
  150. void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
  151. const GraphCompilerInfo &graph_compiler_info);
  152. void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<AbstractActor *> &auto_monad_actors);
  153. void LinkControlArrowForDataPrepareActor(DataPrepareActor *data_prepare_actor, const ActorSet *actor_set,
  154. const ControlNodeParserPtr &parser);
  155. void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
  156. const ControlNodeParserPtr &parser);
  157. void LinkControlArrowForOutputActor(OutputActor *output_actor, const ActorSet *actor_set);
  158. // 3. The processing of linking output result arrows.
  159. void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);
  160. void AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor);
  161. // Add the arrow between from actor and to actor.
  162. void AddDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor, const AnfNodePtr &from_kernel,
  163. size_t from_output_index, size_t to_input_index);
  164. void AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor, const AnfNodePtr &from_kernel,
  165. size_t from_output_index, size_t output_position);
  166. void AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor);
  167. // Check whether the actor set is valid.
  168. void CheckActorValid(const ActorSet *actor_set) const;
  169. // Persist device tensors of graph's some nodes(such as weights and value nodes).
  170. void PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info);
  171. // When the parameters of root graph are not in backend kernel graphs, need persist device tensor by this function.
  172. void PersistDeviceTensorForRootGraphControlNode(const GraphCompilerInfo &graph_compiler_info);
  173. // Display the actor information of corresponding kernel graph.
  174. void DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const;
  175. void DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const;
  176. // The global maps, only be cleared in the deconstruction.
  177. mindspore::HashMap<ActorInfo, ActorSetPtr> actors_;
  178. // The local maps and vectors, will be cleared at the end of each graph transform:
  179. // 1.The second element of pair represents the output index of op actor corresponding to the graph output front node.
  180. std::map<KernelWithIndex, GraphOutputPair, session::KernelWithIndexCmp> graph_output_to_actor_;
  181. // 2.Beaceuse the copy actors are built in the link, so need record the all copy actors in the link process to push
  182. // into the actor set after link.
  183. std::vector<CopyActorPtr> copy_actors_;
  184. // In the control flow, used to build and link control actor.
  185. ControlNodeScheduler control_node_scheduler_;
  186. #ifdef ENABLE_RPC_ACTOR
  187. // Used to build and link for rpc actors.
  188. std::unique_ptr<RpcNodeScheduler> rpc_node_scheduler_{nullptr};
  189. #endif
  190. // The id of global actor.
  191. AID memory_manager_aid_;
  192. const AID *recorder_aid_{nullptr};
  193. const AID *debug_aid_{nullptr};
  194. bool init_{false};
  195. };
  196. } // namespace runtime
  197. } // namespace mindspore
  198. #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_