From 87714b3c7f5e334a69e38fce0525b7e9e1d5df4e Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Fri, 10 Apr 2020 13:08:38 +0800 Subject: [PATCH] Remove the repeats of inferring and optimize the sorting routine. Total Renormalizes: ----- 69.05010 --> 62.28941 ----- --- .../pipeline/static_analysis/evaluator.cc | 42 +++++++++++++++---- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc index 9b120f731c..99cb893104 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc @@ -17,6 +17,7 @@ #include "pipeline/static_analysis/evaluator.h" #include +#include #include "ir/func_graph_cloner.h" #include "pipeline/static_analysis/utils.h" @@ -61,6 +62,29 @@ AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr & return context; } +static std::vector FastShadowSort(const AnfNodePtr &ret_node) { + std::vector sorted_nodes; + std::unordered_set checked_cnodes; + std::size_t index = 0; + sorted_nodes.emplace_back(ret_node); + while (index < sorted_nodes.size()) { + auto current = sorted_nodes[index]; + index++; + MS_EXCEPTION_IF_NULL(current); + if (current->isa()) { + auto &inputs = current->cast()->inputs(); + for (auto it = inputs.begin(); it != inputs.end(); it++) { + AnfNodePtr input = *it; + if (input != nullptr && input->isa() && checked_cnodes.find(input) == checked_cnodes.end()) { + sorted_nodes.emplace_back(input); + (void)checked_cnodes.insert(input); + } + } + } + } + return sorted_nodes; +} + AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); MS_EXCEPTION_IF_NULL(fg); @@ -86,20 +110,20 @@ AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const Ab MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString() << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); - const std::vector &all_nodes = TopoSort(func_node); - for (const auto &node : all_nodes) { + AbstractBasePtr ret_base = nullptr; + std::vector nodes = FastShadowSort(func_node); + for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { + const auto &node = *it; AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); - AbstractBasePtr base = engine->GetEvaluatedValue(node_conf); + ret_base = engine->GetEvaluatedValue(node_conf); MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() - << ", abstract: " << base->ToString(); + << ", abstract: " << ret_base->ToString(); } - AnfNodeConfigPtr ret_conf = engine->MakeConfig(func_node, graph_context_); - AbstractBasePtr base = engine->GetEvaluatedValue(ret_conf); - MS_EXCEPTION_IF_NULL(base); - MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " infer end, inferred abstract: " << base->ToString(); - return base; + MS_EXCEPTION_IF_NULL(ret_base); + MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " infer end, inferred abstract: " << ret_base->ToString(); + return ret_base; } AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {