/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_ #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_ #include #include #include #include #include #include #include #include "runtime/framework/actor/data_source_actor.h" #include "runtime/framework/actor/loop_count_actor.h" #include "runtime/framework/actor/kernel_actor.h" #include "runtime/hardware/device_context.h" #include "backend/session/kernel_graph.h" namespace mindspore { namespace runtime { using mindspore::device::DeviceContext; using mindspore::session::KernelWithIndex; using KernelMapActor = std::unordered_map; enum class GraphExecutionStrategy { kPipeline, // The actor running is triggered only by data. kStep // The actor running need be triggered by control in addition. }; // The actor set generated by graph transformer is the execution unit of actor runtime. // It includes data source actor, kernel actor, loop count actor. // The data source actor is used to obtain data and process them into device tensors, // and then send them to kernel actor. The kernel actor is used to receive the device tensors to luanch kernel. // Specifically notice the no input kernel actor, it means that this actor has no input device tensor, need be triggered // externally. The loop count actor is used to receive the control of tail kernel actor to represent the end of one step // and decide whether to loop execution by loop count. struct ActorSet { std::vector data_source_actors_; std::vector kernel_actors_; // No input kernel actors need be triggered specifically. std::vector no_input_kernel_actors_; LoopCountActorPtr loop_count_actor_{nullptr}; }; using ActorSetPtr = std::shared_ptr; class GraphScheduler { public: static GraphScheduler &GetInstance() { static GraphScheduler instance; return instance; } // 1. Thread pool creating. // 2. The memory manager creating and scheduling. void Initialize(); // Transform graph to actor DAG, contains build and link. ActorSet *Transform(const std::vector &graphs, const std::vector &device_contexts, const std::vector *input_tensors = nullptr, const std::vector *control_nodes = nullptr, GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline); // Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling // will be supported in the future. void Schedule(const ActorSet *actor_set); // The prepare processing before run: // 1. Prepare the data of device tensor store(such as weights and value nodes of graph). // 2. Prepare the data of host tensor queue(such as non weighted parameters of graph). // 3. Prepare the output tensor of graph. // 4.Prepare the continuous memory for communication kernel. void PrepareRun(const KernelGraphPtr &graph, const std::vector *input_tensors, VectorRef *const &outputs); // The processing entry of actors running. bool Run(const ActorSet *actor_set, GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline); // Fetch the actor set by kernel graph. ActorSet *Fetch(const KernelGraphPtr &graph) const; private: GraphScheduler() = default; ~GraphScheduler() = default; DISABLE_COPY_AND_ASSIGN(GraphScheduler); // Transform the nodes of graph to actors. ActorSetPtr Build(const KernelGraphPtr &graph, const DeviceContext *device_context); // Link actors to DAG through the edge connection of graph and graph execution strategy. void Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy); // The processing of actors build. std::vector BuildDataSourceActor(const KernelGraphPtr &graph, const DeviceContext *device_context); std::vector BuildKernelActor(const KernelGraphPtr &graph, const DeviceContext *device_context); std::vector BuildNoInputKernelActor(const KernelGraphPtr &graph); LoopCountActorPtr BuildLoopCountActor(const KernelGraphPtr &graph); // The processing of actors link. void LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor, KernelWithIndex from_kernel_with_output_idx, KernelWithIndex to_to_kernel_with_input_idx); void LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor, KernelWithIndex from_kernel_with_output_idx, KernelWithIndex to_kernel_with_input_idx); void LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor, KernelWithIndex from_kernel_with_output_idx, KernelWithIndex to_kernel_with_input_idx); void LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor, const KernelGraphPtr &graph, GraphExecutionStrategy strategy); void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph); void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node, const KernelMapActor &kernel_actors_map); // Check whether the actor set is valid. bool CheckActorValid(const ActorSet *actor_set) const; // Persist device tensors of graph's some nodes(such as weights and value nodes). void PersistDeviceTensor(const KernelGraphPtr &graph); // Fetch the hsot tensor queue by kernel graph. HostTensorQueue *FetchHostQueue(const KernelGraphPtr &graph) const; // Display the actor information of corresponding kernel graph. void DumpActor(const KernelGraphPtr &graph) const; void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const; void DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const; void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const; std::unordered_map graph_to_actors_; std::unordered_map graph_to_host_queue_; // The second element of pair represents the output index of kernel actor corresponding to the device tensor. std::unordered_map> device_address_to_actor_; // The id of memory manager actor. AID memory_manager_aid_; bool init_{false}; }; } // namespace runtime } // namespace mindspore #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_