From c2bca5bf4fff126443943a45d00a600d8d4677cb Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Mon, 25 May 2020 11:44:24 +0800 Subject: [PATCH] fix summary nodes in child graph --- mindspore/ccsrc/session/ascend_session.cc | 21 +++++++ mindspore/ccsrc/session/ascend_session.h | 2 + mindspore/ccsrc/session/session_basic.cc | 75 ++++++++++------------- mindspore/ccsrc/session/session_basic.h | 2 + 4 files changed, 56 insertions(+), 44 deletions(-) diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 9fe9fc9f4b..9b149b1c8c 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -725,6 +725,27 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector &args) { return final_graph_id_; } +void AscendSession::GetSummaryNodes(const KernelGraph *graph, + std::unordered_map> *summary) { + MS_LOG(DEBUG) << "Update summary Start"; + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(summary); + summary->clear(); + // if final graph have no child graph + auto graph_order_iter = graph_execute_orders_.find(graph->graph_id()); + if (graph_order_iter == graph_execute_orders_.end()) { + SessionBasic::GetSummaryNodes(graph, summary); + return; + } + // for every child graph, find summary nodes + auto graph_order = GetGraphOrder(graph->graph_id()); + for (size_t i = 0; i < graph_order.size(); i++) { + auto child_graph = GetGraph(graph_order[i]); + SessionBasic::GetSummaryNodes(child_graph.get(), summary); + } + MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size(); +} + AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { auto fake_graph = GetGraph(fake_graph_id); auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0); diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index d8b60cf3b3..a9824e680b 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -67,6 +67,8 @@ class AscendSession : public SessionBasic { void SetActive(GraphId, GraphId) override; // compile child graph when session have multiple child graphs void CompileChildGraph(const KernelGraphPtr &child_graph); + void GetSummaryNodes(const KernelGraph *graph, + std::unordered_map> *summary) override; private: void InitRuntimeResource(); diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index db6257c815..5fc8b16c08 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -54,46 +54,6 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) { return py_param.ptr(); } -void GetSummaryNodes(const KernelGraph *graph, std::unordered_map> *summary) { - MS_LOG(DEBUG) << "Update summary Start"; - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(summary); - summary->clear(); - auto apply_list = TopoSort(graph->get_return()); - for (auto &n : apply_list) { - MS_EXCEPTION_IF_NULL(n); - if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || - IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { - auto cnode = n->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() <= kSummaryGetItem) { - MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!"; - } - auto node = cnode->input(kSummaryGetItem); - MS_EXCEPTION_IF_NULL(node); - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0); - if (!AnfAlgo::IsRealKernel(item_with_index.first)) { - MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); - } - (*summary)[n->fullname_with_scope()] = item_with_index; - } - } - MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size(); -} - -bool ExistSummaryNode(const KernelGraph *graph) { - auto ret = graph->get_return(); - MS_EXCEPTION_IF_NULL(ret); - auto all_nodes = DeepLinkedGraphSearch(ret); - for (auto &n : all_nodes) { - if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || - IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { - return true; - } - } - return false; -} - BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph, const std::vector &input_tensors) { MS_EXCEPTION_IF_NULL(node); @@ -751,17 +711,44 @@ void SessionBasic::Reorder(std::vector *node_list) { (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); } +void SessionBasic::GetSummaryNodes(const KernelGraph *graph, + std::unordered_map> *summary) { + MS_LOG(DEBUG) << "Update summary Start"; + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(summary); + auto apply_list = TopoSort(graph->get_return()); + for (auto &n : apply_list) { + MS_EXCEPTION_IF_NULL(n); + if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || + IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { + auto cnode = n->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() <= kSummaryGetItem) { + MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!"; + } + auto node = cnode->input(kSummaryGetItem); + MS_EXCEPTION_IF_NULL(node); + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0); + if (!AnfAlgo::IsRealKernel(item_with_index.first)) { + MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); + } + (*summary)[n->fullname_with_scope()] = item_with_index; + } + } + MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size(); +} + void SessionBasic::Summary(KernelGraph *graph) { if (summary_callback_ == nullptr) { return; } MS_EXCEPTION_IF_NULL(graph); - bool exist_summary = ExistSummaryNode(graph); - if (!exist_summary) { - return; - } std::unordered_map> summary_outputs; GetSummaryNodes(graph, &summary_outputs); + // do not exist summary node + if (summary_outputs.empty()) { + return; + } std::map params_list; // fetch outputs apply kernel in session & run callback functions for (auto &output_item : summary_outputs) { diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index 2719c9b67d..379a4e96b2 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -92,6 +92,8 @@ class SessionBasic { virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } virtual void SetActive(GraphId, GraphId) {} + virtual void GetSummaryNodes(const KernelGraph *graph, + std::unordered_map> *summary); protected: virtual void LoadInputData(const std::shared_ptr &kernel_graph,