|
|
|
@@ -429,7 +429,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (IsInBlackList(prim)) { |
|
|
|
MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); |
|
|
|
MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
// get_next is not in the forward graph, we need mark the get_next as the forward node |
|
|
|
@@ -1199,7 +1199,11 @@ std::vector<Shapes> ExtractShape(const CNodePtr &node) { |
|
|
|
return shape_all; |
|
|
|
} |
|
|
|
|
|
|
|
std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) { |
|
|
|
std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node, int32_t recursion_num) { |
|
|
|
if (recursion_num >= RECURSION_LIMIT) { |
|
|
|
return std::make_pair(nullptr, 0); |
|
|
|
} |
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
FuncGraphPtr func_graph = node->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
@@ -1221,8 +1225,11 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) { |
|
|
|
return node_pair; |
|
|
|
} else if (FindParallelCareNode(node_pair.first).first != nullptr) { |
|
|
|
return FindParallelCareNode(node_pair.first); |
|
|
|
} else { |
|
|
|
auto tmp_pair = FindParallelCareNode(node_pair.first, recursion_num + 1); |
|
|
|
if (tmp_pair.first != nullptr) { |
|
|
|
return tmp_pair; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return std::make_pair(nullptr, 0); |
|
|
|
@@ -1233,7 +1240,7 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode |
|
|
|
MS_EXCEPTION_IF_NULL(parameter); |
|
|
|
FuncGraphManagerPtr manager = graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
std::pair<AnfNodePtr, int> prim_anf_node_pair = FindParallelCareNode(parameter); |
|
|
|
std::pair<AnfNodePtr, int> prim_anf_node_pair = FindParallelCareNode(parameter, 0); |
|
|
|
if (prim_anf_node_pair.first != nullptr) { |
|
|
|
return prim_anf_node_pair; |
|
|
|
} else { |
|
|
|
|