Merge pull request !23906 from limingqi107/new_actor_runtimetags/v1.6.0
| @@ -377,6 +377,9 @@ class KernelGraph : public FuncGraph { | |||
| bool IsDatasetGraph() const; | |||
| bool is_sink() const { return is_sink_; } | |||
| void set_is_sink(bool is_sink) { is_sink_ = is_sink; } | |||
| private: | |||
| // remove value node form graph | |||
| bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); | |||
| @@ -487,6 +490,9 @@ class KernelGraph : public FuncGraph { | |||
| // Indicate whether the kernels in the graphs acquire Python GIL. | |||
| bool is_need_gil_{false}; | |||
| // Indicate whether the kernel graph sink to the device executing. | |||
| bool is_sink_{false}; | |||
| }; | |||
| } // namespace session | |||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | |||
| @@ -313,6 +313,27 @@ GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePt | |||
| return CompileGraphImpl(graph, device_context); | |||
| } | |||
| GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const DeviceContext *device_context) { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| // Generate kernel graph. | |||
| std::vector<KernelGraphPtr> all_graphs; | |||
| KernelGraphPtr root_graph = session_->ConstructKernelGraph(func_graph, &all_graphs); | |||
| MS_EXCEPTION_IF_NULL(root_graph); | |||
| for (const auto &graph : all_graphs) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| graph->set_root_graph_id(root_graph->graph_id()); | |||
| } | |||
| // Cache the backend graph output nodes to front nodes with output index. | |||
| auto output = func_graph->output(); | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| auto backend_node = root_graph->GetBackendAnfByFrontAnf(output); | |||
| MS_EXCEPTION_IF_NULL(backend_node); | |||
| root_graph->CacheGraphOutputToFrontNodeWithIndex(backend_node, output); | |||
| return CompileGraphImpl(root_graph, device_context); | |||
| } | |||
| GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| @@ -339,6 +360,10 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic | |||
| // Adjust kernel graph before run graph. | |||
| device_context->PreprocessBeforeRunGraph(graph); | |||
| // Set the graph sink flag. | |||
| auto is_sink = device_context->IsGraphSink(graph); | |||
| graph->set_is_sink(is_sink); | |||
| MS_LOG(INFO) << "Get graph outputs after optimizer, graph id: " << graph->graph_id(); | |||
| auto outputs_after_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output()); | |||
| // Update the output map of kernel graph by modified output nodes. | |||
| @@ -99,6 +99,10 @@ class GraphCompiler { | |||
| // the detailed implementation of compiling graph is in 'CompileGraphImpl'. | |||
| GraphId CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs, const DeviceContext *device_context); | |||
| // Construct kernel graph from function graph and compile kernel graph in Graph mode, | |||
| // the detailed implementation of compiling graph is in 'CompileGraphImpl'. | |||
| GraphId CompileGraph(const FuncGraphPtr &func_graph, const DeviceContext *device_context); | |||
| // Construct single op kernel graph and compile the kernel graph in PyNative mode. | |||
| GraphId CompileGraph(const session::OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<int64_t> *tensors_mask, std::vector<TensorPtr> *const input_tensors, | |||
| @@ -55,6 +55,16 @@ class DeviceContext { | |||
| // Destroy device context and release device resource. | |||
| virtual void Destroy() {} | |||
| // Partition the function graph through the device capability and return the partition segments. | |||
| // The second parameter is the default partition segments which are provided by the framework. | |||
| // Device can reprocess the default partition segments to new segments, also can partition the function graph again. | |||
| // If Device can launch the whole graph and not expect partitioning the function graph, then return the empty | |||
| // segments. The default behavior is return the default partition segments. | |||
| virtual std::vector<GraphSegmentPtr> PartitionGraph(const FuncGraphPtr &func_graph, | |||
| const std::vector<GraphSegmentPtr> &default_partition_segments) { | |||
| return default_partition_segments; | |||
| } | |||
| // Relevant function to allocate and free device memory. | |||
| virtual bool AllocateMemory(DeviceAddress *const &address, size_t size) const = 0; | |||
| virtual void FreeMemory(DeviceAddress *const &address) const = 0; | |||
| @@ -97,10 +107,18 @@ class DeviceContext { | |||
| // Infer kernel shape and update abstract info for dynamic shape kernel. | |||
| virtual void UpdateDynamicShape(const CNodePtr &kernel) const { AnfAlgo::InferShape(kernel); } | |||
| // Whether the graph sink executing through the device capability, the default behavior is not sink and return false. | |||
| virtual bool IsGraphSink(const KernelGraphPtr &graph) const { return false; } | |||
| // Launch graph, device such as Ascend support the whole graph sink to the device executing. | |||
| virtual bool LaunchGraph(const KernelGraphPtr &graph) const { return true; } | |||
| // Launch a kernel via 'KernelMod' of the kernel. | |||
| virtual bool LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs, | |||
| bool is_dynamic_shape = false) const = 0; | |||
| bool is_dynamic_shape = false) const { | |||
| return true; | |||
| } | |||
| // Synchronize stream, device such as GPU and Ascend need stream to launch kernel asynchronously, | |||
| // using 'SyncStream' to block thread and wait for completing all tasks in stream. | |||
| @@ -379,14 +379,16 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) { | |||
| // Compile root graph. | |||
| graph_id_to_device_context_.clear(); | |||
| control_nodes_.clear(); | |||
| CompileGraph(root_graph); | |||
| auto subgraph_need_compile = CompileGraph(root_graph); | |||
| // Compile sub graphs. | |||
| MS_EXCEPTION_IF_NULL(root_graph->manager()); | |||
| FuncGraphSet sub_graphs = root_graph->manager()->func_graphs(); | |||
| for (auto sub_graph : sub_graphs) { | |||
| if (sub_graph != func_graph && sub_graph != nullptr) { | |||
| CompileGraph(sub_graph); | |||
| if (subgraph_need_compile) { | |||
| MS_EXCEPTION_IF_NULL(root_graph->manager()); | |||
| FuncGraphSet sub_graphs = root_graph->manager()->func_graphs(); | |||
| for (auto sub_graph : sub_graphs) { | |||
| if (sub_graph != func_graph && sub_graph != nullptr) { | |||
| (void)CompileGraph(sub_graph); | |||
| } | |||
| } | |||
| } | |||
| @@ -404,7 +406,7 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) { | |||
| return actor_info; | |||
| } | |||
| void MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) { | |||
| bool MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(graph_partition_); | |||
| MS_EXCEPTION_IF_NULL(graph_compiler_); | |||
| @@ -413,53 +415,68 @@ void MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) { | |||
| // Split graph to segments. | |||
| const auto &segments = graph_partition_->Partition(func_graph, &contain_multi_target); | |||
| MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size:" << segments.size(); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| const auto &device_context = | |||
| device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_}); | |||
| const auto &new_segments = device_context->PartitionGraph(func_graph, segments); | |||
| // Compile the whole function graph if not split graph. | |||
| if (new_segments.size() == 0) { | |||
| auto graph_id = graph_compiler_->CompileGraph(func_graph, device_context); | |||
| graph_id_to_device_context_[graph_id] = device_context; | |||
| return false; | |||
| } | |||
| // Foreach the segments to compile graph. | |||
| for (const auto &segment : segments) { | |||
| MS_EXCEPTION_IF_NULL(segment); | |||
| // Compile the normal nodes, which doesn't contain the cut node. | |||
| if (!segment->is_cut_) { | |||
| if (segment->nodes_.size() == 0) { | |||
| MS_LOG(EXCEPTION) << "The segments size is 0."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(segment->nodes_[0]); | |||
| MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->fullname_with_scope(); | |||
| // Get the device context. | |||
| const auto &cur_device_name = GetCNodeTarget(segment->nodes_[0]); | |||
| const auto &device_context = | |||
| device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_}); | |||
| device_context->Initialize(); | |||
| // Transform nodes to inputs and outputs. | |||
| FuncGraphPtr fg; | |||
| AnfNodePtrList inputs; | |||
| AnfNodePtrList outputs; | |||
| std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_); | |||
| // There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode. | |||
| if (contain_multi_target && ms_execution_mode_ == kPynativeMode) { | |||
| real_execution_mode_ = kGraphMode; | |||
| context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| } | |||
| for (const auto &segment : new_segments) { | |||
| CompileGraph(segment, contain_multi_target); | |||
| } | |||
| return true; | |||
| } | |||
| // Compile graph. | |||
| auto graph_id = graph_compiler_->CompileGraph(segment->nodes_, outputs, device_context); | |||
| void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target) { | |||
| MS_EXCEPTION_IF_NULL(segment); | |||
| // Compile the normal nodes, which doesn't contain the cut node. | |||
| if (!segment->is_cut_) { | |||
| if (segment->nodes_.size() == 0) { | |||
| MS_LOG(EXCEPTION) << "The segments size is 0."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(segment->nodes_[0]); | |||
| MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->fullname_with_scope(); | |||
| // Get the device context. | |||
| const auto &cur_device_name = GetCNodeTarget(segment->nodes_[0]); | |||
| const auto &device_context = | |||
| device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_}); | |||
| device_context->Initialize(); | |||
| // Transform nodes to inputs and outputs. | |||
| FuncGraphPtr fg; | |||
| AnfNodePtrList inputs; | |||
| AnfNodePtrList outputs; | |||
| std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| // There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode. | |||
| if (contain_multi_target && ms_execution_mode_ == kPynativeMode) { | |||
| real_execution_mode_ = kGraphMode; | |||
| context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| } | |||
| if (ms_execution_mode_ != real_execution_mode_) { | |||
| context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_); | |||
| } | |||
| // Compile graph. | |||
| auto graph_id = graph_compiler_->CompileGraph(segment->nodes_, outputs, device_context); | |||
| graph_id_to_device_context_[graph_id] = device_context; | |||
| } else { | |||
| // Compile the cut node. | |||
| auto cut_node = segment->nodes_[0]; | |||
| MS_EXCEPTION_IF_NULL(cut_node); | |||
| MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->fullname_with_scope(); | |||
| control_nodes_.push_back(cut_node); | |||
| if (ms_execution_mode_ != real_execution_mode_) { | |||
| context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_); | |||
| } | |||
| graph_id_to_device_context_[graph_id] = device_context; | |||
| } else { | |||
| // Compile the cut node. | |||
| auto cut_node = segment->nodes_[0]; | |||
| MS_EXCEPTION_IF_NULL(cut_node); | |||
| MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->fullname_with_scope(); | |||
| control_nodes_.push_back(cut_node); | |||
| } | |||
| } | |||
| @@ -125,7 +125,11 @@ class MindRTBackend : public Backend { | |||
| private: | |||
| // The parameter func_graph is a graph, it can be either a root graph or a sub graph, | |||
| // The result of graph compiler is stored in graph_id_to_device_context_ and control_nodes_. | |||
| void CompileGraph(const FuncGraphPtr &func_graph); | |||
| // The return value indicates whether the subgraph needs to be compiled recursively. | |||
| bool CompileGraph(const FuncGraphPtr &func_graph); | |||
| // Compile the kernel graph by the segment which is from the function graph partition. | |||
| void CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target); | |||
| // Restore the outputs tuple by the origin funcGraph output node and output tensors. | |||
| void ConstructOutputs(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors, | |||