From 6e934530a36d48766bc84e0a003c59cf70a46e14 Mon Sep 17 00:00:00 2001 From: chenfei Date: Thu, 14 May 2020 14:51:00 +0800 Subject: [PATCH] new compile graph of new control sink --- mindspore/ccsrc/session/ascend_session.cc | 37 +++++++++++++++++++++++ mindspore/ccsrc/session/ascend_session.h | 9 ++++++ mindspore/ccsrc/session/kernel_graph.cc | 18 +++++++++++ mindspore/ccsrc/session/kernel_graph.h | 14 +++++++++ mindspore/ccsrc/session/session_basic.cc | 2 ++ mindspore/ccsrc/session/session_basic.h | 2 ++ 6 files changed, 82 insertions(+) diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index deec2c648a..4fb46b604d 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -138,6 +138,43 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL return graph_id; } +GraphId AscendSession::CompileGraph(const FuncGraphPtr &func_graph) { + MS_LOG(INFO) << "start"; + auto graph = ConstructKernelGraph(func_graph); + // split switch + SplitSwitch(graph.get()); + // insert goto labels and label_sets + LinkChildGraphs(graph.get()); + // resource initialize + InitRuntimeResource(); + // ir fusion + IRFusion(graph); + // kernel select + SelectKernelGraphKernel(*graph); + // convert model of predict module + ConvertPredictModel(graph); + // hardware optimize + HardwareOptimizeGraphs(graph); + // adjust kernel + AdjustKernel(graph); + // root graph valiate,include genearte execute order and so on + RootGraphExecutorValidate(graph.get()); + // assign stream + AssignStream(graph); + // build kernel if node is cnode + BuildKernel(graph); + // alloc mem + MemoryAlloc(graph.get()); + // task generate + GenerateTaskInfo(graph); + // load task into device + LoadTask(graph); + // return the graph id to backend + auto graph_id = graph->graph_id(); + MS_LOG(INFO) << "Compile graph " << graph_id << " success"; + return graph_id; +} + void AscendSession::BuildGraph(GraphId graph_id) { MS_LOG(INFO) << "start"; auto graph = GetGraph(graph_id); diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 4ab7797257..4823d292a4 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -42,6 +42,7 @@ class AscendSession : public SessionBasic { context_ = std::make_shared(kAscendDevice, device_id); } GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; + GraphId CompileGraph(const FuncGraphPtr &func_graph) override; void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; void BuildGraph(GraphId) override; void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, @@ -92,6 +93,14 @@ class AscendSession : public SessionBasic { void SetFinalGraphOutput(const ValuePtr &value); void SetFinalGraphOutput(const VectorRef &vec_output); + void SplitSwitch(KernelGraph *graph) {} + void LinkChildGraphs(KernelGraph *graph) {} + void IRFusion(const KernelGraphPtr &graph) {} + void SelectKernelGraphKernel(const KernelGraph &graph) {} + void ConvertPredictModel(const KernelGraphPtr graph) {} + void HardwareOptimizeGraphs(const KernelGraphPtr graph) {} + void RootGraphExecutorValidate(KernelGraph *graph) {} + // merge execution order list of child graphs void MergeGraphExecOrder(); // insert assion op to sync data bettween different graphs diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 24b30b233b..8bdb955f79 100755 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -580,5 +580,23 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() { AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get()); } } + +void KernelGraph::UpdateChildGraphOrder() {} + +std::vector> KernelGraph::GetLeafGraphOrder() { + std::vector> leaf_graph_order; + if (IsLeafGraph()) { + leaf_graph_order.push_back(shared_from_this()->cast()); + } else { + for (const auto &child_graph : child_graph_order_) { + MS_EXCEPTION_IF_NULL(child_graph); + auto child_leaf_graph_order = child_graph->GetLeafGraphOrder(); + std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order)); + } + } + return leaf_graph_order; +} + +bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); } } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index 3425bde9c2..b0f27635d0 100755 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -99,6 +99,14 @@ class KernelGraph : public FuncGraph { uint32_t stream_distinction_label() { return stream_distinction_label_; } // refresh execute kernel stream label void UpdateExecuteKernelStreamLabel(); + // calculate the leaf graph order of root graph + std::vector> GetLeafGraphOrder(); + // update the child graph order of graph + void UpdateChildGraphOrder(); + // get the child graph of current graph + std::vector> child_graph_order() const { return child_graph_order_; } + // checkout whether current graph is leaf graph + bool IsLeafGraph() const; private: // remove value node form graph @@ -136,6 +144,12 @@ class KernelGraph : public FuncGraph { bool executable_; // valid inputs std::vector valid_inputs_; + + // new members for control sink process + // all child grahs refers to partial node + std::map> node_to_child_graphs_; + // child graph execute order in root graph + std::vector> child_graph_order_; }; } // namespace session using KernelGraphPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index 8d5ecee79c..b3267db6c7 100755 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -494,6 +494,8 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con return graph; } +std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphPtr &) { return nullptr; } + // run graph steps void SessionBasic::LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const { diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index 072226df8f..0d55b185f4 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -57,6 +57,7 @@ class SessionBasic { virtual ~SessionBasic() { summary_callback_ = nullptr; } virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; + virtual GraphId CompileGraph(const FuncGraphPtr &) { return kInvalidGraphId; } // build graph, used to handle multiple child graphs virtual void BuildGraph(GraphId) {} @@ -72,6 +73,7 @@ class SessionBasic { virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); std::shared_ptr ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); + std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph); CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, std::unordered_map *other_graph_cnode);