Browse Source

Remove the repeats of inferring and optimize the sorting routine.

Total Renormalizes:
-----
69.05010 --> 62.28941
-----
tags/v0.2.0-alpha
Zhang Qinghua 5 years ago
parent
commit
87714b3c7f
1 changed files with 33 additions and 9 deletions
  1. +33
    -9
      mindspore/ccsrc/pipeline/static_analysis/evaluator.cc

+ 33
- 9
mindspore/ccsrc/pipeline/static_analysis/evaluator.cc View File

@@ -17,6 +17,7 @@
#include "pipeline/static_analysis/evaluator.h"

#include <algorithm>
#include <unordered_set>

#include "ir/func_graph_cloner.h"
#include "pipeline/static_analysis/utils.h"
@@ -61,6 +62,29 @@ AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &
return context;
}

static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) {
std::vector<AnfNodePtr> sorted_nodes;
std::unordered_set<AnfNodePtr> checked_cnodes;
std::size_t index = 0;
sorted_nodes.emplace_back(ret_node);
while (index < sorted_nodes.size()) {
auto current = sorted_nodes[index];
index++;
MS_EXCEPTION_IF_NULL(current);
if (current->isa<CNode>()) {
auto &inputs = current->cast<CNodePtr>()->inputs();
for (auto it = inputs.begin(); it != inputs.end(); it++) {
AnfNodePtr input = *it;
if (input != nullptr && input->isa<CNode>() && checked_cnodes.find(input) == checked_cnodes.end()) {
sorted_nodes.emplace_back(input);
(void)checked_cnodes.insert(input);
}
}
}
}
return sorted_nodes;
}

AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
MS_EXCEPTION_IF_NULL(fg);
@@ -86,20 +110,20 @@ AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const Ab

MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString()
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString();
const std::vector<AnfNodePtr> &all_nodes = TopoSort(func_node);
for (const auto &node : all_nodes) {
AbstractBasePtr ret_base = nullptr;
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
const auto &node = *it;
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString();
AbstractBasePtr base = engine->GetEvaluatedValue(node_conf);
ret_base = engine->GetEvaluatedValue(node_conf);
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString()
<< ", abstract: " << base->ToString();
<< ", abstract: " << ret_base->ToString();
}

AnfNodeConfigPtr ret_conf = engine->MakeConfig(func_node, graph_context_);
AbstractBasePtr base = engine->GetEvaluatedValue(ret_conf);
MS_EXCEPTION_IF_NULL(base);
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " infer end, inferred abstract: " << base->ToString();
return base;
MS_EXCEPTION_IF_NULL(ret_base);
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " infer end, inferred abstract: " << ret_base->ToString();
return ret_base;
}

AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {


Loading…
Cancel
Save