Merge pull request !3450 from fary86/fix_large_for_loop_execute_errortags/v0.7.0-beta
| @@ -185,12 +185,6 @@ inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitiv | |||||
| inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices"); | inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices"); | ||||
| inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape"); | inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape"); | ||||
| // attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll | |||||
| const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; | |||||
| // max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it | |||||
| // will be sunk(i.e. not unrolled) | |||||
| const int MAX_FOR_LOOP_COUNT = 600; | |||||
| class UnpackGraphPrimitive : public Primitive { | class UnpackGraphPrimitive : public Primitive { | ||||
| public: | public: | ||||
| explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) | explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) | ||||
| @@ -108,11 +108,6 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &p | |||||
| auto fb = args_spec_list[2]; | auto fb = args_spec_list[2]; | ||||
| MS_EXCEPTION_IF_NULL(cond); | MS_EXCEPTION_IF_NULL(cond); | ||||
| auto unroll_flag = prim->GetAttr(prim::SWITCH_UNROLL_FLAG); | |||||
| if (unroll_flag != nullptr && GetValue<int>(unroll_flag) == 0) { | |||||
| return tb->Join(fb); | |||||
| } | |||||
| ValuePtr v = cond->GetValueTrack(); | ValuePtr v = cond->GetValueTrack(); | ||||
| MS_EXCEPTION_IF_NULL(v); | MS_EXCEPTION_IF_NULL(v); | ||||
| // for tensor as condition, keeps both true and false branch. | // for tensor as condition, keeps both true and false branch. | ||||
| @@ -298,13 +298,8 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr | |||||
| MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " | MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " | ||||
| << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); | << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); | ||||
| } | } | ||||
| // Here we need set an attribute to primtive 'switch', so we create a new variable instead of global 'kPrimSwitch' | |||||
| auto prim_switch = std::make_shared<Primitive>(prim::kPrimSwitch->name()); | |||||
| if (!unroll_loop) { | |||||
| prim_switch->AddAttr(prim::SWITCH_UNROLL_FLAG, MakeValue(0)); | |||||
| } | |||||
| CNodePtr switch_app = | CNodePtr switch_app = | ||||
| func_graph()->NewCNode({NewValueNode(prim_switch), condNode, NewValueNode(true_block->func_graph()), | |||||
| func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()), | |||||
| NewValueNode(false_block->func_graph())}); | NewValueNode(false_block->func_graph())}); | ||||
| CNodePtr switch_app_new = func_graph()->NewCNode({switch_app}); | CNodePtr switch_app_new = func_graph()->NewCNode({switch_app}); | ||||
| func_graph()->set_output(switch_app_new); | func_graph()->set_output(switch_app_new); | ||||
| @@ -1061,13 +1061,13 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) { | |||||
| FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) { | FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) { | ||||
| MS_LOG(DEBUG) << "Process ast For, create an if else statement"; | MS_LOG(DEBUG) << "Process ast For, create an if else statement"; | ||||
| MS_EXCEPTION_IF_NULL(block); | MS_EXCEPTION_IF_NULL(block); | ||||
| // create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT' | |||||
| // create statement 'len(xs) < MAX_FOR_LOOP_COUNT' | |||||
| AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); | AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); | ||||
| py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER); | py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER); | ||||
| AnfNodePtr iter_node = ParseExprNode(block, iter_obj); | AnfNodePtr iter_node = ParseExprNode(block, iter_obj); | ||||
| CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); | CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); | ||||
| CNodePtr bool_node = block->func_graph()->NewCNode( | |||||
| {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(prim::MAX_FOR_LOOP_COUNT)}); | |||||
| CNodePtr bool_node = | |||||
| block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(MAX_FOR_LOOP_COUNT)}); | |||||
| // create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop' | // create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop' | ||||
| TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info())); | TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info())); | ||||
| @@ -1191,7 +1191,12 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o | |||||
| py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter"); | py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter"); | ||||
| AnfNodePtr iter_node = ParseExprNode(block, iter_obj); | AnfNodePtr iter_node = ParseExprNode(block, iter_obj); | ||||
| MS_EXCEPTION_IF_NULL(iter_node); | MS_EXCEPTION_IF_NULL(iter_node); | ||||
| CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); | |||||
| // Generate node for loop count and convert it to tensor, to make the loop not unroll | |||||
| CNodePtr scalar_len = block->func_graph()->NewCNode({op_len, iter_node}); | |||||
| auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations"); | |||||
| auto scalar_to_tensor_node = block->func_graph()->NewCNode({NewValueNode(scalar_to_tensor)}); | |||||
| CNodePtr len_iter = block->func_graph()->NewCNode({scalar_to_tensor_node, scalar_len}); | |||||
| FunctionBlockPtr header_block = | FunctionBlockPtr header_block = | ||||
| GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info())); | GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info())); | ||||
| @@ -1199,7 +1204,9 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o | |||||
| // create loop variable 'i' | // create loop variable 'i' | ||||
| ParameterPtr loop_var = header_block->func_graph()->add_parameter(); | ParameterPtr loop_var = header_block->func_graph()->add_parameter(); | ||||
| // create loop condition 'i < len(xs)' | // create loop condition 'i < len(xs)' | ||||
| CNodePtr cond_node = header_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), loop_var, len_iter}); | |||||
| auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations"); | |||||
| auto less_node = header_block->func_graph()->NewCNode({NewValueNode(prim_less)}); | |||||
| CNodePtr cond_node = header_block->func_graph()->NewCNode({less_node, loop_var, len_iter}); | |||||
| // generate the body of the for statement | // generate the body of the for statement | ||||
| FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info())); | FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info())); | ||||
| @@ -48,6 +48,10 @@ enum ParseStatusCode : int { | |||||
| PARSE_FAILURE = 0xFF | PARSE_FAILURE = 0xFF | ||||
| }; | }; | ||||
| // max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it | |||||
| // will be sunk(i.e. not unrolled) | |||||
| const int MAX_FOR_LOOP_COUNT = 600; | |||||
| class AstNodeType; | class AstNodeType; | ||||
| class ParseAst; | class ParseAst; | ||||
| @@ -24,7 +24,6 @@ namespace mindspore { | |||||
| REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { | REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { | ||||
| // Define python "MetaFuncGraph_" class | // Define python "MetaFuncGraph_" class | ||||
| (void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_") | (void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_") | ||||
| // .def_readonly(PYTHON_METAFUNCGRAPH_FLAG, &MetaFuncGraph::parse_info_) | |||||
| .def(py::init<std::string &>()); | .def(py::init<std::string &>()); | ||||
| // Define python "FuncGraph" class | // Define python "FuncGraph" class | ||||
| (void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph") | (void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph") | ||||
| @@ -72,7 +72,6 @@ class MetaFuncGraph : public FuncGraphBase { | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| // const bool parse_info_ = true; | |||||
| protected: | protected: | ||||
| template <typename Derived> | template <typename Derived> | ||||