You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.h 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  17. #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  18. #include <vector>
  19. #include <string>
  20. #include <memory>
  21. #include <utility>
  22. #include <unordered_map>
  23. #include <algorithm>
  24. #include <fstream>
  25. #include "runtime/framework/actor/data_source_actor.h"
  26. #include "runtime/framework/actor/loop_count_actor.h"
  27. #include "runtime/framework/actor/kernel_actor.h"
  28. #include "runtime/hardware/device_context.h"
  29. #include "backend/session/kernel_graph.h"
  30. namespace mindspore {
  31. namespace runtime {
  32. using mindspore::device::DeviceContext;
  33. using mindspore::session::KernelWithIndex;
  34. using KernelMapActor = std::unordered_map<std::string, KernelActorPtr>;
  35. enum class GraphExecutionStrategy {
  36. kPipeline, // The actor running is triggered only by data.
  37. kStep // The actor running need be triggered by control in addition.
  38. };
  39. // The actor set generated by graph transformer is the execution unit of actor runtime.
  40. // It includes data source actor, kernel actor, loop count actor.
  41. // The data source actor is used to obtain data and process them into device tensors,
  42. // and then send them to kernel actor. The kernel actor is used to receive the device tensors to luanch kernel.
  43. // Specifically notice the no input kernel actor, it means that this actor has no input device tensor, need be triggered
  44. // externally. The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
  45. // and decide whether to loop execution by loop count.
  46. struct ActorSet {
  47. std::vector<DataSourceActorPtr> data_source_actors_;
  48. std::vector<KernelActorPtr> kernel_actors_;
  49. // No input kernel actors need be triggered specifically.
  50. std::vector<KernelActorPtr> no_input_kernel_actors_;
  51. LoopCountActorPtr loop_count_actor_{nullptr};
  52. };
  53. using ActorSetPtr = std::shared_ptr<ActorSet>;
  54. class GraphScheduler {
  55. public:
  56. static GraphScheduler &GetInstance() {
  57. static GraphScheduler instance;
  58. return instance;
  59. }
  60. // 1. Thread pool creating.
  61. // 2. The memory manager creating and scheduling.
  62. void Initialize();
  63. // Transform graph to actor DAG, contains build and link.
  64. ActorSet *Transform(const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
  65. const std::vector<TensorPtr> *input_tensors = nullptr,
  66. const std::vector<AnfNodePtr> *control_nodes = nullptr,
  67. GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
  68. // Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling
  69. // will be supported in the future.
  70. void Schedule(const ActorSet *actor_set);
  71. // The prepare processing before run:
  72. // 1. Prepare the data of device tensor store(such as weights and value nodes of graph).
  73. // 2. Prepare the data of host tensor queue(such as non weighted parameters of graph).
  74. // 3. Prepare the output tensor of graph.
  75. // 4.Prepare the continuous memory for communication kernel.
  76. void PrepareRun(const KernelGraphPtr &graph, const std::vector<TensorPtr> *input_tensors, VectorRef *const &outputs);
  77. // The processing entry of actors running.
  78. bool Run(const ActorSet *actor_set, GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
  79. // Fetch the actor set by kernel graph.
  80. ActorSet *Fetch(const KernelGraphPtr &graph) const;
  81. private:
  82. GraphScheduler() = default;
  83. ~GraphScheduler() = default;
  84. DISABLE_COPY_AND_ASSIGN(GraphScheduler);
  85. // Transform the nodes of graph to actors.
  86. ActorSetPtr Build(const KernelGraphPtr &graph, const DeviceContext *device_context);
  87. // Link actors to DAG through the edge connection of graph and graph execution strategy.
  88. void Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy);
  89. // The processing of actors build.
  90. std::vector<DataSourceActorPtr> BuildDataSourceActor(const KernelGraphPtr &graph,
  91. const DeviceContext *device_context);
  92. std::vector<KernelActorPtr> BuildKernelActor(const KernelGraphPtr &graph, const DeviceContext *device_context);
  93. std::vector<KernelActorPtr> BuildNoInputKernelActor(const KernelGraphPtr &graph);
  94. LoopCountActorPtr BuildLoopCountActor(const KernelGraphPtr &graph);
  95. // The processing of actors link.
  96. void LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
  97. KernelWithIndex from_kernel_with_output_idx,
  98. KernelWithIndex to_to_kernel_with_input_idx);
  99. void LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor,
  100. KernelWithIndex from_kernel_with_output_idx,
  101. KernelWithIndex to_kernel_with_input_idx);
  102. void LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
  103. KernelWithIndex from_kernel_with_output_idx,
  104. KernelWithIndex to_kernel_with_input_idx);
  105. void LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor, const KernelGraphPtr &graph,
  106. GraphExecutionStrategy strategy);
  107. void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph);
  108. void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node,
  109. const KernelMapActor &kernel_actors_map);
  110. // Check whether the actor set is valid.
  111. bool CheckActorValid(const ActorSet *actor_set) const;
  112. // Persist device tensors of graph's some nodes(such as weights and value nodes).
  113. void PersistDeviceTensor(const KernelGraphPtr &graph);
  114. // Fetch the hsot tensor queue by kernel graph.
  115. HostTensorQueue *FetchHostQueue(const KernelGraphPtr &graph) const;
  116. // Display the actor information of corresponding kernel graph.
  117. void DumpActor(const KernelGraphPtr &graph) const;
  118. void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const;
  119. void DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const;
  120. void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const;
  121. std::unordered_map<KernelGraphPtr, ActorSetPtr> graph_to_actors_;
  122. std::unordered_map<KernelGraphPtr, HostTensorQueuePtr> graph_to_host_queue_;
  123. // The second element of pair represents the output index of kernel actor corresponding to the device tensor.
  124. std::unordered_map<DeviceTensorPtr, std::pair<KernelActorPtr, int>> device_address_to_actor_;
  125. // The id of memory manager actor.
  126. AID memory_manager_aid_;
  127. bool init_{false};
  128. };
  129. } // namespace runtime
  130. } // namespace mindspore
  131. #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_