Browse Source

!2383 handle summary to adapt new control sink

Merge pull request !2383 from Margaret_wangrui/control_sink_summary
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
aad494a6fc
2 changed files with 19 additions and 5 deletions
  1. +17
    -4
      mindspore/ccsrc/session/ascend_session.cc
  2. +2
    -1
      mindspore/ccsrc/session/ascend_session.h

+ 17
- 4
mindspore/ccsrc/session/ascend_session.cc View File

@@ -751,17 +751,19 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
return final_graph_id_;
}

void AscendSession::GetSummaryNodes(KernelGraph *graph) {
MS_LOG(DEBUG) << "Update summary Start";
void AscendSession::RecurseGetSummaryNodes(KernelGraph *graph,
std::map<std::string, std::pair<AnfNodePtr, int>> *summary) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(summary);
// if final graph have no child graph
auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
if (graph_order_iter == graph_execute_orders_.end()) {
SessionBasic::GetSummaryNodes(graph);
auto summary_nodes = graph->summary_nodes();
(*summary).insert(summary_nodes.begin(), summary_nodes.end());
return;
}
// for every child graph, find summary nodes
auto summary = graph->summary_nodes();
auto graph_order = GetGraphOrder(graph->graph_id());
for (size_t i = 0; i < graph_order.size(); i++) {
auto child_graph = GetGraph(graph_order[i]);
@@ -770,8 +772,19 @@ void AscendSession::GetSummaryNodes(KernelGraph *graph) {
}
SessionBasic::GetSummaryNodes(child_graph.get());
auto child_graph_summary = child_graph->summary_nodes();
summary.insert(child_graph_summary.begin(), child_graph_summary.end());
(*summary).insert(child_graph_summary.begin(), child_graph_summary.end());
RecurseGetSummaryNodes(child_graph.get(), summary);
}
graph->set_summary_nodes(*summary);
}

void AscendSession::GetSummaryNodes(KernelGraph *graph) {
MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph);
auto summary_nodes = graph->summary_nodes();
std::map<std::string, std::pair<AnfNodePtr, int>> summary;
summary.insert(summary_nodes.begin(), summary_nodes.end());
RecurseGetSummaryNodes(graph, &summary);
graph->set_summary_nodes(summary);
MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
}


+ 2
- 1
mindspore/ccsrc/session/ascend_session.h View File

@@ -67,7 +67,8 @@ class AscendSession : public SessionBasic {
void SetActive(GraphId, GraphId) override;
// compile child graph when session have multiple child graphs
void CompileChildGraph(const KernelGraphPtr &child_graph);
void GetSummaryNodes(KernelGraph *graph) override;
void RecurseGetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
void GetSummaryNodes(KernelGraph *graph);

private:
void InitRuntimeResource();


Loading…
Cancel
Save