|
|
|
@@ -1061,13 +1061,13 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) { |
|
|
|
FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) { |
|
|
|
MS_LOG(DEBUG) << "Process ast For, create an if else statement"; |
|
|
|
MS_EXCEPTION_IF_NULL(block); |
|
|
|
// create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT' |
|
|
|
// create statement 'len(xs) < MAX_FOR_LOOP_COUNT' |
|
|
|
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); |
|
|
|
py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER); |
|
|
|
AnfNodePtr iter_node = ParseExprNode(block, iter_obj); |
|
|
|
CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); |
|
|
|
CNodePtr bool_node = block->func_graph()->NewCNode( |
|
|
|
{NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(prim::MAX_FOR_LOOP_COUNT)}); |
|
|
|
CNodePtr bool_node = |
|
|
|
block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(MAX_FOR_LOOP_COUNT)}); |
|
|
|
|
|
|
|
// create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop' |
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info())); |
|
|
|
@@ -1191,7 +1191,12 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o |
|
|
|
py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter"); |
|
|
|
AnfNodePtr iter_node = ParseExprNode(block, iter_obj); |
|
|
|
MS_EXCEPTION_IF_NULL(iter_node); |
|
|
|
CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); |
|
|
|
// Generate node for loop count and convert it to tensor, to make the loop not unroll |
|
|
|
CNodePtr scalar_len = block->func_graph()->NewCNode({op_len, iter_node}); |
|
|
|
auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations"); |
|
|
|
auto scalar_to_tensor_node = block->func_graph()->NewCNode({NewValueNode(scalar_to_tensor)}); |
|
|
|
|
|
|
|
CNodePtr len_iter = block->func_graph()->NewCNode({scalar_to_tensor_node, scalar_len}); |
|
|
|
|
|
|
|
FunctionBlockPtr header_block = |
|
|
|
GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info())); |
|
|
|
@@ -1199,7 +1204,9 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o |
|
|
|
// create loop variable 'i' |
|
|
|
ParameterPtr loop_var = header_block->func_graph()->add_parameter(); |
|
|
|
// create loop condition 'i < len(xs)' |
|
|
|
CNodePtr cond_node = header_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), loop_var, len_iter}); |
|
|
|
auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations"); |
|
|
|
auto less_node = header_block->func_graph()->NewCNode({NewValueNode(prim_less)}); |
|
|
|
CNodePtr cond_node = header_block->func_graph()->NewCNode({less_node, loop_var, len_iter}); |
|
|
|
|
|
|
|
// generate the body of the for statement |
|
|
|
FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info())); |
|
|
|
|