/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pipeline/jit/parse/parse.h" #include #include #include #include #include #include #include "pipeline/jit/parse/resolve.h" #include "frontend/operator/ops.h" #include "pipeline/jit/parse/data_converter.h" #include "frontend/operator/composite/composite.h" #include "utils/context/ms_context.h" #include "debug/trace.h" namespace mindspore { namespace parse { FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) { (void)python_adapter::set_python_scoped(); if (obj == nullptr || py::isinstance(obj)) { MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none"; return nullptr; } auto ast = std::make_shared(obj); bool success = ast->InitParseAstInfo(python_mod_get_parse_method); if (!success) { MS_LOG(ERROR) << "Parse code to ast tree failed."; return nullptr; } auto parser = std::make_shared(ast); FuncGraphPtr func_graph = parser->ParseFuncGraph(); if (func_graph == nullptr) { MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode(); return nullptr; } return func_graph; } // if any mixed precision flag add a cast node after the parameter node. AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { TypePtr dst_type; if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { dst_type = kFloat32; } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { dst_type = kFloat16; } else { return param; } auto cast_helper = prim::kPrimMixedPrecisionCast; auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param}); return cast; } FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr(); Parser::Parser(const std::shared_ptr &ast) : ast_(ast) { errcode_ = PARSE_SUCCESS; BuildMethodMap(); } void Parser::BuildMethodMap() { stmt_method_map_["Return"] = &Parser::ParseReturn; stmt_method_map_["Expr"] = &Parser::ParseExpr; stmt_method_map_["If"] = &Parser::ParseIf; stmt_method_map_["Assign"] = &Parser::ParseAssign; stmt_method_map_["While"] = &Parser::ParseWhile; stmt_method_map_["For"] = &Parser::ParseFor; stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef; stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign; stmt_method_map_["Global"] = &Parser::ParseGlobal; stmt_method_map_["Break"] = &Parser::ParseBreak; stmt_method_map_["Continue"] = &Parser::ParseContinue; stmt_method_map_["Pass"] = &Parser::ParsePass; expr_method_map_["NoneType"] = &Parser::ParseNone; expr_method_map_["BinOp"] = &Parser::ParseBinOp; expr_method_map_["Name"] = &Parser::ParseName; expr_method_map_["Num"] = &Parser::ParseNum; expr_method_map_["Str"] = &Parser::ParseStr; expr_method_map_["NameConstant"] = &Parser::ParseNameConstant; expr_method_map_["Call"] = &Parser::ParseCall; expr_method_map_["IfExp"] = &Parser::ParseIfExp; expr_method_map_["Attribute"] = &Parser::ParseAttribute; expr_method_map_["Compare"] = &Parser::ParseCompare; expr_method_map_["BoolOp"] = &Parser::ParseBoolOp; expr_method_map_["Lambda"] = &Parser::ParseLambda; expr_method_map_["Tuple"] = &Parser::ParseTuple; expr_method_map_["List"] = &Parser::ParseList; expr_method_map_["Subscript"] = &Parser::ParseSubscript; expr_method_map_["Slice"] = &Parser::ParseSlice; expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice; expr_method_map_["Index"] = &Parser::ParseIndex; expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp; expr_method_map_["Dict"] = &Parser::ParseDict; expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis; } void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); } void Parser::InitParserEnvironment(const py::object &obj) { Parser::top_func_graph_ = FuncGraphWeakPtr(); ScopeManager::GetInstance().ClearScope(); (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj); } void Parser::CleanParserResource() { Parser::top_func_graph_ = FuncGraphWeakPtr(); ScopeManager::GetInstance().ClearScope(); } FuncGraphPtr Parser::ParseFuncGraph() { // get ast FunctionDef node py::object node = ast_->GetAstNode(); FunctionBlockPtr pFnBlock = ParseFunction(node); if (errcode() != PARSE_SUCCESS) { MS_LOG(ERROR) << "Parse function error, code is " << errcode(); return nullptr; } RemoveUnnecessaryPhis(); MS_EXCEPTION_IF_NULL(pFnBlock); return pFnBlock->func_graph(); } void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) { py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args"); py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg"); block->func_graph()->set_has_vararg(!py::isinstance(var_arg_node)); py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg"); block->func_graph()->set_has_kwarg(!py::isinstance(kw_arg_node)); py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs"); block->func_graph()->set_kwonlyargs_count(SizeToInt(kwonly_args.size())); MS_EXCEPTION_IF_NULL(ast_); py::list args = ast_->GetArgs(fn_node); for (std::size_t i = 0; i < args.size(); i++) { std::string arg_name = py::cast(args[i].attr("arg")); if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { if (arg_name == "self") { continue; } } TraceManager::DebugTrace(GetLocation(args[i])); auto para_node = std::make_shared(block->func_graph()); MS_EXCEPTION_IF_NULL(para_node); TraceManager::EndTrace(); para_node->set_name(arg_name); para_node->debug_info()->set_name(arg_name); block->func_graph()->add_parameter(para_node); AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block->func_graph(), para_node); block->WriteVariable(arg_name, para_after_cast); MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name; } } void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) { py::list defaults = ast_->GetArgsDefaultValues(fn_node); py::list args = ast_->GetArgs(fn_node); std::vector namelist_for_default_value; std::vector default_values; for (std::size_t i = 0; i < args.size(); i++) { std::string arg_name = py::cast(args[i].attr("arg")); if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { if (arg_name == "self") { continue; } } namelist_for_default_value.push_back(arg_name); if (py::isinstance(defaults[i])) { default_values.push_back(NewValueNode(kNull)); } else { default_values.push_back(ParseExprNode(block, defaults[i])); } } block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values); } ScopePtr Parser::GetScopeForParseFunction() { ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope(); if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj()); if (!py::isinstance(scope_str)) { auto scope_name = py::cast(scope_str); scope = std::make_shared(scope_name); } } return scope; } FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) { ScopePtr scope = GetScopeForParseFunction(); // the node created in the parsefunction context, will inherit the scope created using scope_guard ScopeGuard scope_guard(scope); TraceGuard trace_guard(data_converter::GetObjKey(ast()->obj())[0], GetLocation(node)); FunctionBlockPtr pFunBlock = MakeFunctionBlock(*this); if (block != nullptr) { pFunBlock->AddPrevBlock(block); } else { func_graph_ = pFunBlock->func_graph(); } pFunBlock->Mature(); auto current_fg = pFunBlock->func_graph(); auto function_name = py::cast(python_adapter::GetPyObjAttr(node, "name")); MS_LOG(DEBUG) << "The function name is " << function_name; current_fg->debug_info()->set_name(function_name); MS_EXCEPTION_IF_NULL(ast_); py::list deco_list = node.attr("decorator_list"); if (deco_list.size() > 0) { current_fg->debug_info()->set_deco_location(GetLocation(deco_list)); } bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg); if (ast_->obj() != ast_->function()) { set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg); } if (!set_flag) { MS_LOG(ERROR) << "Set flags failed"; return nullptr; } GenerateArgsNodeForFunction(pFunBlock, node); // when parsing the top graph of construct, save the top graph if (GetTopFuncGraph() == nullptr) { UpdateTopFuncGraph(pFunBlock->func_graph()); } // save the function node to block pFunBlock->WriteVariable(function_name, NewValueNode(current_fg)); py::object funcObj = python_adapter::GetPyObjAttr(node, "body"); (void)ParseStatements(pFunBlock, funcObj); if (current_fg->get_return() == nullptr) { MS_LOG(ERROR) << "Graph return node is null, loc:" << GetLocation(node)->ToString(); errcode_ = PARSE_NO_RETURN; return pFunBlock; } GenerateArgsDefaultValueForFunction(pFunBlock, node); return pFunBlock; } FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr fn_block, const py::object &nodes) { py::int_ pcount = python_adapter::CallPyObjMethod(nodes, "__len__"); size_t count = IntToSize(pcount); MS_LOG(DEBUG) << "The nodes count is " << count; for (size_t i = 0; i < count; i++) { auto node = py::cast(nodes)[i]; TraceManager::DebugTrace(GetLocation(node)); fn_block = ParseStatement(fn_block, node); TraceManager::EndTrace(); // insert appropriate depended items for the function block if it has a return node if (fn_block->func_graph()->get_return() != nullptr) { fn_block->InsertDependItemsBeforeReturn(); // Skip statements after 'return' (or 'break', 'continue'). break; } } return fn_block; } FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) { auto node_type = ast_->GetNodeType(node); // check the node type AstMainType nodeType = node_type->main_type(); if (nodeType != AST_MAIN_TYPE_STMT) { MS_LOG(INFO) << "Node type is error : " << nodeType; return block; } // call the process function std::string node_name = node_type->node_name(); MS_LOG(DEBUG) << "Ast node is " << node_name; if (stmt_method_map_.count(node_name)) { TraceManager::DebugTrace(GetLocation(node)); auto stmt_block = (this->*stmt_method_map_[node_name])(block, node); TraceManager::EndTrace(); return stmt_block; } else { errcode_ = PARSE_NODE_METHOD_UNSUPPORTED; py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); if (location.size() < 2) { MS_LOG(EXCEPTION) << "List size should not be less than 2."; } auto filename = location[0].cast(); auto line_no = location[1].cast(); MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; } } AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast expr"; auto node_type = ast_->GetNodeType(node); // check the node type AstMainType node_main_type = node_type->main_type(); if (node_main_type != AST_MAIN_TYPE_EXPR) { MS_LOG(ERROR) << "Node type is error : " << node_main_type; errcode_ = PARSE_NODE_TYPE_NO_MATCH; return nullptr; } // call the process function std::string node_name = node_type->node_name(); MS_LOG(DEBUG) << "Ast node is " << node_name; if (expr_method_map_.count(node_name)) { TraceManager::DebugTrace(GetLocation(node)); auto expr_node = (this->*expr_method_map_[node_name])(block, node); TraceManager::EndTrace(); return expr_node; } else { errcode_ = PARSE_NODE_METHOD_UNSUPPORTED; py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); auto filename = ret[0].cast(); auto line_no = ret[1].cast(); MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; } } // process the expr statement and expand it // eg: x.append(y) -> x = x.append(y) FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Expr"; // Expr only have value , no target py::tuple expand_info = ast_->CallParserObjMethod(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node); // refer python function expand_expr_statement, expand_info is one of the following: // True, expr.value, x // True, expr.value // False, None, None // check the expand info result auto is_expand = py::cast(expand_info[0]); if (is_expand) { // process the expr statement py::object value_object = expand_info[1]; AnfNodePtr value_node = ParseExprNode(block, value_object); if (py::len(expand_info) == 2) { // add to depend list and insert before output block->AddAutoDepend(value_node); } else { // expand the assign statement py::object target_node = expand_info[2]; WriteAssignVars(block, target_node, value_node); } } return block; } LocationPtr Parser::GetLocation(const py::object &node) const { MS_EXCEPTION_IF_NULL(ast_); py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); if (ret.size() < 5) { MS_LOG(EXCEPTION) << "List size should not be less than 5."; } // refer to Location::Location() for each member of ret: line, column, line_end, column_end. auto location = std::make_shared(ret[0].cast(), ret[1].cast(), ret[2].cast(), ret[3].cast(), ret[4].cast()); return location; } void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block, const FunctionBlockPtr &false_block) { true_block->AddPrevBlock(pre_block); true_block->Mature(); false_block->AddPrevBlock(pre_block); false_block->Mature(); } FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast return"; MS_EXCEPTION_IF_NULL(block); // create return valuenode AnfNodePtr pReturnValueNode = NewValueNode(prim::kPrimReturn); // parse the return Statements value py::object value = python_adapter::GetPyObjAttr(node, "value"); AnfNodePtr pReturnStatementNode = ParseExprNode(block, value); // Create the cnode CNodePtr pReturnCNode = block->func_graph()->NewCNode({pReturnValueNode, pReturnStatementNode}); block->func_graph()->set_return(pReturnCNode); return block; } // Process binary operators,eg: `a + b`, `a | b`, etc. AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast BinOP"; py::object left = python_adapter::GetPyObjAttr(node, "left"); py::object right = python_adapter::GetPyObjAttr(node, "right"); py::object op = python_adapter::GetPyObjAttr(node, "op"); // create left and right ANF node AnfNodePtr left_node = ParseExprNode(block, left); if (left_node == nullptr) { MS_LOG(WARNING) << "DoBinOp process left node failed: " << errcode(); return nullptr; } AnfNodePtr right_node = ParseExprNode(block, right); if (right_node == nullptr) { MS_LOG(WARNING) << "DoBinOp process right node failed:" << errcode(); return nullptr; } // resolve the op AnfNodePtr op_node = block->MakeResolveAstOp(op); // create apply node return block->func_graph()->NewCNode({op_node, left_node, right_node}); } AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Name"; auto name_id = py::cast(python_adapter::GetPyObjAttr(node, "id")); MS_LOG(DEBUG) << "The Name id is " << name_id; TraceGuard trace_guard(GetLocation(node)); if (block->IsGlobalVar(name_id)) { return block->MakeResolveSymbol(name_id); } return block->ReadVariable(name_id); } AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) { MS_LOG(DEBUG) << "Process ast NoneType"; return NewValueNode(kNone); } AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) { MS_LOG(DEBUG) << "Process ast Ellipsis"; return NewValueNode(kEllipsis); } AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) { MS_LOG(DEBUG) << "Process ast Num"; py::object obj = python_adapter::GetPyObjAttr(node, "n"); TraceGuard trace_guard(GetLocation(node)); if (py::isinstance(obj)) { MS_LOG(INFO) << "The Num is int:" << (std::string)py::str(obj); auto data = py::cast(obj); return NewValueNode(data); } else if (py::isinstance(obj)) { MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj); auto data = py::cast(obj); return NewValueNode(data); } else { // no else actually MS_LOG(ERROR) << "Unsupported Num type : " << (std::string)py::str(obj) << GetLocation(node)->ToString(); errcode_ = PARSE_NODE_TYPE_UNKOWN; return nullptr; } } AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) { MS_LOG(DEBUG) << "Process ast Str"; auto str_s = py::cast(python_adapter::GetPyObjAttr(node, "s")); return NewValueNode(str_s); } AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) { MS_LOG(DEBUG) << "Process ast NameConstant"; py::object obj = python_adapter::GetPyObjAttr(node, "value"); TraceGuard trace_guard(GetLocation(node)); if (py::isinstance(obj)) { MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj); auto data = py::cast(obj); return NewValueNode(data); } else if (py::isinstance(obj)) { MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj); return NewValueNode(kNone); } else { // no else actually MS_LOG(ERROR) << "Unsupported NameConstant type: " << (std::string)py::str(obj) << GetLocation(node)->ToString(); errcode_ = PARSE_NODE_TYPE_UNKOWN; return nullptr; } } AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector &element_nodes) { AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); std::vector make_tuple_nodes; make_tuple_nodes.push_back(make_tuple_op); (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes), [](AnfNodePtr arg) -> AnfNodePtr { return arg; }); return block->func_graph()->NewCNode(make_tuple_nodes); } AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) { py::object father_class; if (args.empty()) { father_class = py::none(); } else if (args.size() == 2) { father_class = args[0]; auto arg_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[1]))); if (arg_type != AST_SUB_TYPE_NAME || py::cast(python_adapter::GetPyObjAttr(args[1], "id")) != "self") { MS_EXCEPTION(ArgumentError) << "When call 'super', the second arg should be 'self'."; } } else { MS_EXCEPTION(ArgumentError) << "When call 'super', the args number should be 0 or 2, but got" << args.size() << "."; } py::object target_class_instance = ast()->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast()->obj()); py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance); NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); SymbolPtr symbol = std::make_shared("namespace"); return block->MakeResolve(name_space, symbol); } // process function call, eg : f1(x, y) ... AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Call"; // process function call py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func"); py::list args = python_adapter::GetPyObjAttr(node, "args"); auto arg_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, function_ast_node))); if (arg_type == AST_SUB_TYPE_NAME) { auto name_id = py::cast(python_adapter::GetPyObjAttr(function_ast_node, "id")); if (name_id == "super") { return ParseSuper(block, args); } } AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node); // function call arguments should be passed in as groups and unpacked later using unpack call std::vector packed_arguments; std::vector group_arguments; bool need_unpack_args = ParseArgsInCall(block, args, &packed_arguments, &group_arguments); bool need_unpack_keywords = ParseKeywordsInCall(block, node, &packed_arguments); // if there is stared or keyword argument, unpack may be needed bool need_unpack = need_unpack_args || need_unpack_keywords; return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack); } AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, const std::vector &packed_arguments, const std::vector &group_arguments, bool need_unpack) const { // if there is keyword arguments or starred, using an unpack_call op to unpack the argument if (need_unpack) { std::vector unpack_call_nodes; auto unpack_call_op = NewValueNode(std::make_shared(NAMED_METAGRAPH_UNPACKCALL)); unpack_call_nodes.push_back(unpack_call_op); unpack_call_nodes.push_back(call_function_anf_node); (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), [](AnfNodePtr node) -> AnfNodePtr { return node; }); CNodePtr unpack_call = block->func_graph()->NewCNode(unpack_call_nodes); return unpack_call; } // else there is no keyword arguments and starred, parsed as normal arguments without unpack std::vector func_call_nodes; func_call_nodes.push_back(call_function_anf_node); (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes), [](AnfNodePtr node) -> AnfNodePtr { return node; }); CNodePtr call_anf_node = block->func_graph()->NewCNode(func_call_nodes); return call_anf_node; } bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, std::vector *packed_arguments, std::vector *group_arguments) { bool need_unpack = false; for (size_t i = 0; i < args.size(); i++) { auto arg_node = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[i]))); if (arg_node == AST_SUB_TYPE_STARRED) { if (!group_arguments->empty()) { packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments)); } packed_arguments->push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value"))); group_arguments->clear(); need_unpack = true; } else { group_arguments->push_back(ParseExprNode(block, args[i])); } } if (!group_arguments->empty()) { packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments)); } return need_unpack; } bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, std::vector *packed_arguments) { bool need_unpack = false; py::list keywords = python_adapter::GetPyObjAttr(node, "keywords"); if (!keywords.empty()) { need_unpack = true; std::vector keys; std::vector values; for (size_t index = 0; index < keywords.size(); index++) { auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg"); auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value"); if (py::isinstance(kw_key)) { packed_arguments->push_back(ParseExprNode(block, kw_value)); } else { auto kw_key_c = kw_key.cast(); keys.push_back(NewValueNode(kw_key_c)); values.push_back(ParseExprNode(block, kw_value)); } } auto keys_tuple = GenerateMakeTuple(block, keys); auto values_tuple = GenerateMakeTuple(block, values); auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT); std::vector make_dict_nodes; make_dict_nodes.push_back(make_dict_op); make_dict_nodes.push_back(keys_tuple); make_dict_nodes.push_back(values_tuple); packed_arguments->push_back(block->func_graph()->NewCNode(make_dict_nodes)); } return need_unpack; } // process call attributes of class type define, eg: x.y() AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Attribute"; // process class value,eg: self.xx if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { if (ast_->IsClassMember(node)) { std::string var_name = "self."; std::string attr_name = node.attr("attr").cast(); (void)var_name.append(attr_name); auto attr_obj = ast()->obj().attr(attr_name.c_str()); if (py::hasattr(ast()->obj(), attr_name.c_str()) && (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance(attr_obj) || py::isinstance(attr_obj) || py::isinstance(attr_obj) || py::isinstance(attr_obj) || data_converter::IsCellInstance(attr_obj))) { return block->MakeResolveSymbol(var_name); } else { return block->ReadVariable(var_name); } } } // process the get attr // Use the Primitive replace the operation resolve node (getattr) // because the getattr will eventually be converted to Primitive node AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr); // process the attr body py::object value_body = python_adapter::GetPyObjAttr(node, "value"); AnfNodePtr value_node = ParseExprNode(block, value_body); if (value_node == nullptr) { MS_LOG(WARNING) << "Parse attribute failed"; return nullptr; } // process the node attr auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast(); MS_LOG(DEBUG) << "Attr = " << attr_str; TraceManager::DebugTrace(GetLocation(python_adapter::GetPyObjAttr(node, "attr"))); AnfNodePtr attr_node = NewValueNode(attr_str); TraceManager::EndTrace(); // create the apply node return block->func_graph()->NewCNode({op_node, value_node, attr_node}); } // Process comparison expression : a == b. a > b etc. AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Compare"; // for python comparison ,there may be if x>y>5 , // which there is two ops , but we only support one now py::list ops = python_adapter::GetPyObjAttr(node, "ops"); if (ops.size() > MAX_COMPARISON_OPS_SUPPORTED) { MS_LOG(ERROR) << "MindSpore does not support comparison with operators more than one now, ops size =" << ops.size(); return nullptr; } py::object left = python_adapter::GetPyObjAttr(node, "left"); py::list comparators = python_adapter::GetPyObjAttr(node, "comparators"); AnfNodePtr left_node = ParseExprNode(block, left); AnfNodePtr right_node = ParseExprNode(block, comparators[0]); MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]); return block->func_graph()->NewCNode({op_node, left_node, right_node}); } AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, const py::object &op) { // if there is only one bool op now if (value_list.size() == 1) { AnfNodePtr first_node = ParseExprNode(block, value_list[0]); return first_node; } else { py::object first = value_list[0]; py::list rest; for (size_t i = 1; i < value_list.size(); i++) { rest.append(value_list[i]); } AnfNodePtr first_node = ParseExprNode(block, first); AnfNodePtr rest_node = ProcessBoolOpValueList(block, rest, op); auto op_node = block->MakeResolveAstOp(op); return block->func_graph()->NewCNode({op_node, first_node, rest_node}); } } // Process comparison expression : a and b. a or b . AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast BoolOp"; py::object op_node = python_adapter::GetPyObjAttr(node, "op"); py::list op_values = python_adapter::GetPyObjAttr(node, "values"); return ProcessBoolOpValueList(block, op_values, op_node); } // Process a function def FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast FunctionDef"; FunctionBlockPtr function_block = ParseFunction(node, block); MS_EXCEPTION_IF_NULL(function_block); // get function name py::str name = python_adapter::GetPyObjAttr(node, "name"); std::string function_name = name; ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph()); block->WriteVariable(function_name, valuenode_graph); return block; } // Process a lambda expression . like lambda x,y: x + y AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Lambda"; FunctionBlockPtr func_block = MakeFunctionBlock(*this); func_block->AddPrevBlock(block); func_block->Mature(); // get lambda args py::list args = ast_->GetArgs(node); for (std::size_t i = 0; i < args.size(); i++) { std::string arg = py::cast(args[i].attr("arg")); TraceManager::DebugTrace(GetLocation(args[i])); auto para_node = std::make_shared(func_block->func_graph()); TraceManager::EndTrace(); para_node->debug_info()->set_name(arg); func_block->func_graph()->add_parameter(para_node); func_block->WriteVariable(arg, para_node); MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg; } py::object body_node = python_adapter::GetPyObjAttr(node, "body"); AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node); func_block->func_graph()->set_output(lambda_body_node); ValueNodePtr const_graph = NewValueNode(func_block->func_graph()); return const_graph; } // process a tuple AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Tuple"; MS_EXCEPTION_IF_NULL(block); py::tuple elts = python_adapter::GetPyObjAttr(node, "elts"); if (elts.size() == 0) { auto empty_tuple = std::vector(); return NewValueNode(std::make_shared(empty_tuple)); } std::vector tuple_vec; AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); tuple_vec.emplace_back(make_tuple_op); for (size_t i = 0; i < elts.size(); i++) { AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); tuple_vec.emplace_back(node_ptr); } CNodePtr tuple_app = block->func_graph()->NewCNode(tuple_vec); return tuple_app; } // process a list AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast List"; MS_EXCEPTION_IF_NULL(block); py::tuple elts = python_adapter::GetPyObjAttr(node, "elts"); if (elts.size() == 0) { auto empty_list = std::vector(); return NewValueNode(std::make_shared(empty_list)); } std::vector list_vec; AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST); list_vec.emplace_back(make_list_op); for (size_t i = 0; i < elts.size(); i++) { AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); list_vec.emplace_back(node_ptr); } CNodePtr list_app = block->func_graph()->NewCNode(list_vec); return list_app; } // process a subscript, such as x[y] , node expressed as value[slice] AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Subscript"; MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); py::object value_node = python_adapter::GetPyObjAttr(node, "value"); py::object slice_node = python_adapter::GetPyObjAttr(node, "slice"); AnfNodePtr value = ParseExprNode(block, value_node); AnfNodePtr slice = ParseExprNode(block, slice_node); return block->func_graph()->NewCNode({op_getitem, value, slice}); } // process a slice, get the slice value AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Slice"; MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE); py::object start = python_adapter::GetPyObjAttr(node, "lower"); py::object stop = python_adapter::GetPyObjAttr(node, "upper"); py::object step = python_adapter::GetPyObjAttr(node, "step"); AnfNodePtr start_node = ParseExprNode(block, start); AnfNodePtr stop_node = ParseExprNode(block, stop); AnfNodePtr step_node = ParseExprNode(block, step); return block->func_graph()->NewCNode({op_makeslice, start_node, stop_node, step_node}); } // process a extslice AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast ExtSlice"; MS_EXCEPTION_IF_NULL(block); AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims"); std::vector node_vec; node_vec.emplace_back(make_tuple_op); for (size_t i = 0; i < slice_tuple.size(); i++) { AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]); node_vec.emplace_back(node_ptr); } CNodePtr tuple_conde = block->func_graph()->NewCNode(node_vec); return tuple_conde; } // process a index, get the index number AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Index"; py::object value_node = python_adapter::GetPyObjAttr(node, "value"); return ParseExprNode(block, value_node); } // process a UnaryOp, +a, -b AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast UnaryOp"; py::object op = python_adapter::GetPyObjAttr(node, "op"); MS_EXCEPTION_IF_NULL(block); // resolve the op AnfNodePtr op_node = block->MakeResolveAstOp(op); py::object operand = python_adapter::GetPyObjAttr(node, "operand"); AnfNodePtr operand_node = ParseExprNode(block, operand); return block->func_graph()->NewCNode({op_node, operand_node}); } // process a dict ast node expression AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Dict"; py::list keys = node.attr("keys"); py::list values = node.attr("values"); std::vector key_nodes; std::vector value_nodes; for (size_t i = 0; i < keys.size(); i++) { key_nodes.push_back(ParseExprNode(block, keys[i])); value_nodes.push_back(ParseExprNode(block, values[i])); } auto keys_tuple = GenerateMakeTuple(block, key_nodes); auto values_tuple = GenerateMakeTuple(block, value_nodes); MS_EXCEPTION_IF_NULL(block); auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT); return block->func_graph()->NewCNode({make_dict_op, keys_tuple, values_tuple}); } // process a augment assign such as a += b; FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast AugAssign"; py::object op = python_adapter::GetPyObjAttr(node, "op"); MS_EXCEPTION_IF_NULL(block); // resolve the op AnfNodePtr op_node = block->MakeResolveAstOp(op); py::object target_node = python_adapter::GetPyObjAttr(node, "target"); MS_EXCEPTION_IF_NULL(ast_); auto ast_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, target_node))); AnfNodePtr read_node = nullptr; if (ast_type == AST_SUB_TYPE_NAME) { read_node = ParseName(block, target_node); } else if (ast_->IsClassMember(target_node)) { read_node = ParseAttribute(block, target_node); } else { MS_LOG(EXCEPTION) << "Not supported augassign"; } if (read_node == nullptr) { MS_LOG(EXCEPTION) << "Can not get target node "; } py::object value = python_adapter::GetPyObjAttr(node, "value"); AnfNodePtr value_node = ParseExprNode(block, value); CNodePtr augassign_app = block->func_graph()->NewCNode({op_node, read_node, value_node}); WriteAssignVars(block, target_node, augassign_app); return block; } // process global declaration such as 'global x'; FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Global"; MS_EXCEPTION_IF_NULL(block); py::list vars = python_adapter::GetPyObjAttr(node, "names"); for (auto &item : vars) { block->AddGlobalVar(py::cast(item)); } return block; } // process a if statement FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast If"; py::object test_node = python_adapter::GetPyObjAttr(node, "test"); AnfNodePtr condition_node = ParseExprNode(block, test_node); MS_EXCEPTION_IF_NULL(block); CNodePtr bool_node = block->ForceToBoolNode(condition_node); TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); FunctionBlockPtr true_block = MakeFunctionBlock(*this); TraceManager::EndTrace(); TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); FunctionBlockPtr false_block = MakeFunctionBlock(*this); TraceManager::EndTrace(); MakeConditionBlocks(block, true_block, false_block); TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); FunctionBlockPtr after_block = MakeFunctionBlock(*this); TraceManager::EndTrace(); // process the if-true branch py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode); // if the return_ is set ,it has its own continuation block if (true_end->func_graph()->get_return() == nullptr) { true_end->Jump(after_block, nullptr); } // process the orelse branch py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse"); FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode); // if the return_ is set ,it has its own continuation block if (false_end->func_graph()->get_return() == nullptr) { false_end->Jump(after_block, nullptr); } block->ConditionalJump(bool_node, true_block, false_block); after_block->Mature(); return after_block; } FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast While"; MS_EXCEPTION_IF_NULL(block); MS_LOG(INFO) << "Parse while statement"; TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); FunctionBlockPtr header_block = MakeFunctionBlock(*this); TraceManager::EndTrace(); TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); FunctionBlockPtr body_block = MakeFunctionBlock(*this); TraceManager::EndTrace(); TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); FunctionBlockPtr after_block = MakeFunctionBlock(*this); TraceManager::EndTrace(); body_block->AddPrevBlock(header_block); after_block->AddPrevBlock(header_block); block->Jump(header_block, nullptr); py::object test_node = python_adapter::GetPyObjAttr(node, "test"); AnfNodePtr condition_node = ParseExprNode(header_block, test_node); condition_node = header_block->ForceToWhileCond(condition_node); body_block->Mature(); header_block->ConditionalJump(condition_node, body_block, after_block); // Parse loop body statements with loop context. LoopContext loop_context{&loops_, header_block, nullptr}; py::object body_node = python_adapter::GetPyObjAttr(node, "body"); FunctionBlockPtr after_body = ParseStatements(body_block, body_node); if (after_body->func_graph()->get_return() == nullptr) { after_body->Jump(header_block, nullptr); } header_block->Mature(); after_block->Mature(); auto &end_block = loop_context.EndBlock(); if (end_block) { // end_block exists if we encounter 'break' in loop body. after_block->Jump(end_block, nullptr); end_block->Mature(); return end_block; } // No 'break', no end_block. return after_block; } CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::object &node, const AnfNodePtr &op_iter) { py::object iter_node = python_adapter::GetPyObjAttr(node, "iter"); AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node); return block->func_graph()->NewCNode({op_iter, iter_anf_node}); } CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block, const AnfNodePtr &op_hasnext) { MS_EXCEPTION_IF_NULL(header_block); return header_block->func_graph()->NewCNode({op_hasnext, iter_param}); } FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) { TraceManager::DebugTrace(trace_info); FunctionBlockPtr body_block = MakeFunctionBlock(*this); TraceManager::EndTrace(); return body_block; } // A for loop will generate 3 functions :the test, the body, and the continuation // for x in xs: // body // it is compiled to be following statement // if len(xs) < max_loop_cnt: // ParseForIter() // use iter to implement for loop, which always unroll loop // else: // ParseForLoop() // use loop var to implement for loop, which always sink loop FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast For, create an if else statement"; MS_EXCEPTION_IF_NULL(block); // create statement 'len(xs) < MAX_FOR_LOOP_COUNT' AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER); AnfNodePtr iter_node = ParseExprNode(block, iter_obj); CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); CNodePtr bool_node = block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(MAX_FOR_LOOP_COUNT)}); // create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop' TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); FunctionBlockPtr true_block = MakeFunctionBlock(*this); TraceManager::EndTrace(); TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); FunctionBlockPtr false_block = MakeFunctionBlock(*this); TraceManager::EndTrace(); MakeConditionBlocks(block, true_block, false_block); TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); FunctionBlockPtr after_block = MakeFunctionBlock(*this); TraceManager::EndTrace(); FunctionBlockPtr true_end = ParseForIter(true_block, node); true_end->Jump(after_block, nullptr); FunctionBlockPtr false_end = ParseForLoop(false_block, node); false_end->Jump(after_block, nullptr); block->ConditionalJump(bool_node, true_block, false_block); after_block->Mature(); return after_block; } // A for loop will generate 3 functions :the test, the body, and the continuation // for x in xs: // body // it is compiled to be following statement // it = iter(xs) // while hastnext(it) // x, it = next(it) // body FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast For"; MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER); AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT); AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT); // generate the iterator apply CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter); MS_EXCEPTION_IF_NULL(iter_apply); FunctionBlockPtr header_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); MS_EXCEPTION_IF_NULL(header_block); // generate the hasnext apply which is a condition ParameterPtr iter_param = header_block->func_graph()->add_parameter(); CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext); // generate the body of the for statement FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); MS_EXCEPTION_IF_NULL(body_block); body_block->AddPrevBlock(header_block); // generate the iterator next apply // process as following: `app = next(it); target = app[0]; it = app[1];` CNodePtr app = body_block->func_graph()->NewCNode({op_next, iter_param}); CNodePtr target_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(0)}); py::object target_node = python_adapter::GetPyObjAttr(node, "target"); CNodePtr iter2_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(1)}); WriteAssignVars(body_block, target_node, target_app); // link the variable name with the target auto it_info = std::make_shared(target_app->debug_info()); iter_param->debug_info()->set_trace_info(it_info); iter2_app->debug_info()->set_trace_info(it_info); iter_apply->debug_info()->set_trace_info(it_info); TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); FunctionBlockPtr after_block = MakeFunctionBlock(*this); MS_EXCEPTION_IF_NULL(after_block); TraceManager::EndTrace(); after_block->AddPrevBlock(header_block); block->Jump(header_block, iter_apply); body_block->Mature(); header_block->ConditionalJump(cond_apply, body_block, after_block); // Parse loop body statements with loop context. LoopContext loop_context{&loops_, header_block, iter2_app}; py::object body_node = python_adapter::GetPyObjAttr(node, "body"); FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); if (after_body_block->func_graph()->get_return() == nullptr) { after_body_block->Jump(header_block, iter2_app); } header_block->Mature(); after_block->Mature(); auto &end_block = loop_context.EndBlock(); if (end_block) { // end_block exists if we encounter 'break' in loop body. after_block->Jump(end_block, nullptr); end_block->Mature(); return end_block; } // No 'break', no end_block. return after_block; } // A for loop will generate 3 functions :the test, the body, and the continuation // for x in xs: // body // it is compiled to be following statement // i = 0 // while i < len(xs) // x = xs[i] // i = i + 1 // body FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast For by loop variable"; MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); // get varibale name of 'x' in statement 'for x in xs' py::object target_node = python_adapter::GetPyObjAttr(node, "target"); // create statement 'len(xs)' py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter"); AnfNodePtr iter_node = ParseExprNode(block, iter_obj); MS_EXCEPTION_IF_NULL(iter_node); // Generate node for loop count and convert it to tensor, to make the loop not unroll CNodePtr scalar_len = block->func_graph()->NewCNode({op_len, iter_node}); auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations"); auto scalar_to_tensor_node = block->func_graph()->NewCNode({NewValueNode(scalar_to_tensor)}); CNodePtr len_iter = block->func_graph()->NewCNode({scalar_to_tensor_node, scalar_len}); FunctionBlockPtr header_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); MS_EXCEPTION_IF_NULL(header_block); // create loop variable 'i' ParameterPtr loop_var = header_block->func_graph()->add_parameter(); // create loop condition 'i < len(xs)' auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations"); auto less_node = header_block->func_graph()->NewCNode({NewValueNode(prim_less)}); CNodePtr cond_node = header_block->func_graph()->NewCNode({less_node, loop_var, len_iter}); // generate the body of the for statement FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); MS_EXCEPTION_IF_NULL(body_block); body_block->AddPrevBlock(header_block); // create 'x = xs[i]' CNodePtr target_var = body_block->func_graph()->NewCNode({op_getitem, iter_node, loop_var}); WriteAssignVars(body_block, target_node, target_var); // create 'i = i + 1' CNodePtr loop_var_inc = body_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(1)}); body_block->WriteVariable(loop_var->name(), loop_var_inc); // link the variable name with the target auto it_info = std::make_shared(loop_var_inc->debug_info()); loop_var->debug_info()->set_trace_info(it_info); len_iter->debug_info()->set_trace_info(it_info); TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); FunctionBlockPtr after_block = MakeFunctionBlock(*this); MS_EXCEPTION_IF_NULL(after_block); TraceManager::EndTrace(); after_block->AddPrevBlock(header_block); block->Jump(header_block, NewValueNode(0)); body_block->Mature(); header_block->ConditionalJump(cond_node, body_block, after_block, false); // Parse loop body statements with loop context. LoopContext loop_context{&loops_, header_block, loop_var_inc}; py::object body_node = python_adapter::GetPyObjAttr(node, "body"); FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); if (after_body_block->func_graph()->get_return() == nullptr) { after_body_block->Jump(header_block, loop_var_inc); } header_block->Mature(); after_block->Mature(); auto &end_block = loop_context.EndBlock(); if (end_block) { // end_block exists if we encounter 'break' in loop body. after_block->Jump(end_block, nullptr); end_block->Mature(); return end_block; } // No 'break', no end_block. return after_block; } AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast IfExp"; MS_EXCEPTION_IF_NULL(block); py::object test_node = python_adapter::GetPyObjAttr(node, "test"); AnfNodePtr condition_node = ParseExprNode(block, test_node); CNodePtr bool_node = block->ForceToBoolNode(condition_node); TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); FunctionBlockPtr true_block = MakeFunctionBlock(*this); TraceManager::EndTrace(); TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); FunctionBlockPtr false_block = MakeFunctionBlock(*this); TraceManager::EndTrace(); MakeConditionBlocks(block, true_block, false_block); // process the if-true branch py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode)); AnfNodePtr true_node = ParseExprNode(true_block, bodyNode); // process the orelse branch py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse"); false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode)); AnfNodePtr false_node = ParseExprNode(false_block, orelseNode); true_block->func_graph()->set_output(true_node); false_block->func_graph()->set_output(false_node); // Use the Primitive replace the operation resolve node (switch) // because the switch will eventually be converted to Primitive node CNodePtr switch_app = block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), bool_node, NewValueNode(true_block->func_graph()), NewValueNode(false_block->func_graph())}); std::vector call_graph_nodes{switch_app}; CNodePtr switch_app_call = block->func_graph()->NewCNode(call_graph_nodes); return switch_app_call; } void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) { MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(assigned_node); py::str name = python_adapter::GetPyObjAttr(targ, "id"); std::string name_id = name; assigned_node->debug_info()->set_name(name_id); // set the debug name of the constant graph if (IsValueNode(assigned_node)) { // the value should be graph auto fg = GetValueNode(assigned_node); if (fg->debug_info()->name().empty()) { fg->debug_info()->set_name(name_id); } } block->WriteVariable(name_id, assigned_node); } void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) { MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); py::list items = python_adapter::GetPyObjAttr(targ, "elts"); for (size_t i = 0; i < items.size(); i++) { // Use the Primitive replace the operation resolve node (getitem) // because the getitem will eventually be converted to Primitive node CNodePtr item_apply = block->func_graph()->NewCNode({op_getitem, assigned_node, NewValueNode(static_cast(i))}); py::object elt = items[i]; WriteAssignVars(block, elt, item_apply); } } void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) { // Now only support the self.xx = xxxxx, can't support x.y = xxxx AnfNodePtr target_node = ParseExprNode(block, targ); MS_EXCEPTION_IF_NULL(target_node); std::string attr_name = targ.attr("attr").cast(); std::string var_name = "self."; (void)var_name.append(attr_name); MS_LOG(DEBUG) << "assign " << var_name; // Get targ location info for error printing py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, targ); if (location.size() < 2) { MS_LOG(EXCEPTION) << "List size should not be less than 2."; } auto filename = location[0].cast(); auto line_no = location[1].cast(); // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined, at " << filename << ":" << line_no; } auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); auto obj_type = obj.attr("__class__").attr("__name__"); if (!py::hasattr(obj, "__parameter__")) { MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" << py::str(obj).cast() << "' with type '" << py::str(obj_type).cast() << "' at " << filename << ":" << line_no; } MS_EXCEPTION_IF_NULL(block); block->WriteVariable(var_name, assigned_node); MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString(); block->SetStateAssgin(target_node, var_name); } void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) { MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM); py::object value_obj = python_adapter::GetPyObjAttr(targ, "value"); py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice"); AnfNodePtr value_node = ParseExprNode(block, value_obj); AnfNodePtr slice_node = ParseExprNode(block, slice_obj); CNodePtr setitem_app = block->func_graph()->NewCNode({op_setitem, value_node, slice_node, assigned_node}); // getitem apply should return the sequence data structure itself std::string var_name = ""; if (ast_->IsClassMember(value_obj)) { std::string attr_name = value_obj.attr("attr").cast(); var_name = "self." + attr_name; if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { MS_EXCEPTION(TypeError) << "'" << var_name << "' was not defined in the class '__init__' function."; } auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); auto obj_type = obj.attr("__class__").attr("__name__"); if (!py::hasattr(obj, "__parameter__")) { MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" << py::str(obj).cast() << "' with type '" << py::str(obj_type).cast() << "'."; } } else { var_name = value_obj.attr("id").cast(); } block->WriteVariable(var_name, setitem_app); } void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node) { MS_EXCEPTION_IF_NULL(value_node); MS_LOG(DEBUG) << "Process WriteAssignVars"; auto ast_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, targ))); if (ast_type == AST_SUB_TYPE_NAME) { HandleAssignName(block, targ, value_node); } else if (ast_type == AST_SUB_TYPE_TUPLE) { HandleAssignTuple(block, targ, value_node); } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) { HandleAssignSubscript(block, targ, value_node); } else if (ast_->IsClassMember(targ)) { HandleAssignClassMember(block, targ, value_node); } else { MS_LOG(EXCEPTION) << "Not supported assign type: " << ast_type << " NodeInfo: " << trace::GetDebugInfo(value_node->debug_info()); } } // process a assign statement, such as a =b, a,b = tup FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast assgin"; py::object value_object = python_adapter::GetPyObjAttr(node, "value"); AnfNodePtr value_node = ParseExprNode(block, value_object); py::object targets_object = python_adapter::GetPyObjAttr(node, "targets"); py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__"); size_t count = IntToSize(pcount); MS_LOG(DEBUG) << "The nodes count is " << count; for (size_t i = 0; i < count; i++) { auto target_node = py::cast(targets_object)[i]; WriteAssignVars(block, target_node, value_node); } return block; } FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) { if (loops_.empty()) { // Report error if loop context not set for the 'break' statement. py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); if (location.size() < 2) { MS_LOG(EXCEPTION) << "List size should not be less than 2."; } auto filename = location[0].cast(); auto line_no = location[1].cast(); MS_LOG(EXCEPTION) << "Unexpected 'break' at " << filename << ":" << line_no; } // Get current loop. Loop &loop = loops_.top(); if (loop.end == nullptr) { // Create end_block if it is not existed. TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); loop.end = MakeFunctionBlock(*this); TraceManager::EndTrace(); } // Jump to the end_block. block->Jump(loop.end, nullptr); return block; } FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) { if (loops_.empty()) { // Report error if loop context not set for the 'continue' statement. py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); if (location.size() < 2) { MS_LOG(EXCEPTION) << "List size should not be less than 2."; } auto filename = location[0].cast(); auto line_no = location[1].cast(); MS_LOG(EXCEPTION) << "Unexpected 'continue' at " << filename << ":" << line_no; } // Jump to the header of the loop with iterator called. Loop &loop = loops_.top(); block->Jump(loop.header, loop.iterator); return block; } FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) { // We just bypass 'pass' statement. return block; } AnfNodePtr FindPhis(const std::unordered_map &removable_phis, const AnfNodePtr &node) { const auto &inp = node->cast(); const auto &iter = removable_phis.find(inp); if (iter == removable_phis.end()) { return node; } return FindPhis(removable_phis, iter->second); } void Parser::RemoveUnnecessaryPhis() { // merge all removable phis to one map; std::unordered_map removable_phis; std::vector phis; for (FunctionBlockPtr &block : func_block_list_) { MS_EXCEPTION_IF_NULL(block); removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end()); std::transform(block->removable_phis().begin(), block->removable_phis().end(), std::back_inserter(phis), [](std::pair pair) { return pair.first; }); } if (removable_phis.size() == 0) { return; } auto fg_name = func_graph_->ToString(); auto mng = Manage(func_graph_, false); // replace the nodes // remove from inside to outside for (int idx = SizeToInt(phis.size() - 1); idx >= 0; idx--) { auto phi = phis[IntToSize(idx)]; auto new_node = FindPhis(removable_phis, phi); MS_LOG(DEBUG) << "phi " << phi->DebugString() << " to " << new_node->DebugString(); mng->Replace(phi, new_node); } // remove the parameter for (FunctionBlockPtr &block : func_block_list_) { MS_EXCEPTION_IF_NULL(block); auto &local_removable_phis = block->removable_phis(); if (local_removable_phis.size() == 0) { continue; } auto func_graph = block->func_graph(); auto ¶meters = func_graph->parameters(); std::vector new_parameters(parameters.size()); auto it = std::copy_if( parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](AnfNodePtr param) { return local_removable_phis.find(param->cast()) == local_removable_phis.end(); }); // shrink container to new size new_parameters.resize(std::distance(new_parameters.begin(), it)); func_graph->set_parameters(new_parameters); } for (auto fg : mng->func_graphs()) { fg->ClearAllManagerInfo(); } } // ParseAst class code bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) { // init the type target_type_ = PARSE_TARGET_UNKNOW; // call python parse, get the parser fn module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD); // get the obj type auto type = data_converter::GetObjType(obj_); if (type == RESOLVE_TYPE_FUNCTION) { target_type_ = PARSE_TARGET_FUNCTION; function_ = obj_; } else if (type == RESOLVE_TYPE_METHOD) { // process the method ,need get the method's self obj target_type_ = PARSE_TARGET_METHOD; py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS); if (py::isinstance(method_object)) { MS_LOG(ERROR) << "Get method's self object instance failed."; return false; } target_type_ = PARSE_TARGET_OBJECT_INSTANCE; function_ = obj_; obj_ = method_object; } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) { // obj is class instance, get the method to parse. function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method); if (py::isinstance(function_)) { MS_LOG(ERROR) << "Get obj method function failed."; return false; } target_type_ = PARSE_TARGET_OBJECT_INSTANCE; // check the fn is method auto obj_type = data_converter::GetObjType(function_); if (obj_type != RESOLVE_TYPE_METHOD) { MS_LOG(WARNING) << "Parse method function is invalid."; return false; } } else { MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type; return false; } // call python parse get ast tree parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method); ast_tree_ = python_adapter::CallPyObjMethod(parser_, "parse"); // get fn name and module function_module_ = py::cast(python_adapter::GetPyObjAttr(parser_, "function_module")); function_name_ = py::cast(python_adapter::GetPyObjAttr(parser_, "function_name")); function_filename_ = py::cast(python_adapter::GetPyObjAttr(parser_, "filename")); function_line_offset_ = py::cast(python_adapter::GetPyObjAttr(parser_, "line_offset")); return true; } // Get ast tree node : is the tree bode list[0] py::object ParseAst::GetAstNode() { py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body"); py::object ast_node = tree_body[0]; return ast_node; } py::list ParseAst::GetArgs(const py::object &func_node) { py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS, func_node); return ret; } py::list ParseAst::GetArgsDefaultValues(const py::object &func_node) { py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node); return ret; } AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) { py::list list_value = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_NODE_TYPE, node); if (list_value.size() < 2) { MS_LOG(ERROR) << "The node of python method must has 2 values."; return nullptr; } auto node_name = py::cast(list_value[0]); auto type = AstMainType(py::cast(list_value[1])); return std::make_shared(node, node_name, type); } AstSubType ParseAst::GetOpType(const py::object &node) { auto op_type = AstSubType(python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_AST_TYPE, node).cast()); return op_type; } bool ParseAst::IsClassMember(const py::object &node) { py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node); if (!py::isinstance(ret)) { MS_LOG(ERROR) << "The result of mod function parse, should be bool type."; return false; } return ret.cast(); } bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { MS_LOG(ERROR) << "FuncGraph is null"; return false; } if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) { MS_LOG(DEBUG) << "No flags"; return true; } py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG); for (auto &item : flags) { if (!py::isinstance(item.first)) { MS_LOG(ERROR) << "Type error in flags dict convert"; return false; } auto name = py::cast(item.first); if (py::isinstance(item.second)) { auto value = py::cast(item.second); MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; func_graph->set_flag(name, value); } else if (py::isinstance(item.second)) { auto value = py::cast(item.second); MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; func_graph->set_attr(name, MakeValue(value)); } else { MS_LOG(ERROR) << "Type error in flags/attrs dict convert"; return false; } } return true; } } // namespace parse } // namespace mindspore