You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  17. #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  18. #include <vector>
  19. #include <string>
  20. #include <memory>
  21. #include <utility>
  22. #include <map>
  23. #include <set>
  24. #include <algorithm>
  25. #include <fstream>
  26. #include "utils/hash_map.h"
  27. #include "utils/hash_set.h"
  28. #include "runtime/framework/control_node_scheduler.h"
  29. #include "runtime/framework/actor/actor_set.h"
  30. #include "runtime/framework/graph_compiler.h"
  31. #include "runtime/framework/actor/actor_dump.h"
  32. #include "thread/actor_threadpool.h"
  33. namespace mindspore {
  34. namespace runtime {
  35. using mindspore::device::DeviceContext;
  36. using mindspore::session::KernelGraph;
  37. using mindspore::session::KernelWithIndex;
  38. // The second element of pair represents the output node and output index of abstract actor corresponding to the graph
  39. // output node.
  40. using GraphOutputPair = std::pair<AbstractActor *, KernelWithIndex>;
  41. class GraphScheduler {
  42. public:
  43. static GraphScheduler &GetInstance() noexcept {
  44. static GraphScheduler instance;
  45. return instance;
  46. }
  47. // 1. Thread pool creating.
  48. // 2. The global actors creating and scheduling.
  49. void Initialize();
  50. // Clear the members.
  51. void Clear();
  52. void Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs) noexcept;
  53. // The control flow actors will generate some data in the loop body execution, so need clear on the end of execution.
  54. void ClearActorData(const ActorSet *actor_set);
  55. // Transform graph to actor DAG, contains build and link.
  56. ActorSet *Transform(const GraphCompilerInfo &graph_compiler_info);
  57. // Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling
  58. // will be supported in the future.
  59. void Schedule(const ActorSet *actor_set);
  60. // The processing entry of actors running. The fourth parameter is used only in the step execution strategy.
  61. void Run(ActorSet *constactor_set, const std::vector<DeviceContext *> &device_contexts,
  62. const std::vector<std::vector<TensorPtr>> &input_tensors,
  63. const std::vector<TensorPtr> &input_tensors_with_value_node = {},
  64. GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
  65. // Fetch the actor set by actor info.
  66. ActorSet *Fetch(const ActorInfo &actor_info) const;
  67. private:
  68. GraphScheduler() = default;
  69. ~GraphScheduler() = default;
  70. DISABLE_COPY_AND_ASSIGN(GraphScheduler);
  71. // Set using the multi thread or single thread to execute the actor set by the execution time compared.
  72. void SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy, double execution_time);
  73. // The Global actors contain memory manager actor, recorder actor and debug actor.
  74. void BuildAndScheduleGlobalActor();
  75. // Transform the nodes of graph to actors.
  76. ActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info);
  77. // Link actors to DAG through the edge connection of graph and graph execution strategy.
  78. void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info);
  79. // The processing of actors build.
  80. std::vector<DataSourceActorPtr> BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
  81. const HostTensorQueuePtr &host_queue);
  82. std::vector<KernelActorPtr> BuildKernelActor(const GraphCompilerInfo &graph_compiler_info);
  83. std::vector<CustomActorPtr> BuildCustomActor(const GraphCompilerInfo &graph_compiler_info);
  84. std::vector<SuperKernelActorPtr> BuildSuperKernelActor(const GraphCompilerInfo &graph_compiler_info);
  85. LoopCountActorPtr BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info);
  86. OutputActorPtr BuildOutputActor(const GraphCompilerInfo &graph_compiler_info);
  87. DataPrepareActorPtr BuildDataPrepareActor(const GraphCompilerInfo &graph_compiler_info,
  88. const std::vector<DataSourceActorPtr> &data_source_actors,
  89. const HostTensorQueuePtr &host_queue);
  90. std::vector<AbstractActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set, GraphExecutionStrategy strategy);
  91. // Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
  92. // previous graph and the head of next graph.
  93. void CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info);
  94. // The processing of actors linking.
  95. // 1. The processing of linking data arrows.
  96. void LinkDataArrowInSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info,
  97. std::vector<AbstractActor *> *const auto_monad_actors);
  98. void LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info,
  99. std::vector<AbstractActor *> *const auto_monad_actors,
  100. std::vector<CNodePtr> *const communication_nodes);
  101. // The gather of linking data arrows of kernel, it will call following functions by the different from actor type.
  102. void LinkDataArrow(AbstractActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
  103. const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
  104. const KernelWithIndex &to_kernel_with_input_idx);
  105. void LinkDataArrowForBaseActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  106. const KernelWithIndex &from_kernel_with_output_idx,
  107. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  108. // Link data arrows for internal parameter, convert internal parameter to actor by internal parameter cache to link.
  109. void LinkDataArrowForInternalParameter(AbstractActor *const from_actor, AbstractActor *const to_actor,
  110. const KernelWithIndex &from_kernel_with_output_idx,
  111. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  112. void LinkDataArrowForDeviceTensorStore(AbstractActor *const from_actor, AbstractActor *const to_actor,
  113. const KernelWithIndex &from_kernel_with_output_idx,
  114. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  115. void LinkDataArrowForHostDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  116. const KernelWithIndex &from_kernel_with_output_idx,
  117. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  118. void LinkDataArrowForKernelActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  119. const KernelWithIndex &from_kernel_with_output_idx,
  120. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  121. // Link data arrows in the copy actor scene, insert the copy actor between from_actor and to_actor.
  122. void LinkDataArrowForCopyActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  123. const KernelWithIndex &from_kernel_with_output_idx,
  124. const KernelWithIndex &to_kernel_with_input_idx);
  125. // 2. The processing of linking control arrows.
  126. void LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph,
  127. const ControlNodeParserPtr &parser = nullptr);
  128. // The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
  129. void LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node);
  130. // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
  131. void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);
  132. // The gather of linking the global control arrows, it will call following functions:
  133. void LinkGlobalControlArrow(ActorSet *const actor_set, const std::vector<CNodePtr> &communication_nodes,
  134. const std::vector<AbstractActor *> &auto_monad_actors,
  135. const GraphCompilerInfo &graph_compiler_info);
  136. void LinkControlArrowForCustomActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
  137. // Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
  138. void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
  139. const GraphCompilerInfo &graph_compiler_info);
  140. void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<AbstractActor *> &auto_monad_actors);
  141. void LinkControlArrowForDataPrepareActor(DataPrepareActor *data_prepare_actor, const ActorSet *actor_set,
  142. const ControlNodeParserPtr &parser);
  143. void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
  144. const ControlNodeParserPtr &parser);
  145. // 3. The processing of linking output result arrows.
  146. void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);
  147. void AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor);
  148. // Add the arrow between from actor and to actor.
  149. void AddDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor, const AnfNodePtr &from_kernel,
  150. size_t from_output_index, size_t to_input_index);
  151. void AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor, const AnfNodePtr &from_kernel,
  152. size_t from_output_index, size_t output_position);
  153. void AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor);
  154. // Check whether the actor set is valid.
  155. void CheckActorValid(const ActorSet *actor_set) const;
  156. // Persist device tensors of graph's some nodes(such as weights and value nodes).
  157. void PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info);
  158. // Display the actor information of corresponding kernel graph.
  159. void DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const;
  160. void DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const;
  161. // The global maps, only be cleared in the deconstruction.
  162. mindspore::HashMap<ActorInfo, ActorSetPtr> actors_;
  163. // The local maps and vectors, will be cleared at the end of each graph transform:
  164. // 1.The second element of pair represents the output index of op actor corresponding to the graph output front node.
  165. std::map<KernelWithIndex, GraphOutputPair, session::KernelWithIndexCmp> graph_output_to_actor_;
  166. // 2.Beaceuse the copy actors are built in the link, so need record the all copy actors in the link process to push
  167. // into the actor set after link.
  168. std::vector<CopyActorPtr> copy_actors_;
  169. // In the control flow, used to build and link control actor.
  170. ControlNodeScheduler control_node_scheduler_;
  171. // The id of global actor.
  172. AID memory_manager_aid_;
  173. const AID *recorder_aid_{nullptr};
  174. const AID *debug_aid_{nullptr};
  175. bool init_{false};
  176. };
  177. } // namespace runtime
  178. } // namespace mindspore
  179. #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_