From d4a3d0fa149fee8df14954892ab21fcc2dbb9ac2 Mon Sep 17 00:00:00 2001 From: fary86 Date: Fri, 24 Jul 2020 19:45:16 +0800 Subject: [PATCH] Fix large for loop execute fail --- mindspore/ccsrc/frontend/operator/ops.h | 6 ------ .../ccsrc/frontend/operator/prim_statement.cc | 5 ----- .../ccsrc/pipeline/jit/parse/function_block.cc | 7 +------ mindspore/ccsrc/pipeline/jit/parse/parse.cc | 17 ++++++++++++----- mindspore/ccsrc/pipeline/jit/parse/parse.h | 4 ++++ mindspore/core/ir/func_graph_py.cc | 1 - mindspore/core/ir/meta_func_graph.h | 1 - 7 files changed, 17 insertions(+), 24 deletions(-) diff --git a/mindspore/ccsrc/frontend/operator/ops.h b/mindspore/ccsrc/frontend/operator/ops.h index 5b73daba11..9f93159a2a 100755 --- a/mindspore/ccsrc/frontend/operator/ops.h +++ b/mindspore/ccsrc/frontend/operator/ops.h @@ -185,12 +185,6 @@ inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared("SparseTensorGetIndices"); inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared("SparseTensorGetDenseShape"); -// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll -const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; -// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it -// will be sunk(i.e. not unrolled) -const int MAX_FOR_LOOP_COUNT = 600; - class UnpackGraphPrimitive : public Primitive { public: explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) diff --git a/mindspore/ccsrc/frontend/operator/prim_statement.cc b/mindspore/ccsrc/frontend/operator/prim_statement.cc index bb421bdf8a..e193ff1dab 100644 --- a/mindspore/ccsrc/frontend/operator/prim_statement.cc +++ b/mindspore/ccsrc/frontend/operator/prim_statement.cc @@ -108,11 +108,6 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &p auto fb = args_spec_list[2]; MS_EXCEPTION_IF_NULL(cond); - auto unroll_flag = prim->GetAttr(prim::SWITCH_UNROLL_FLAG); - if (unroll_flag != nullptr && GetValue(unroll_flag) == 0) { - return tb->Join(fb); - } - ValuePtr v = cond->GetValueTrack(); MS_EXCEPTION_IF_NULL(v); // for tensor as condition, keeps both true and false branch. diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc index 9ccaa46fed..52bc1e0588 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc @@ -298,13 +298,8 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); } - // Here we need set an attribute to primtive 'switch', so we create a new variable instead of global 'kPrimSwitch' - auto prim_switch = std::make_shared(prim::kPrimSwitch->name()); - if (!unroll_loop) { - prim_switch->AddAttr(prim::SWITCH_UNROLL_FLAG, MakeValue(0)); - } CNodePtr switch_app = - func_graph()->NewCNode({NewValueNode(prim_switch), condNode, NewValueNode(true_block->func_graph()), + func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()), NewValueNode(false_block->func_graph())}); CNodePtr switch_app_new = func_graph()->NewCNode({switch_app}); func_graph()->set_output(switch_app_new); diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 510b4cbd24..ca27824d09 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -1061,13 +1061,13 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) { FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast For, create an if else statement"; MS_EXCEPTION_IF_NULL(block); - // create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT' + // create statement 'len(xs) < MAX_FOR_LOOP_COUNT' AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER); AnfNodePtr iter_node = ParseExprNode(block, iter_obj); CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); - CNodePtr bool_node = block->func_graph()->NewCNode( - {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(prim::MAX_FOR_LOOP_COUNT)}); + CNodePtr bool_node = + block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(MAX_FOR_LOOP_COUNT)}); // create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop' TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); @@ -1191,7 +1191,12 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter"); AnfNodePtr iter_node = ParseExprNode(block, iter_obj); MS_EXCEPTION_IF_NULL(iter_node); - CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); + // Generate node for loop count and convert it to tensor, to make the loop not unroll + CNodePtr scalar_len = block->func_graph()->NewCNode({op_len, iter_node}); + auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations"); + auto scalar_to_tensor_node = block->func_graph()->NewCNode({NewValueNode(scalar_to_tensor)}); + + CNodePtr len_iter = block->func_graph()->NewCNode({scalar_to_tensor_node, scalar_len}); FunctionBlockPtr header_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); @@ -1199,7 +1204,9 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o // create loop variable 'i' ParameterPtr loop_var = header_block->func_graph()->add_parameter(); // create loop condition 'i < len(xs)' - CNodePtr cond_node = header_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), loop_var, len_iter}); + auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations"); + auto less_node = header_block->func_graph()->NewCNode({NewValueNode(prim_less)}); + CNodePtr cond_node = header_block->func_graph()->NewCNode({less_node, loop_var, len_iter}); // generate the body of the for statement FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.h b/mindspore/ccsrc/pipeline/jit/parse/parse.h index dc0b43c3a6..afb72ba5c9 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.h @@ -48,6 +48,10 @@ enum ParseStatusCode : int { PARSE_FAILURE = 0xFF }; +// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it +// will be sunk(i.e. not unrolled) +const int MAX_FOR_LOOP_COUNT = 600; + class AstNodeType; class ParseAst; diff --git a/mindspore/core/ir/func_graph_py.cc b/mindspore/core/ir/func_graph_py.cc index f6bb419a10..cdddb7b08d 100644 --- a/mindspore/core/ir/func_graph_py.cc +++ b/mindspore/core/ir/func_graph_py.cc @@ -24,7 +24,6 @@ namespace mindspore { REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { // Define python "MetaFuncGraph_" class (void)py::class_>(*m, "MetaFuncGraph_") - // .def_readonly(PYTHON_METAFUNCGRAPH_FLAG, &MetaFuncGraph::parse_info_) .def(py::init()); // Define python "FuncGraph" class (void)py::class_(*m, "FuncGraph") diff --git a/mindspore/core/ir/meta_func_graph.h b/mindspore/core/ir/meta_func_graph.h index 3193c33926..05743bd6a4 100644 --- a/mindspore/core/ir/meta_func_graph.h +++ b/mindspore/core/ir/meta_func_graph.h @@ -72,7 +72,6 @@ class MetaFuncGraph : public FuncGraphBase { return false; } } - // const bool parse_info_ = true; protected: template