|
|
|
@@ -65,33 +65,57 @@ AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr & |
|
|
|
return context; |
|
|
|
} |
|
|
|
|
|
|
|
static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) { |
|
|
|
auto current_func_graph = ret_node->func_graph(); |
|
|
|
// Return CNodes set that may contain duplicates from a DAG function graph. |
|
|
|
// Should check the results' second item as ignored flag before use them, to avoid processing repeatedly. |
|
|
|
static inline std::vector<std::pair<AnfNodePtr, bool>> SortCNodesContainDup(const AnfNodePtr &root_node) { |
|
|
|
auto current_func_graph = root_node->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(current_func_graph); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> sorted_nodes; |
|
|
|
auto seen = NewSeenGeneration(); |
|
|
|
std::vector<std::pair<AnfNodePtr, bool>> sorted_nodes; // Record {node, ignored_flag}. |
|
|
|
std::unordered_map<AnfNodePtr, size_t> checked_cnodes; // Record {node, position_in_sorted_nodes} |
|
|
|
std::size_t index = 0; |
|
|
|
sorted_nodes.emplace_back(ret_node); |
|
|
|
sorted_nodes.emplace_back(std::pair(root_node, false)); |
|
|
|
while (index < sorted_nodes.size()) { |
|
|
|
auto current = sorted_nodes[index]; |
|
|
|
index++; |
|
|
|
auto current = sorted_nodes[index].first; |
|
|
|
MS_EXCEPTION_IF_NULL(current); |
|
|
|
if (current->isa<CNode>()) { |
|
|
|
auto ignored_flag = sorted_nodes[index].second; |
|
|
|
if (!ignored_flag && current->isa<CNode>()) { |
|
|
|
auto &inputs = current->cast<CNodePtr>()->inputs(); |
|
|
|
for (auto it = inputs.begin(); it != inputs.end(); it++) { |
|
|
|
for (auto it = inputs.crbegin(); it != inputs.crend(); it++) { |
|
|
|
AnfNodePtr input = *it; |
|
|
|
if (input != nullptr && input->isa<CNode>() && input->seen_ != seen && |
|
|
|
input->func_graph() == current_func_graph) { |
|
|
|
sorted_nodes.emplace_back(input); |
|
|
|
input->seen_ = seen; |
|
|
|
if (input == nullptr || !input->isa<CNode>() || input->func_graph() != current_func_graph) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto checked_item = checked_cnodes.find(input); |
|
|
|
if (checked_item == checked_cnodes.end()) { // Not met before. |
|
|
|
checked_cnodes.insert({input, sorted_nodes.size()}); |
|
|
|
sorted_nodes.emplace_back(std::pair(input, false)); |
|
|
|
} else { // Checked, should update flag and new position. |
|
|
|
auto pos = checked_item->second; |
|
|
|
sorted_nodes[pos].second = true; // Set ignore flag. |
|
|
|
checked_cnodes[input] = sorted_nodes.size(); // Update a new position. |
|
|
|
sorted_nodes.emplace_back(std::pair(input, false)); // Insert duplicate node into new position. |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
index++; |
|
|
|
} |
|
|
|
return sorted_nodes; |
|
|
|
} |
|
|
|
|
|
|
|
// Return CNodes set that root at the bottom. |
|
|
|
static inline std::vector<AnfNodePtr> SortReverseCNodes(const AnfNodePtr &root_node) { |
|
|
|
std::vector<AnfNodePtr> res; |
|
|
|
auto nodes_with_flag = SortCNodesContainDup(root_node); |
|
|
|
for (auto it = nodes_with_flag.crbegin(); it != nodes_with_flag.crend(); it++) { |
|
|
|
if (it->second) { // Check ignored flag. |
|
|
|
continue; |
|
|
|
} |
|
|
|
res.emplace_back(it->first); |
|
|
|
} |
|
|
|
return res; |
|
|
|
} |
|
|
|
|
|
|
|
EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { |
|
|
|
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); |
|
|
|
MS_EXCEPTION_IF_NULL(fg); |
|
|
|
@@ -123,9 +147,12 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr |
|
|
|
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) |
|
|
|
<< ", please call 'context.set_context(max_call_depth=value)' to adjust this value."; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node); |
|
|
|
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { |
|
|
|
const auto &node = *it; |
|
|
|
auto nodes_with_flag = SortCNodesContainDup(func_node); |
|
|
|
for (auto it = nodes_with_flag.crbegin(); it != nodes_with_flag.crend(); it++) { |
|
|
|
if (it->second) { // Check ignored flag. |
|
|
|
continue; |
|
|
|
} |
|
|
|
const auto &node = it->first; |
|
|
|
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); |
|
|
|
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString() |
|
|
|
<< ", node_conf: " << node_conf->ToString(); |
|
|
|
|