| @@ -66,57 +66,6 @@ AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr & | |||||
| return context; | return context; | ||||
| } | } | ||||
| // Return CNodes set that may contain duplicates from a DAG function graph. | |||||
| // Should check the results' second item as ignored flag before use them, to avoid processing repeatedly. | |||||
| static inline std::vector<std::pair<AnfNodePtr, bool>> SortCNodesContainDup(const AnfNodePtr &root_node) { | |||||
| auto current_func_graph = root_node->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(current_func_graph); | |||||
| std::vector<std::pair<AnfNodePtr, bool>> sorted_nodes; // Record {node, ignored_flag}. | |||||
| std::unordered_map<AnfNodePtr, size_t> checked_cnodes; // Record {node, position_in_sorted_nodes} | |||||
| std::size_t index = 0; | |||||
| sorted_nodes.emplace_back(std::pair(root_node, false)); | |||||
| while (index < sorted_nodes.size()) { | |||||
| auto current = sorted_nodes[index].first; | |||||
| MS_EXCEPTION_IF_NULL(current); | |||||
| auto ignored_flag = sorted_nodes[index].second; | |||||
| if (!ignored_flag && current->isa<CNode>()) { | |||||
| auto &inputs = current->cast<CNodePtr>()->inputs(); | |||||
| for (auto it = inputs.crbegin(); it != inputs.crend(); it++) { | |||||
| AnfNodePtr input = *it; | |||||
| if (input == nullptr || !input->isa<CNode>() || input->func_graph() != current_func_graph) { | |||||
| continue; | |||||
| } | |||||
| auto checked_item = checked_cnodes.find(input); | |||||
| if (checked_item == checked_cnodes.end()) { // Not met before. | |||||
| checked_cnodes.insert({input, sorted_nodes.size()}); | |||||
| sorted_nodes.emplace_back(std::pair(input, false)); | |||||
| } else { // Checked, should update flag and new position. | |||||
| auto pos = checked_item->second; | |||||
| sorted_nodes[pos].second = true; // Set ignore flag. | |||||
| checked_cnodes[input] = sorted_nodes.size(); // Update a new position. | |||||
| sorted_nodes.emplace_back(std::pair(input, false)); // Insert duplicate node into new position. | |||||
| } | |||||
| } | |||||
| } | |||||
| index++; | |||||
| } | |||||
| return sorted_nodes; | |||||
| } | |||||
| // Return CNodes set that root at the bottom. | |||||
| static inline std::vector<AnfNodePtr> SortReverseCNodes(const AnfNodePtr &root_node) { | |||||
| std::vector<AnfNodePtr> res; | |||||
| auto nodes_with_flag = SortCNodesContainDup(root_node); | |||||
| for (auto it = nodes_with_flag.crbegin(); it != nodes_with_flag.crend(); it++) { | |||||
| if (it->second) { // Check ignored flag. | |||||
| continue; | |||||
| } | |||||
| res.emplace_back(it->first); | |||||
| } | |||||
| return res; | |||||
| } | |||||
| EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { | EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { | ||||
| FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); | FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); | ||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| @@ -148,12 +97,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr | |||||
| << MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) | << MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) | ||||
| << ", please call 'context.set_context(max_call_depth=value)' to adjust this value."; | << ", please call 'context.set_context(max_call_depth=value)' to adjust this value."; | ||||
| } | } | ||||
| auto nodes_with_flag = SortCNodesContainDup(func_node); | |||||
| for (auto it = nodes_with_flag.crbegin(); it != nodes_with_flag.crend(); it++) { | |||||
| if (it->second) { // Check ignored flag. | |||||
| continue; | |||||
| } | |||||
| const auto &node = it->first; | |||||
| const std::vector<AnfNodePtr> &all_nodes = TopoSort(func_node); | |||||
| for (const auto &node : all_nodes) { | |||||
| AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); | AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); | ||||
| MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString() | MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString() | ||||
| << ", node_conf: " << node_conf->ToString(); | << ", node_conf: " << node_conf->ToString(); | ||||