Browse Source

add recursion limit

tags/v1.1.0
yangzhenzhang 5 years ago
parent
commit
92d02b7aff
2 changed files with 13 additions and 7 deletions
  1. +12
    -5
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  2. +1
    -2
      mindspore/ccsrc/frontend/parallel/step_parallel.h

+ 12
- 5
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -429,7 +429,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) {
return false; return false;
} }
if (IsInBlackList(prim)) { if (IsInBlackList(prim)) {
MS_LOG(INFO) << "Parallel don't care node: " << prim->name();
MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name();
return false; return false;
} }
// get_next is not in the forward graph, we need mark the get_next as the forward node // get_next is not in the forward graph, we need mark the get_next as the forward node
@@ -1199,7 +1199,11 @@ std::vector<Shapes> ExtractShape(const CNodePtr &node) {
return shape_all; return shape_all;
} }


std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) {
std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node, int32_t recursion_num) {
if (recursion_num >= RECURSION_LIMIT) {
return std::make_pair(nullptr, 0);
}

MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
FuncGraphPtr func_graph = node->func_graph(); FuncGraphPtr func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
@@ -1221,8 +1225,11 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) {
} }
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) { if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
return node_pair; return node_pair;
} else if (FindParallelCareNode(node_pair.first).first != nullptr) {
return FindParallelCareNode(node_pair.first);
} else {
auto tmp_pair = FindParallelCareNode(node_pair.first, recursion_num + 1);
if (tmp_pair.first != nullptr) {
return tmp_pair;
}
} }
} }
return std::make_pair(nullptr, 0); return std::make_pair(nullptr, 0);
@@ -1233,7 +1240,7 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
FuncGraphManagerPtr manager = graph->manager(); FuncGraphManagerPtr manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
std::pair<AnfNodePtr, int> prim_anf_node_pair = FindParallelCareNode(parameter);
std::pair<AnfNodePtr, int> prim_anf_node_pair = FindParallelCareNode(parameter, 0);
if (prim_anf_node_pair.first != nullptr) { if (prim_anf_node_pair.first != nullptr) {
return prim_anf_node_pair; return prim_anf_node_pair;
} else { } else {


+ 1
- 2
mindspore/ccsrc/frontend/parallel/step_parallel.h View File

@@ -36,6 +36,7 @@ using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
const uint64_t kUSecondInSecond = 1000000; const uint64_t kUSecondInSecond = 1000000;
const int32_t RECURSION_LIMIT = 3;


struct LossNodeInfo { struct LossNodeInfo {
bool has_tuple_getitem = false; bool has_tuple_getitem = false;
@@ -104,8 +105,6 @@ std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const
// Extract shape from anfnode // Extract shape from anfnode
std::vector<Shapes> ExtractShape(const CNodePtr &node); std::vector<Shapes> ExtractShape(const CNodePtr &node);


std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node);

// Find finally sub graph // Find finally sub graph
std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &parameter); std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &parameter);




Loading…
Cancel
Save