You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  17. #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  18. #include <vector>
  19. #include <string>
  20. #include <memory>
  21. #include <utility>
  22. #include <map>
  23. #include <set>
  24. #include <algorithm>
  25. #include <fstream>
  26. #include "utils/hash_map.h"
  27. #include "utils/hash_set.h"
  28. #include "runtime/framework/control_node_scheduler.h"
  29. #include "runtime/framework/actor/actor_set.h"
  30. #include "runtime/framework/graph_compiler.h"
  31. #include "runtime/framework/actor/actor_dump.h"
  32. #include "thread/actor_threadpool.h"
  33. namespace mindspore {
  34. namespace runtime {
  35. using mindspore::device::DeviceContext;
  36. using mindspore::session::KernelGraph;
  37. using mindspore::session::KernelWithIndex;
  38. // The second element of pair represents the output node and output index of abstract actor corresponding to the graph
  39. // output node.
  40. using GraphOutputPair = std::pair<AbstractActor *, KernelWithIndex>;
  41. class GraphScheduler {
  42. public:
  43. static GraphScheduler &GetInstance() noexcept {
  44. static GraphScheduler instance;
  45. return instance;
  46. }
  47. // 1. Thread pool creating.
  48. // 2. The global actors creating and scheduling.
  49. void Initialize();
  50. // Clear the members.
  51. void Clear();
  52. void Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs) noexcept;
  53. // The control flow actors will generate some data in the loop body execution, so need clear on the end of execution.
  54. void ClearActorData(const ActorSet *actor_set);
  55. // Transform graph to actor DAG, contains build and link.
  56. ActorSet *Transform(const GraphCompilerInfo &graph_compiler_info);
  57. // Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling
  58. // will be supported in the future.
  59. void Schedule(const ActorSet *actor_set);
  60. // The processing entry of actors running. The fourth parameter is used only in the step execution strategy.
  61. void Run(ActorSet *constactor_set, const std::vector<DeviceContext *> &device_contexts,
  62. const std::vector<std::vector<TensorPtr>> &input_tensors,
  63. const std::vector<TensorPtr> &input_tensors_with_value_node = {},
  64. GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
  65. // Fetch the actor set by actor info.
  66. ActorSet *Fetch(const ActorInfo &actor_info) const;
  67. private:
  68. GraphScheduler() = default;
  69. ~GraphScheduler() = default;
  70. DISABLE_COPY_AND_ASSIGN(GraphScheduler);
  71. // Set using the multi thread or single thread to execute the actor set by the execution time compared.
  72. void SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy, double execution_time);
  73. // The Global actors contain memory manager actor, recorder actor and debug actor.
  74. void BuildAndScheduleGlobalActor();
  75. // Transform the nodes of graph to actors.
  76. ActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info);
  77. // Link actors to DAG through the edge connection of graph and graph execution strategy.
  78. void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info);
  79. // The processing of actors build.
  80. std::vector<DataSourceActorPtr> BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
  81. const HostTensorQueuePtr &host_queue);
  82. std::vector<KernelActorPtr> BuildKernelActor(const GraphCompilerInfo &graph_compiler_info);
  83. std::vector<SuperKernelActorPtr> BuildSuperKernelActor(const GraphCompilerInfo &graph_compiler_info);
  84. LoopCountActorPtr BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info);
  85. OutputActorPtr BuildOutputActor(const GraphCompilerInfo &graph_compiler_info);
  86. DataPrepareActorPtr BuildDataPrepareActor(const GraphCompilerInfo &graph_compiler_info,
  87. const std::vector<DataSourceActorPtr> &data_source_actors,
  88. const HostTensorQueuePtr &host_queue);
  89. std::vector<AbstractActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set, GraphExecutionStrategy strategy);
  90. // Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
  91. // previous graph and the head of next graph.
  92. void CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info);
  93. // The processing of actors linking.
  94. // 1. The processing of linking data arrows.
  95. void LinkDataArrowInSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info,
  96. std::vector<AbstractActor *> *const auto_monad_actors);
  97. void LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info,
  98. std::vector<AbstractActor *> *const auto_monad_actors,
  99. std::vector<CNodePtr> *const communication_nodes);
  100. // The gather of linking data arrows of kernel, it will call following functions by the different from actor type.
  101. void LinkDataArrow(AbstractActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
  102. const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
  103. const KernelWithIndex &to_kernel_with_input_idx);
  104. void LinkDataArrowForBaseActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  105. const KernelWithIndex &from_kernel_with_output_idx,
  106. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  107. // Link data arrows for internal parameter, convert internal parameter to actor by internal parameter cache to link.
  108. void LinkDataArrowForInternalParameter(AbstractActor *const from_actor, AbstractActor *const to_actor,
  109. const KernelWithIndex &from_kernel_with_output_idx,
  110. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  111. void LinkDataArrowForDeviceTensorStore(AbstractActor *const from_actor, AbstractActor *const to_actor,
  112. const KernelWithIndex &from_kernel_with_output_idx,
  113. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  114. void LinkDataArrowForHostDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  115. const KernelWithIndex &from_kernel_with_output_idx,
  116. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  117. void LinkDataArrowForKernelActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  118. const KernelWithIndex &from_kernel_with_output_idx,
  119. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  120. // Link data arrows in the copy actor scene, insert the copy actor between from_actor and to_actor.
  121. void LinkDataArrowForCopyActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  122. const KernelWithIndex &from_kernel_with_output_idx,
  123. const KernelWithIndex &to_kernel_with_input_idx);
  124. // 2. The processing of linking control arrows.
  125. void LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph);
  126. // The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
  127. void LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node);
  128. // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
  129. void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);
  130. // The gather of linking the global control arrows, it will call following functions:
  131. void LinkGlobalControlArrow(ActorSet *const actor_set, const std::vector<CNodePtr> &communication_nodes,
  132. const std::vector<AbstractActor *> &auto_monad_actors,
  133. const GraphCompilerInfo &graph_compiler_info);
  134. // Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
  135. void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
  136. const GraphCompilerInfo &graph_compiler_info);
  137. void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<AbstractActor *> &auto_monad_actors);
  138. void LinkControlArrowForDataPrepareActor(DataPrepareActor *data_prepare_actor, const ActorSet *actor_set,
  139. const ControlNodeParserPtr &parser);
  140. void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
  141. const ControlNodeParserPtr &parser);
  142. // 3. The processing of linking output result arrows.
  143. void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);
  144. void AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor);
  145. // Add the arrow between from actor and to actor.
  146. void AddDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor, const AnfNodePtr &from_kernel,
  147. size_t from_output_index, size_t to_input_index);
  148. void AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor, const AnfNodePtr &from_kernel,
  149. size_t from_output_index, size_t output_position);
  150. void AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor);
  151. // Check whether the actor set is valid.
  152. void CheckActorValid(const ActorSet *actor_set) const;
  153. // Persist device tensors of graph's some nodes(such as weights and value nodes).
  154. void PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info);
  155. // Display the actor information of corresponding kernel graph.
  156. void DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const;
  157. void DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const;
  158. // The global maps, only be cleared in the deconstruction.
  159. mindspore::HashMap<ActorInfo, ActorSetPtr> actors_;
  160. // The local maps and vectors, will be cleared at the end of each graph transform:
  161. // 1.The second element of pair represents the output index of op actor corresponding to the graph output front node.
  162. std::map<KernelWithIndex, GraphOutputPair, session::KernelWithIndexCmp> graph_output_to_actor_;
  163. // 2.Beaceuse the copy actors are built in the link, so need record the all copy actors in the link process to push
  164. // into the actor set after link.
  165. std::vector<CopyActorPtr> copy_actors_;
  166. // In the control flow, used to build and link control actor.
  167. ControlNodeScheduler control_node_scheduler_;
  168. // The id of global actor.
  169. AID memory_manager_aid_;
  170. const AID *recorder_aid_{nullptr};
  171. const AID *debug_aid_{nullptr};
  172. bool init_{false};
  173. };
  174. } // namespace runtime
  175. } // namespace mindspore
  176. #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_