| @@ -322,6 +322,18 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| return graph_id; | return graph_id; | ||||
| } | } | ||||
| void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||||
| auto graph_order = GetGraphOrder(kernel_graph->graph_id()); | |||||
| for (auto graph_id : graph_order) { | |||||
| auto child_graph = GetGraph(graph_id); | |||||
| if (child_graph->summary_node_exist()) { | |||||
| kernel_graph->set_summary_node_exist(true); | |||||
| return; | |||||
| } | |||||
| } | |||||
| kernel_graph->set_summary_node_exist(false); | |||||
| } | |||||
| void AscendSession::BuildGraph(GraphId graph_id) { | void AscendSession::BuildGraph(GraphId graph_id) { | ||||
| MS_LOG(INFO) << "start"; | MS_LOG(INFO) << "start"; | ||||
| auto graph = GetGraph(graph_id); | auto graph = GetGraph(graph_id); | ||||
| @@ -337,6 +349,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { | |||||
| InsertAllAssigns(); | InsertAllAssigns(); | ||||
| // insert switch and active to child graph | // insert switch and active to child graph | ||||
| MergeSwitchCompile(); | MergeSwitchCompile(); | ||||
| SetFinalGraphSummaryFlag(graph); | |||||
| // OptChildGraphs | // OptChildGraphs | ||||
| auto graph_order = GetGraphOrder(final_graph_id_); | auto graph_order = GetGraphOrder(final_graph_id_); | ||||
| auto &graph_type = GetGraphOrderType(final_graph_id_); | auto &graph_type = GetGraphOrderType(final_graph_id_); | ||||
| @@ -348,6 +361,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { | |||||
| auto child_graph = GetGraph(graph_order[i]); | auto child_graph = GetGraph(graph_order[i]); | ||||
| CompileChildGraph(child_graph); | CompileChildGraph(child_graph); | ||||
| } | } | ||||
| GetSummaryNodes(graph.get()); | |||||
| // merge child graph | // merge child graph | ||||
| MergeGraphExecOrder(); | MergeGraphExecOrder(); | ||||
| } else { | } else { | ||||
| @@ -751,25 +765,26 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) { | |||||
| return final_graph_id_; | return final_graph_id_; | ||||
| } | } | ||||
| void AscendSession::GetSummaryNodes(const KernelGraph *graph, | |||||
| std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) { | |||||
| void AscendSession::GetSummaryNodes(KernelGraph *graph) { | |||||
| MS_LOG(DEBUG) << "Update summary Start"; | MS_LOG(DEBUG) << "Update summary Start"; | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(summary); | |||||
| summary->clear(); | |||||
| // if final graph have no child graph | // if final graph have no child graph | ||||
| auto graph_order_iter = graph_execute_orders_.find(graph->graph_id()); | auto graph_order_iter = graph_execute_orders_.find(graph->graph_id()); | ||||
| if (graph_order_iter == graph_execute_orders_.end()) { | if (graph_order_iter == graph_execute_orders_.end()) { | ||||
| SessionBasic::GetSummaryNodes(graph, summary); | |||||
| SessionBasic::GetSummaryNodes(graph); | |||||
| return; | return; | ||||
| } | } | ||||
| // for every child graph, find summary nodes | // for every child graph, find summary nodes | ||||
| auto summary = graph->summary_nodes(); | |||||
| auto graph_order = GetGraphOrder(graph->graph_id()); | auto graph_order = GetGraphOrder(graph->graph_id()); | ||||
| for (size_t i = 0; i < graph_order.size(); i++) { | for (size_t i = 0; i < graph_order.size(); i++) { | ||||
| auto child_graph = GetGraph(graph_order[i]); | auto child_graph = GetGraph(graph_order[i]); | ||||
| SessionBasic::GetSummaryNodes(child_graph.get(), summary); | |||||
| SessionBasic::GetSummaryNodes(child_graph.get()); | |||||
| auto child_graph_summary = child_graph->summary_nodes(); | |||||
| summary.insert(child_graph_summary.begin(), child_graph_summary.end()); | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size(); | |||||
| graph->set_summary_nodes(summary); | |||||
| MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); | |||||
| } | } | ||||
| AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { | AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { | ||||
| @@ -67,8 +67,7 @@ class AscendSession : public SessionBasic { | |||||
| void SetActive(GraphId, GraphId) override; | void SetActive(GraphId, GraphId) override; | ||||
| // compile child graph when session have multiple child graphs | // compile child graph when session have multiple child graphs | ||||
| void CompileChildGraph(const KernelGraphPtr &child_graph); | void CompileChildGraph(const KernelGraphPtr &child_graph); | ||||
| void GetSummaryNodes(const KernelGraph *graph, | |||||
| std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) override; | |||||
| void GetSummaryNodes(KernelGraph *graph) override; | |||||
| private: | private: | ||||
| void InitRuntimeResource(); | void InitRuntimeResource(); | ||||
| @@ -149,6 +148,7 @@ class AscendSession : public SessionBasic { | |||||
| AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output); | AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output); | ||||
| // sync intial tensors' data to device | // sync intial tensors' data to device | ||||
| void SyncInitialTenosrToDevice(); | void SyncInitialTenosrToDevice(); | ||||
| void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph); | |||||
| // member variables | // member variables | ||||
| // key is final_graph_id,value is child graph execute order of final graph | // key is final_graph_id,value is child graph execute order of final graph | ||||
| @@ -73,7 +73,8 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten | |||||
| kernel_graph->set_execution_order(execution_order); | kernel_graph->set_execution_order(execution_order); | ||||
| NamedSummaryOutputs summary_outputs; | NamedSummaryOutputs summary_outputs; | ||||
| if (enable_summary) { | if (enable_summary) { | ||||
| GetSummaryNodes(kernel_graph.get(), &summary_outputs); | |||||
| GetSummaryNodes(kernel_graph.get()); | |||||
| summary_outputs = kernel_graph->summary_nodes(); | |||||
| runtime_.IncreaseSummaryRefCount(summary_outputs); | runtime_.IncreaseSummaryRefCount(summary_outputs); | ||||
| } | } | ||||
| @@ -142,6 +142,8 @@ class KernelGraph : public FuncGraph { | |||||
| bool get_output_null() { return null_output_; } | bool get_output_null() { return null_output_; } | ||||
| void set_output_null(bool is_output_null) { null_output_ = is_output_null; } | void set_output_null(bool is_output_null) { null_output_ = is_output_null; } | ||||
| void PrintGraphExecuteOrder() const; | void PrintGraphExecuteOrder() const; | ||||
| std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() { return summary_nodes_; } | |||||
| void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; } | |||||
| private: | private: | ||||
| // remove value node form graph | // remove value node form graph | ||||
| @@ -175,6 +177,7 @@ class KernelGraph : public FuncGraph { | |||||
| // record map between ref final output anf with index and ref origin input with index | // record map between ref final output anf with index and ref origin input with index | ||||
| std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_; | std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_; | ||||
| std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_; | std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_; | ||||
| std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_; | |||||
| // graph needn't execute | // graph needn't execute | ||||
| bool executable_; | bool executable_; | ||||
| // exist summary node in graph | // exist summary node in graph | ||||
| @@ -745,13 +745,13 @@ void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { | |||||
| (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); | (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); | ||||
| } | } | ||||
| void SessionBasic::GetSummaryNodes(const KernelGraph *graph, NamedSummaryOutputs *summary) { | |||||
| void SessionBasic::GetSummaryNodes(KernelGraph *graph) { | |||||
| MS_LOG(DEBUG) << "Update summary Start"; | MS_LOG(DEBUG) << "Update summary Start"; | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(summary); | |||||
| if (!graph->summary_node_exist()) { | if (!graph->summary_node_exist()) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto summary = graph->summary_nodes(); | |||||
| auto apply_list = TopoSort(graph->get_return()); | auto apply_list = TopoSort(graph->get_return()); | ||||
| for (auto &n : apply_list) { | for (auto &n : apply_list) { | ||||
| MS_EXCEPTION_IF_NULL(n); | MS_EXCEPTION_IF_NULL(n); | ||||
| @@ -764,14 +764,15 @@ void SessionBasic::GetSummaryNodes(const KernelGraph *graph, NamedSummaryOutputs | |||||
| } | } | ||||
| auto node = cnode->input(kSummaryGetItem); | auto node = cnode->input(kSummaryGetItem); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0); | |||||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); | |||||
| if (!AnfAlgo::IsRealKernel(item_with_index.first)) { | if (!AnfAlgo::IsRealKernel(item_with_index.first)) { | ||||
| MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); | MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); | ||||
| } | } | ||||
| (*summary)[n->fullname_with_scope()] = item_with_index; | |||||
| summary[n->fullname_with_scope()] = item_with_index; | |||||
| } | } | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size(); | |||||
| graph->set_summary_nodes(summary); | |||||
| MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); | |||||
| } | } | ||||
| void SessionBasic::Summary(KernelGraph *graph) { | void SessionBasic::Summary(KernelGraph *graph) { | ||||
| @@ -779,8 +780,8 @@ void SessionBasic::Summary(KernelGraph *graph) { | |||||
| return; | return; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| NamedSummaryOutputs summary_outputs; | |||||
| GetSummaryNodes(graph, &summary_outputs); | |||||
| GetSummaryNodes(graph); | |||||
| auto summary_outputs = graph->summary_nodes(); | |||||
| // do not exist summary node | // do not exist summary node | ||||
| if (summary_outputs.empty()) { | if (summary_outputs.empty()) { | ||||
| return; | return; | ||||
| @@ -93,8 +93,7 @@ class SessionBasic { | |||||
| virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } | virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } | ||||
| virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } | virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } | ||||
| virtual void SetActive(GraphId, GraphId) {} | virtual void SetActive(GraphId, GraphId) {} | ||||
| virtual void GetSummaryNodes(const KernelGraph *graph, | |||||
| std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary); | |||||
| virtual void GetSummaryNodes(KernelGraph *graph); | |||||
| protected: | protected: | ||||
| virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | ||||
| @@ -130,7 +129,7 @@ class SessionBasic { | |||||
| }; | }; | ||||
| using SessionPtr = std::shared_ptr<session::SessionBasic>; | using SessionPtr = std::shared_ptr<session::SessionBasic>; | ||||
| using NamedSummaryOutputs = std::unordered_map<std::string, std::pair<AnfNodePtr, int>>; | |||||
| using NamedSummaryOutputs = std::map<std::string, std::pair<AnfNodePtr, int>>; | |||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H | #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H | ||||