| @@ -322,6 +322,18 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| return graph_id; | |||
| } | |||
| void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| auto graph_order = GetGraphOrder(kernel_graph->graph_id()); | |||
| for (auto graph_id : graph_order) { | |||
| auto child_graph = GetGraph(graph_id); | |||
| if (child_graph->summary_node_exist()) { | |||
| kernel_graph->set_summary_node_exist(true); | |||
| return; | |||
| } | |||
| } | |||
| kernel_graph->set_summary_node_exist(false); | |||
| } | |||
| void AscendSession::BuildGraph(GraphId graph_id) { | |||
| MS_LOG(INFO) << "start"; | |||
| auto graph = GetGraph(graph_id); | |||
| @@ -337,6 +349,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { | |||
| InsertAllAssigns(); | |||
| // insert switch and active to child graph | |||
| MergeSwitchCompile(); | |||
| SetFinalGraphSummaryFlag(graph); | |||
| // OptChildGraphs | |||
| auto graph_order = GetGraphOrder(final_graph_id_); | |||
| auto &graph_type = GetGraphOrderType(final_graph_id_); | |||
| @@ -348,6 +361,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { | |||
| auto child_graph = GetGraph(graph_order[i]); | |||
| CompileChildGraph(child_graph); | |||
| } | |||
| GetSummaryNodes(graph.get()); | |||
| // merge child graph | |||
| MergeGraphExecOrder(); | |||
| } else { | |||
| @@ -751,25 +765,26 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) { | |||
| return final_graph_id_; | |||
| } | |||
| void AscendSession::GetSummaryNodes(const KernelGraph *graph, | |||
| std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) { | |||
| void AscendSession::GetSummaryNodes(KernelGraph *graph) { | |||
| MS_LOG(DEBUG) << "Update summary Start"; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(summary); | |||
| summary->clear(); | |||
| // if final graph have no child graph | |||
| auto graph_order_iter = graph_execute_orders_.find(graph->graph_id()); | |||
| if (graph_order_iter == graph_execute_orders_.end()) { | |||
| SessionBasic::GetSummaryNodes(graph, summary); | |||
| SessionBasic::GetSummaryNodes(graph); | |||
| return; | |||
| } | |||
| // for every child graph, find summary nodes | |||
| auto summary = graph->summary_nodes(); | |||
| auto graph_order = GetGraphOrder(graph->graph_id()); | |||
| for (size_t i = 0; i < graph_order.size(); i++) { | |||
| auto child_graph = GetGraph(graph_order[i]); | |||
| SessionBasic::GetSummaryNodes(child_graph.get(), summary); | |||
| SessionBasic::GetSummaryNodes(child_graph.get()); | |||
| auto child_graph_summary = child_graph->summary_nodes(); | |||
| summary.insert(child_graph_summary.begin(), child_graph_summary.end()); | |||
| } | |||
| MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size(); | |||
| graph->set_summary_nodes(summary); | |||
| MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); | |||
| } | |||
| AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { | |||
| @@ -67,8 +67,7 @@ class AscendSession : public SessionBasic { | |||
| void SetActive(GraphId, GraphId) override; | |||
| // compile child graph when session have multiple child graphs | |||
| void CompileChildGraph(const KernelGraphPtr &child_graph); | |||
| void GetSummaryNodes(const KernelGraph *graph, | |||
| std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) override; | |||
| void GetSummaryNodes(KernelGraph *graph) override; | |||
| private: | |||
| void InitRuntimeResource(); | |||
| @@ -149,6 +148,7 @@ class AscendSession : public SessionBasic { | |||
| AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output); | |||
| // sync intial tensors' data to device | |||
| void SyncInitialTenosrToDevice(); | |||
| void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph); | |||
| // member variables | |||
| // key is final_graph_id,value is child graph execute order of final graph | |||
| @@ -73,7 +73,8 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten | |||
| kernel_graph->set_execution_order(execution_order); | |||
| NamedSummaryOutputs summary_outputs; | |||
| if (enable_summary) { | |||
| GetSummaryNodes(kernel_graph.get(), &summary_outputs); | |||
| GetSummaryNodes(kernel_graph.get()); | |||
| summary_outputs = kernel_graph->summary_nodes(); | |||
| runtime_.IncreaseSummaryRefCount(summary_outputs); | |||
| } | |||
| @@ -142,6 +142,8 @@ class KernelGraph : public FuncGraph { | |||
| bool get_output_null() { return null_output_; } | |||
| void set_output_null(bool is_output_null) { null_output_ = is_output_null; } | |||
| void PrintGraphExecuteOrder() const; | |||
| std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() { return summary_nodes_; } | |||
| void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; } | |||
| private: | |||
| // remove value node form graph | |||
| @@ -175,6 +177,7 @@ class KernelGraph : public FuncGraph { | |||
| // record map between ref final output anf with index and ref origin input with index | |||
| std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_; | |||
| std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_; | |||
| std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_; | |||
| // graph needn't execute | |||
| bool executable_; | |||
| // exist summary node in graph | |||
| @@ -745,13 +745,13 @@ void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { | |||
| (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); | |||
| } | |||
| void SessionBasic::GetSummaryNodes(const KernelGraph *graph, NamedSummaryOutputs *summary) { | |||
| void SessionBasic::GetSummaryNodes(KernelGraph *graph) { | |||
| MS_LOG(DEBUG) << "Update summary Start"; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(summary); | |||
| if (!graph->summary_node_exist()) { | |||
| return; | |||
| } | |||
| auto summary = graph->summary_nodes(); | |||
| auto apply_list = TopoSort(graph->get_return()); | |||
| for (auto &n : apply_list) { | |||
| MS_EXCEPTION_IF_NULL(n); | |||
| @@ -764,14 +764,15 @@ void SessionBasic::GetSummaryNodes(const KernelGraph *graph, NamedSummaryOutputs | |||
| } | |||
| auto node = cnode->input(kSummaryGetItem); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0); | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); | |||
| if (!AnfAlgo::IsRealKernel(item_with_index.first)) { | |||
| MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); | |||
| } | |||
| (*summary)[n->fullname_with_scope()] = item_with_index; | |||
| summary[n->fullname_with_scope()] = item_with_index; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size(); | |||
| graph->set_summary_nodes(summary); | |||
| MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); | |||
| } | |||
| void SessionBasic::Summary(KernelGraph *graph) { | |||
| @@ -779,8 +780,8 @@ void SessionBasic::Summary(KernelGraph *graph) { | |||
| return; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| NamedSummaryOutputs summary_outputs; | |||
| GetSummaryNodes(graph, &summary_outputs); | |||
| GetSummaryNodes(graph); | |||
| auto summary_outputs = graph->summary_nodes(); | |||
| // do not exist summary node | |||
| if (summary_outputs.empty()) { | |||
| return; | |||
| @@ -93,8 +93,7 @@ class SessionBasic { | |||
| virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } | |||
| virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } | |||
| virtual void SetActive(GraphId, GraphId) {} | |||
| virtual void GetSummaryNodes(const KernelGraph *graph, | |||
| std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary); | |||
| virtual void GetSummaryNodes(KernelGraph *graph); | |||
| protected: | |||
| virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| @@ -130,7 +129,7 @@ class SessionBasic { | |||
| }; | |||
| using SessionPtr = std::shared_ptr<session::SessionBasic>; | |||
| using NamedSummaryOutputs = std::unordered_map<std::string, std::pair<AnfNodePtr, int>>; | |||
| using NamedSummaryOutputs = std::map<std::string, std::pair<AnfNodePtr, int>>; | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H | |||