You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.h 20 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  17. #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  18. #include <vector>
  19. #include <string>
  20. #include <memory>
  21. #include <utility>
  22. #include <unordered_map>
  23. #include <unordered_set>
  24. #include <map>
  25. #include <set>
  26. #include <algorithm>
  27. #include <fstream>
  28. #include "runtime/framework/actor/data_source_actor.h"
  29. #include "runtime/framework/actor/loop_count_actor.h"
  30. #include "runtime/framework/actor/kernel_actor.h"
  31. #include "runtime/framework/actor/output_actor.h"
  32. #include "runtime/framework/actor/switch_actor.h"
  33. #include "runtime/framework/actor/gather_actor.h"
  34. #include "runtime/framework/actor/copy_actor.h"
  35. #include "runtime/hardware/device_context.h"
  36. #include "backend/session/kernel_graph.h"
  37. #include "thread/actor_threadpool.h"
  38. namespace mindspore {
  39. namespace runtime {
  40. using mindspore::device::DeviceContext;
  41. using mindspore::session::KernelGraph;
  42. using mindspore::session::KernelWithIndex;
  43. // Position of kernel with index, the value pair<branch_id, vector<pos>> means the branch id of the kernel and the pos
  44. // of the kernel. Generally, there is only one branch, and the branch id is 0 at this time. In control flow, there are
  45. // multiple branch scenarios, and pos represents the position of the kernel in the branch.
  46. using KernelMapPosition = std::map<KernelWithIndex, std::vector<size_t>, session::KernelWithIndexCmp>;
  47. using ActorInfo = std::string;
  48. // The second element of pair represents the output index of abstract actor corresponding to the graph output node.
  49. using GraphOutputPair = std::pair<AbstractActor *, size_t>;
  50. // The graph compiler info generated by graph compiler is the express of executable graph.
  51. // The device context is unified interface of interaction with device of corresponding graph.
  52. // The tensors mask is used to distinguish input tensor's type.
  53. // The input tensor is used to link graphs in the dynamic build scenario.
  54. // The control node is used to link graphs in the control flow scenario.
  55. // The control node parser is used to parse the edge info in control nodes.
  56. // The origin parameters order is used to correspond to the input args.
  57. // The origin outputs order is used to correspond to the output args.
  58. // The need_erase means need erase this GraphCompilerInfo object after run actor set.
  59. struct GraphCompilerInfo {
  60. GraphCompilerInfo(const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
  61. const std::vector<std::vector<int64_t> *> &tensors_mask,
  62. const std::vector<std::vector<TensorPtr> *> &input_tensors,
  63. const std::vector<AnfNodePtr> &control_nodes,
  64. const std::vector<AnfNodePtr> &origin_parameters_order, const ControlNodeParserPtr &parser,
  65. const KernelMapPosition &origin_outputs_order, const size_t outputs_num, const std::string &name,
  66. bool need_erase, GraphExecutionStrategy strategy)
  67. : graphs_(graphs),
  68. device_contexts_(device_contexts),
  69. tensors_mask_(tensors_mask),
  70. input_tensors_(input_tensors),
  71. control_nodes_(control_nodes),
  72. control_node_parser_(parser),
  73. origin_parameters_order_(origin_parameters_order),
  74. origin_outputs_order_(origin_outputs_order),
  75. outputs_num_(outputs_num),
  76. name_(name),
  77. need_erase_(need_erase),
  78. strategy_(strategy) {}
  79. ~GraphCompilerInfo();
  80. std::vector<KernelGraphPtr> graphs_;
  81. std::vector<DeviceContext *> device_contexts_;
  82. std::vector<std::vector<int64_t> *> tensors_mask_;
  83. std::vector<std::vector<TensorPtr> *> input_tensors_;
  84. std::vector<AnfNodePtr> control_nodes_;
  85. ControlNodeParserPtr control_node_parser_;
  86. std::vector<AnfNodePtr> origin_parameters_order_;
  87. KernelMapPosition origin_outputs_order_;
  88. size_t outputs_num_;
  89. std::string name_;
  90. bool need_erase_;
  91. GraphExecutionStrategy strategy_;
  92. };
  93. // The actor set generated by graph transformer is the execution unit of actor runtime.
  94. // It includes data source actor, kernel actor, switch actor, copy actor, loop count actor and output actor.
  95. // The data source actor is used to obtain data and process them into device tensors, and send them to kernel actor.
  96. // The kernel actor is used to receive the device tensors to luanch kernel. Specifically notice the no input
  97. // kernel actor, it means that this actor has no input device tensor, need be triggered externally.
  98. // The switch actor is used to run different branches in the control flow scenario.
  99. // The gather actor is used to collect the inputs of graph and send branch id to loop count actor in multi-branch
  100. // output scenario.
  101. // The copy actor is used to convert the device tensor between the different device kernel.
  102. // The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
  103. // and decide whether to loop execution by loop count.
  104. // The output actor is used to receive the output result of actor which represents the graph output.
  105. struct ActorSet {
  106. explicit ActorSet(const ActorInfo &name) : name_(name) {}
  107. std::vector<DataSourceActorPtr> data_source_actors_;
  108. std::vector<KernelActorPtr> kernel_actors_;
  109. // No input kernel actors need be triggered specifically.
  110. std::vector<KernelActorPtr> no_input_kernel_actors_;
  111. std::vector<SwitchActorPtr> switch_actors_;
  112. std::vector<GatherActorPtr> gather_actors_;
  113. std::vector<CopyActorPtr> copy_actors_;
  114. LoopCountActorPtr loop_count_actor_{nullptr};
  115. OutputActorPtr output_actor_{nullptr};
  116. ActorInfo name_;
  117. };
  118. using ActorSetPtr = std::shared_ptr<ActorSet>;
  119. class GraphScheduler {
  120. public:
  121. static GraphScheduler &GetInstance() {
  122. static GraphScheduler instance;
  123. return instance;
  124. }
  125. // 1. Thread pool creating.
  126. // 2. The global actors creating and scheduling.
  127. void Initialize();
  128. // Clear the members.
  129. void Clear();
  130. void Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs);
  131. // Transform graph to actor DAG, contains build and link.
  132. ActorSet *Transform(const GraphCompilerInfo &graph_compiler_info);
  133. // Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling
  134. // will be supported in the future.
  135. void Schedule(const ActorSet *actor_set);
  136. // The prepare processing before run. (used in pipeline mode):
  137. // 1. Prepare the data of device tensor store(such as weights and value nodes of graph).
  138. // 2. Prepare the data of host tensor queue(such as non weighted parameters of graph).
  139. // 3. Prepare the continuous memory for communication kernel.
  140. void PrepareRun(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
  141. const std::vector<std::vector<TensorPtr>> &input_tensors);
  142. // The prepare processing before run. (used in step mode):
  143. // 1. Prepare the data of device tensor store(such as weights and value nodes of graph).
  144. // 2. Prepare the data of host tensor queue(such as non weighted parameters of graph).
  145. void PrepareRunOp(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
  146. const std::vector<std::vector<TensorPtr>> &input_tensors);
  147. // The processing entry of actors running.
  148. bool Run(const ActorSet *actor_set, GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline,
  149. const std::vector<TensorPtr> *input_tensors = nullptr);
  150. // Fetch the actor set by actor info.
  151. ActorSet *Fetch(const ActorInfo &actor_info) const;
  152. private:
  153. GraphScheduler() = default;
  154. ~GraphScheduler() = default;
  155. DISABLE_COPY_AND_ASSIGN(GraphScheduler);
  156. // The Global actors contain memory manager actor, recorder actor and debug actor.
  157. void BuildAndScheduleGlobalActor();
  158. // Transform the nodes of graph to actors.
  159. ActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info);
  160. // Link actors to DAG through the edge connection of graph and graph execution strategy.
  161. void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info);
  162. // The processing of actors build.
  163. std::vector<DataSourceActorPtr> BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
  164. const HostTensorQueuePtr &host_queue);
  165. std::vector<KernelActorPtr> BuildKernelActor(const GraphCompilerInfo &graph_compiler_info);
  166. LoopCountActorPtr BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info);
  167. OutputActorPtr BuildOutputActor(const GraphCompilerInfo &graph_compiler_info);
  168. std::vector<KernelActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set, GraphExecutionStrategy strategy);
  169. std::vector<SwitchActorPtr> BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info);
  170. std::vector<GatherActorPtr> BuildGatherActor(const GraphCompilerInfo &graph_compiler_info);
  171. // Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
  172. // previous graph and the head of next graph.
  173. void CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info);
  174. // The processing of actors link statically.
  175. // 1. The processing of linking data arrows.
  176. // The gather of linking data arrows of kernel, it will call following functions by the different from actor type.
  177. void LinkDataArrow(KernelActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
  178. const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
  179. const KernelWithIndex &to_kernel_with_input_idx);
  180. void LinkDataArrowForBaseActor(AbstractActor *const from_actor, KernelActor *const to_actor,
  181. const KernelWithIndex &from_kernel_with_output_idx,
  182. const KernelWithIndex &to_kernel_with_input_idx);
  183. // Link data arrows for internal parameter, convert internal parameter to actor by internal parameter cache to link.
  184. void LinkDataArrowForInternalParameter(AbstractActor *const from_actor, KernelActor *const to_actor,
  185. const KernelWithIndex &from_kernel_with_output_idx,
  186. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  187. void LinkDataArrowForDeviceTensorStore(AbstractActor *const from_actor, KernelActor *const to_actor,
  188. const KernelWithIndex &from_kernel_with_output_idx,
  189. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  190. void LinkDataArrowForDeviceDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
  191. const KernelWithIndex &from_kernel_with_output_idx,
  192. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  193. void LinkDataArrowForHostDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
  194. const KernelWithIndex &from_kernel_with_output_idx,
  195. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  196. void LinkDataArrowForKernelActor(AbstractActor *const from_actor, KernelActor *const to_actor,
  197. const KernelWithIndex &from_kernel_with_output_idx,
  198. const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
  199. // Link data arrows in the copy actor scene, insert the copy actor between from_actor and to_actor.
  200. void LinkDataArrowForCopyActor(AbstractActor *const from_actor, KernelActor *const to_actor,
  201. const KernelWithIndex &from_kernel_with_output_idx,
  202. const KernelWithIndex &to_kernel_with_input_idx);
  203. // 2. The processing of linking control arrows.
  204. void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
  205. const ControlNodeParserPtr &parser);
  206. void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph);
  207. // The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
  208. void LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node);
  209. // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
  210. void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);
  211. // Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
  212. void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
  213. const GraphCompilerInfo &graph_compiler_info);
  214. void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors);
  215. // 3. The processing of linking output result arrows.
  216. void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);
  217. // 4. The processing of control flow linking.
  218. void LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *const actor_set);
  219. void LinkDataArrowForGatherActor(GatherActor *const from_actor, KernelActor *const to_actor,
  220. const KernelWithIndex &front_node_with_index,
  221. const KernelWithIndex &to_node_with_index);
  222. void LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, SwitchActor *const actor);
  223. // Connect the input of the actor.
  224. void LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, const KernelWithIndex &input_node,
  225. const FuncGraphPtr &from_func_graph, OpActor<DeviceTensor> *const to_actor,
  226. const size_t to_index);
  227. // When the input of the actor is a call node, the output of the funcgraph called by the call node needs to be
  228. // connected.
  229. void LinkDataArrowByCallInput(const KernelWithIndex &call_node_with_index, const ControlNodeParserPtr &parser,
  230. const FuncGraphPtr &from_func_graph, OpActor<DeviceTensor> *const to_actor,
  231. const size_t to_index);
  232. void LinkDataArrowForSwitchActor(SwitchActor *const from_actor, const size_t from_index,
  233. OpActor<DeviceTensor> *const to_actor, const size_t to_index,
  234. const size_t branch_index = SIZE_MAX);
  235. void LinkControlArrowForGatherActor(std::vector<KernelActorPtr> *const kernel_actors,
  236. const std::vector<KernelGraphPtr> &graphs, const ControlNodeParserPtr &parser);
  237. void LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *const switch_actors, LoopCountActor *const to_actor,
  238. const KernelMapPosition &origin_outputs_order);
  239. // In control flow, there are scenarios where there are multi-branch outputs, and the gather actor needs to
  240. // send the branch id to the loop count actor.
  241. void LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info);
  242. void LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info);
  243. void LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
  244. void PrepareDataForControlNode(HostQueueDataSourceActor *const host_data_source_actor,
  245. const ControlNodeParserPtr &control_node_parser,
  246. const std::vector<AnfNodePtr> &origin_parameters,
  247. const std::vector<TensorPtr> &tensors, std::vector<TensorPtr> *const host_tensors);
  248. // Add input for switch actor. Since part of the input of funcgraph is on call node, these inputs need to be added
  249. // to switch actor.
  250. void PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes);
  251. // Check whether the actor set is valid.
  252. bool CheckActorValid(const ActorSet *actor_set,
  253. GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline) const;
  254. // Persist device tensors of graph's some nodes(such as weights and value nodes).
  255. void PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info);
  256. // Fetch the hsot tensor queue by actor info.
  257. HostTensorQueue *FetchHostQueue(const ActorInfo &actor_info) const;
  258. // The fetch results are kernel_type and kernel_name.
  259. void FetchKernelTransformTypeAndName(const AnfNodePtr &node, const KernelGraphPtr &graph,
  260. const GraphCompilerInfo &graph_compiler_info,
  261. KernelTransformType *const kernel_type, std::string *const kernel_name);
  262. // The operation of the map of actor_name_to_actor_.
  263. void InsertActor(OpActor<DeviceTensor> *actor);
  264. OpActor<DeviceTensor> *FetchActor(const std::string &actor_name) const;
  265. // Display the actor information of corresponding kernel graph.
  266. void DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const;
  267. void DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) const;
  268. void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const;
  269. void DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const;
  270. void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const;
  271. void DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const;
  272. void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const;
  273. void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) const;
  274. void DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) const;
  275. void DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const;
  276. // The global maps, only be cleared in the deconstruction.
  277. std::unordered_map<ActorInfo, ActorSetPtr> actors_;
  278. std::unordered_map<std::string, OpActor<DeviceTensor> *> actor_name_to_actor_;
  279. std::unordered_map<ActorInfo, HostTensorQueuePtr> actor_to_host_queue_;
  280. // The local maps and vectors, will be cleared at the end of each graph transform:
  281. // 1.The second element of pair represents the output index of op actor corresponding to the graph output front node.
  282. std::map<KernelWithIndex, GraphOutputPair, session::KernelWithIndexCmp> graph_output_to_actor_;
  283. // 2.Since the control node does not have a backend node, it can only be connected through the relationship between
  284. // the front node, so the mapping relationship between the front node and the actor needs to be recorded.
  285. std::unordered_map<AnfNodePtr, KernelActorPtr> front_node_to_actor_;
  286. // 3.Beaceuse the copy actors are built in the link, so need record the all copy actors in the link process to push
  287. // into the actor set after link.
  288. std::vector<CopyActorPtr> copy_actors_;
  289. // The id of global actor.
  290. AID memory_manager_aid_;
  291. const AID *recorder_aid_{nullptr};
  292. const AID *debug_aid_{nullptr};
  293. bool init_{false};
  294. };
  295. } // namespace runtime
  296. } // namespace mindspore
  297. #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_