Browse Source

Improve performance of finding summary nodes

tags/v0.5.0-beta
Margaret_wangrui 5 years ago
parent
commit
285f225eca
2 changed files with 29 additions and 0 deletions
  1. +7
    -0
      mindspore/ccsrc/session/kernel_graph.h
  2. +22
    -0
      mindspore/ccsrc/session/session_basic.cc

+ 7
- 0
mindspore/ccsrc/session/kernel_graph.h View File

@@ -40,6 +40,7 @@ class KernelGraph : public FuncGraph {
inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
execution_order_ = {};
executable_ = true;
summary_node_exist_ = false;
stream_distinction_label_ = kInvalidDistincLabel;
}
~KernelGraph() override;
@@ -90,6 +91,10 @@ class KernelGraph : public FuncGraph {
bool executable() const { return executable_; }
// set executable of graph
void set_executable(bool executable) { executable_ = executable; }
// set summary_node of graph
void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; }
// check whether exist summary node in graph
bool summary_node_exist() const { return summary_node_exist_; }
// set invalid inputs for control sink
std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
std::vector<bool> valid_inputs() const { return valid_inputs_; }
@@ -172,6 +177,8 @@ class KernelGraph : public FuncGraph {
std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_;
// graph needn't execute
bool executable_;
// exist summary node in graph
bool summary_node_exist_;
// valid inputs
std::vector<bool> valid_inputs_;



+ 22
- 0
mindspore/ccsrc/session/session_basic.cc View File

@@ -291,6 +291,19 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
(void)tab_str.append(any.ToString());
MS_LOG(INFO) << tab_str;
}

bool ExistSummaryNode(const KernelGraph *graph) {
auto ret = graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto all_nodes = DeepLinkedGraphSearch(ret);
for (auto &n : all_nodes) {
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
return true;
}
}
return false;
}
} // namespace

GraphId SessionBasic::graph_sum_ = 0;
@@ -537,6 +550,9 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
graph->set_manager(manager);
}
graph->SetExecOrderByDefault();
if (ExistSummaryNode(graph.get())) {
graph->set_summary_node_exist(true);
}
opt::BackendCommonOptimization(graph);
return graph;
}
@@ -594,6 +610,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
graph->set_manager(manager);
}
graph->SetExecOrderByDefault();
if (ExistSummaryNode(graph.get())) {
graph->set_summary_node_exist(true);
}
return graph;
}

@@ -716,6 +735,9 @@ void SessionBasic::GetSummaryNodes(const KernelGraph *graph,
MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(summary);
if (!graph->summary_node_exist()) {
return;
}
auto apply_list = TopoSort(graph->get_return());
for (auto &n : apply_list) {
MS_EXCEPTION_IF_NULL(n);


Loading…
Cancel
Save