You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.h 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  17. #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
  18. #include <vector>
  19. #include <string>
  20. #include <memory>
  21. #include <utility>
  22. #include <unordered_map>
  23. #include <algorithm>
  24. #include "runtime/framework/actor/data_source_actor.h"
  25. #include "runtime/framework/actor/loop_count_actor.h"
  26. #include "runtime/framework/actor/kernel_actor.h"
  27. #include "runtime/hardware/device_context.h"
  28. #include "backend/session/kernel_graph.h"
  29. namespace mindspore {
  30. namespace runtime {
  31. using mindspore::device::DeviceContext;
  32. using mindspore::session::KernelWithIndex;
  33. enum class GraphExecutionStrategy {
  34. kPipeline, // The actor running is triggered only by data.
  35. kStep // The actor running need be triggered by control in addition.
  36. };
  37. // The actor set generated by graph transformer is the execution unit of actor runtime.
  38. // It includes data source actor, kernel actor, loop count actor.
  39. // The data source actor is used to obtain data and process them into device tensors,
  40. // and then send them to kernel actor. The kernel actor is used to receive the device tensors to luanch kernel.
  41. // Specifically notice the no input kernel actor, it means that this actor has no input device tensor, need be triggered
  42. // externally. The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
  43. // and decide whether to loop execution by loop count.
  44. struct ActorSet {
  45. std::vector<DataSourceActorPtr> data_source_actors_;
  46. std::vector<KernelActorPtr> kernel_actors_;
  47. // No input kernel actors need be triggered specifically.
  48. std::vector<KernelActorPtr> no_input_kernel_actors_;
  49. LoopCountActorPtr loop_count_actor_{nullptr};
  50. };
  51. using ActorSetPtr = std::shared_ptr<ActorSet>;
  52. class GraphScheduler {
  53. public:
  54. static GraphScheduler &GetInstance() {
  55. static GraphScheduler instance;
  56. return instance;
  57. }
  58. // The memory manager creating and scheduling.
  59. void Initialize();
  60. // Transform graph to actor DAG, contains build and link.
  61. ActorSet *Transform(const KernelGraphPtr &graph, const DeviceContext *device_context,
  62. const std::vector<TensorPtr> *input_tensors = nullptr,
  63. GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
  64. // Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling
  65. // will be supported in the future.
  66. void Schedule(const ActorSet *actor_set);
  67. // The prepare processing before run:
  68. // 1. Prepare the data of device tensor store(such as weights and value nodes of graph).
  69. // 2. Prepare the data of host tensor queue(such as non weighted parameters of graph).
  70. // 3. Prepare the output tensor of graph.
  71. void PrepareRun(const KernelGraphPtr &graph, const DeviceContext *device_context,
  72. const std::vector<TensorPtr> *input_tensors, VectorRef *const &outputs);
  73. // The processing entry of actors running.
  74. bool Run(const ActorSet *actor_set, GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
  75. // Fetch the actor set by kernel graph.
  76. ActorSet *Fetch(const KernelGraphPtr &graph) const;
  77. private:
  78. GraphScheduler() = default;
  79. ~GraphScheduler() = default;
  80. DISABLE_COPY_AND_ASSIGN(GraphScheduler);
  81. // Transform the nodes of graph to actors.
  82. ActorSetPtr Build(const KernelGraphPtr &graph, const DeviceContext *device_context);
  83. // Link actors to DAG through the edge connection of graph and graph execution strategy.
  84. void Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy);
  85. // The processing of actors build.
  86. std::vector<DataSourceActorPtr> BuildDataSourceActor(const KernelGraphPtr &graph,
  87. const DeviceContext *device_context);
  88. std::vector<KernelActorPtr> BuildKernelActor(const KernelGraphPtr &graph, const DeviceContext *device_context);
  89. std::vector<KernelActorPtr> BuildNoInputKernelActor(const KernelGraphPtr &graph);
  90. LoopCountActorPtr BuildLoopCountActor(const KernelGraphPtr &graph);
  91. // The processing of actors link.
  92. void LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
  93. KernelWithIndex from_kernel_with_output_idx,
  94. KernelWithIndex to_to_kernel_with_input_idx);
  95. void LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor,
  96. KernelWithIndex from_kernel_with_output_idx,
  97. KernelWithIndex to_kernel_with_input_idx);
  98. void LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
  99. KernelWithIndex from_kernel_with_output_idx,
  100. KernelWithIndex to_kernel_with_input_idx);
  101. void LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor, const KernelGraphPtr &graph,
  102. GraphExecutionStrategy strategy);
  103. void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph);
  104. // Persist device tensors of graph's some nodes(such as weights and value nodes).
  105. void PersistDeviceTensor(const KernelGraphPtr &graph);
  106. // Fetch the hsot tensor queue by kernel graph.
  107. HostTensorQueue *FetchHostQueue(const KernelGraphPtr &graph) const;
  108. std::unordered_map<KernelGraphPtr, ActorSetPtr> graph_to_actors_;
  109. std::unordered_map<KernelGraphPtr, HostTensorQueuePtr> graph_to_host_queue_;
  110. // The second element of pair represents the output index of kernel actor corresponding to the device tensor.
  111. std::unordered_map<DeviceTensorPtr, std::pair<KernelActorPtr, int>> device_address_to_actor_;
  112. // The id of memory manager actor.
  113. AID memory_manager_aid_;
  114. bool init_{false};
  115. };
  116. } // namespace runtime
  117. } // namespace mindspore
  118. #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_