/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pipeline/jit/parse/parse.h" #include #include #include #include #include #include "utils/hash_map.h" #include "pybind_api/pybind_patch.h" #include "pipeline/jit/parse/resolve.h" #include "pipeline/jit/parse/data_converter.h" #include "frontend/operator/ops.h" #include "frontend/operator/composite/composite.h" #include "utils/ms_context.h" #include "utils/interpret_node_recorder.h" #include "debug/trace.h" #include "mindspore/core/ir/cell.h" namespace mindspore { namespace parse { FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) { (void)python_adapter::set_python_scoped(); if (!obj || py::isinstance(obj)) { MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none"; return nullptr; } auto ast = std::make_shared(obj); bool success = ast->InitParseAstInfo(python_mod_get_parse_method); if (!success) { MS_LOG(ERROR) << "Parse code to ast tree failed."; return nullptr; } auto parser = std::make_shared(ast); FuncGraphPtr func_graph = parser->ParseFuncGraph(); if (func_graph == nullptr) { MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode(); return nullptr; } return func_graph; } TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { return kFloat32; } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { return kFloat16; } else { return nullptr; } } // If any mixed precision flag add a cast node after the parameter node. AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { MS_EXCEPTION_IF_NULL(func_graph); TypePtr dst_type; if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { dst_type = kFloat32; } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { dst_type = kFloat16; } else { return param; } auto cast_helper = prim::kPrimMixedPrecisionCast; auto cast = func_graph->NewCNodeAfter(param, {NewValueNode(cast_helper), NewValueNode(dst_type), param}); return cast; } FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr(); Parser::Parser(const std::shared_ptr &ast) : ast_(ast) { max_for_loop_count_str_ = common::GetEnv("ENV_FOR_TO_WHILE_LOOP"); support_fallback_ = common::GetEnv("ENV_SUPPORT_FALLBACK"); errcode_ = PARSE_SUCCESS; BuildMethodMap(); } void Parser::BuildMethodMap() { stmt_method_map_["Return"] = &Parser::ParseReturn; stmt_method_map_["Expr"] = &Parser::ParseExpr; stmt_method_map_["If"] = &Parser::ParseIf; stmt_method_map_["Assign"] = &Parser::ParseAssign; stmt_method_map_["While"] = &Parser::ParseWhile; stmt_method_map_["For"] = &Parser::ParseFor; stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef; stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign; stmt_method_map_["Global"] = &Parser::ParseGlobal; stmt_method_map_["Break"] = &Parser::ParseBreak; stmt_method_map_["Continue"] = &Parser::ParseContinue; stmt_method_map_["Pass"] = &Parser::ParsePass; expr_method_map_["NoneType"] = &Parser::ParseNone; expr_method_map_["BinOp"] = &Parser::ParseBinOp; expr_method_map_["Name"] = &Parser::ParseName; expr_method_map_["Num"] = &Parser::ParseNum; expr_method_map_["Str"] = &Parser::ParseStr; expr_method_map_["Constant"] = &Parser::ParseConstant; expr_method_map_["NameConstant"] = &Parser::ParseNameConstant; expr_method_map_["Call"] = &Parser::ParseCall; expr_method_map_["IfExp"] = &Parser::ParseIfExp; expr_method_map_["Attribute"] = &Parser::ParseAttribute; expr_method_map_["Compare"] = &Parser::ParseCompare; expr_method_map_["BoolOp"] = &Parser::ParseBoolOp; expr_method_map_["Lambda"] = &Parser::ParseLambda; expr_method_map_["Tuple"] = &Parser::ParseTuple; expr_method_map_["List"] = &Parser::ParseList; expr_method_map_["Subscript"] = &Parser::ParseSubscript; expr_method_map_["Slice"] = &Parser::ParseSlice; expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice; expr_method_map_["Index"] = &Parser::ParseIndex; expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp; expr_method_map_["Dict"] = &Parser::ParseDict; expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis; expr_method_map_["ListComp"] = &Parser::ParseListComp; expr_method_map_["GeneratorExp"] = &Parser::ParseListComp; // We treat 'GeneratorExp' the same as 'ListComp'. } void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); } void Parser::InitParserEnvironment(const py::object &obj) { Parser::top_func_graph_ = FuncGraphWeakPtr(); ScopeManager::GetInstance().ClearScope(); (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj); } void Parser::CleanParserResource() { Parser::top_func_graph_ = FuncGraphWeakPtr(); ScopeManager::GetInstance().ClearScope(); } void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr &ast) { // Check whether the functions referred by this function and itself are missing 'return' statement auto mng = Manage(fn, false); MS_EXCEPTION_IF_NULL(ast); for (const auto &func_graph : mng->func_graphs()) { MS_EXCEPTION_IF_NULL(func_graph); if (func_graph->get_return() != nullptr) { continue; } py::object node = ast->GetAstNode(); py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); constexpr auto min_list_size = 2; if (ret.size() < min_list_size) { MS_LOG(EXCEPTION) << "list size:" << ret.size() << " is less than 2."; } py::str desc = python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]); MS_EXCEPTION(TypeError) << "Function must has 'return' statement, but missing in " << desc.cast() << ". FuncGraph: " << func_graph->ToString(); } } FuncGraphPtr Parser::ParseFuncGraph() { // Get ast FunctionDef node py::object node = ast_->GetAstNode(); constexpr char function_def_name[] = "FunctionDef"; constexpr char lambda_name[] = "Lambda"; FunctionBlockPtr fn_block = nullptr; if (ast_->GetNodeType(node)->node_name() == function_def_name) { fn_block = ParseDefFunction(node); } else { auto lambda_node = python_adapter::GetPyObjAttr(node, "value"); if (py::isinstance(lambda_node) || ast_->GetNodeType(lambda_node)->node_name() != lambda_name) { MS_EXCEPTION(TypeError) << "Parse Lambda Function Fail. Node type must be Lambda, but got " << ast_->GetNodeType(lambda_node)->node_name() << "."; } fn_block = ParseLambdaFunction(lambda_node); } if (errcode() != PARSE_SUCCESS) { MS_LOG(ERROR) << "Parse function error, code is " << errcode(); return nullptr; } RemoveUnnecessaryPhis(); MS_EXCEPTION_IF_NULL(fn_block); CheckFuncReturn(fn_block->func_graph(), ast_); return fn_block->func_graph(); } void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) { py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args"); py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg"); MS_EXCEPTION_IF_NULL(block); auto block_fg = block->func_graph(); block_fg->set_has_vararg(!py::isinstance(var_arg_node)); py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg"); block_fg->set_has_kwarg(!py::isinstance(kw_arg_node)); py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs"); block_fg->set_kwonlyargs_count(SizeToInt(kwonly_args.size())); MS_EXCEPTION_IF_NULL(ast_); py::list args = ast_->GetArgs(fn_node); for (std::size_t i = 0; i < args.size(); i++) { std::string arg_name = py::cast(args[i].attr("arg")); if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { if (arg_name == "self") { continue; } } TraceGuard guard(GetLocation(args[i])); auto para_node = std::make_shared(block_fg); MS_EXCEPTION_IF_NULL(para_node); para_node->set_name(arg_name); para_node->debug_info()->set_name(arg_name); block_fg->add_parameter(para_node); AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block_fg, para_node); MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name; block->WriteVariable(arg_name, para_after_cast); } } void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) { MS_EXCEPTION_IF_NULL(block); py::list defaults = ast_->GetArgsDefaultValues(fn_node); py::list args = ast_->GetArgs(fn_node); std::vector namelist_for_default_value; std::vector default_values; for (std::size_t i = 0; i < args.size(); i++) { std::string arg_name = py::cast(args[i].attr("arg")); if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { if (arg_name == "self") { continue; } } namelist_for_default_value.push_back(arg_name); if (i >= defaults.size()) { MS_LOG(EXCEPTION) << "Index:" << i << " out of range:" << defaults.size(); } if (py::isinstance(defaults[i])) { default_values.push_back(NewValueNode(kNull)); } else { default_values.push_back(ParseExprNode(block, defaults[i])); } } block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values); } ScopePtr Parser::GetScopeForParseFunction() { ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope(); if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj()); if (!py::isinstance(scope_str)) { auto scope_name = py::cast(scope_str); scope = std::make_shared(scope_name); } } return scope; } FunctionBlockPtr Parser::ParseDefFunction(const py::object &node, const FunctionBlockPtr &block) { ScopePtr scope = GetScopeForParseFunction(); // The node created in the parsefunction context, will inherit the scope created using scope_guard ScopeGuard scope_guard(scope); TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node)); FunctionBlockPtr func_block = MakeFunctionBlock(*this); if (block != nullptr) { func_block->AddPrevBlock(block); } else { func_graph_ = func_block->func_graph(); } func_block->Mature(); auto current_fg = func_block->func_graph(); auto function_name = py::cast(python_adapter::GetPyObjAttr(node, "name")); MS_LOG(DEBUG) << "The function name is " << function_name; current_fg->debug_info()->set_name(function_name); MS_EXCEPTION_IF_NULL(ast_); py::list deco_list = node.attr("decorator_list"); if (!deco_list.empty()) { current_fg->debug_info()->set_deco_location(GetLocation(deco_list)); } bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg); if (!ast_->obj().is(ast_->function())) { set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg); } if (!set_flag) { MS_LOG(ERROR) << "Set flags failed"; return nullptr; } GenerateArgsNodeForFunction(func_block, node); // When parsing the top graph of construct, save the top graph if (GetTopFuncGraph() == nullptr) { UpdateTopFuncGraph(func_block->func_graph()); } // Save the function node to block func_block->WriteVariable(function_name, NewValueNode(current_fg)); py::object funcObj = python_adapter::GetPyObjAttr(node, "body"); (void)ParseStatements(func_block, funcObj); // Add unused variables as isolate nodes. for (auto &func_block_item : func_block_list_) { MS_EXCEPTION_IF_NULL(func_block_item); if (func_block_item->func_graph()->get_return() != nullptr) { // Find unused variables. func_block_item->FindIsolatedNodes(); // Attach all isolated nodes. func_block_item->AttachIsolatedNodesBeforeReturn(); } } if (current_fg->get_return() == nullptr) { py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, node, ret[0], ret[1]); MS_EXCEPTION(TypeError) << "Function must has 'return' statement, but missing in " << desc.cast() << "."; } GenerateArgsDefaultValueForFunction(func_block, node); return func_block; } FunctionBlockPtr Parser::ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block) { MS_EXCEPTION_IF_NULL(ast_); ScopePtr scope = GetScopeForParseFunction(); ScopeGuard scope_guard(scope); TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node)); FunctionBlockPtr func_block = MakeFunctionBlock(*this); if (block != nullptr) { func_block->AddPrevBlock(block); } else { func_graph_ = func_block->func_graph(); } func_block->Mature(); auto current_fg = func_block->func_graph(); auto function_name = ast_->function_name(); MS_LOG(DEBUG) << "The function name is " << function_name; current_fg->debug_info()->set_name(function_name); GenerateArgsNodeForFunction(func_block, node); py::object body_node = python_adapter::GetPyObjAttr(node, "body"); AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node); current_fg->set_output(lambda_body_node); GenerateArgsDefaultValueForFunction(func_block, node); return func_block; } FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) { auto node_list = py::cast(nodes); size_t count = py::len(node_list); MS_LOG(DEBUG) << "The nodes count is " << count; for (size_t i = 0; i < count; ++i) { MS_LOG(DEBUG) << "Start parse statement[" << i << "]: " << py::str(node_list[i]); auto node = node_list[i]; block = ParseStatement(block, node); MS_EXCEPTION_IF_NULL(block); // Insert appropriate depended items for the function block if it has a return node if (block->func_graph()->get_return() != nullptr || block->is_dead_block()) { // If break is not the last expr. if (i != count - 1) { TraceGuard trace_guard(GetLocation(node_list[i + 1])); MS_LOG(EXCEPTION) << "Dead code exist, please remove it."; } // Skip statements after 'return' (or 'break', 'continue'). break; } } return block; } FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) { TraceGuard trace_guard(GetLocation(node)); auto node_type = ast_->GetNodeType(node); // Check the node type AstMainType nodeType = node_type->main_type(); if (nodeType != AST_MAIN_TYPE_STMT) { MS_LOG(INFO) << "Node type is error : " << nodeType; return block; } // Call the process function std::string node_name = node_type->node_name(); MS_LOG(DEBUG) << "Ast node is " << node_name; if (stmt_method_map_.count(node_name)) { auto stmt_block = (this->*stmt_method_map_[node_name])(block, node); TraceManager::ClearParseOrResolveDebugInfo(); return stmt_block; } else { errcode_ = PARSE_NODE_METHOD_UNSUPPORTED; MS_LOG(EXCEPTION) << "Unsupported statement '" << node_name << "'.\nMore details please refer to syntax support at https://www.mindspore.cn"; } } AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast expr."; TraceGuard trace_guard(GetLocation(node)); auto node_type = ast_->GetNodeType(node); // Check the node type AstMainType node_main_type = node_type->main_type(); if (node_main_type != AST_MAIN_TYPE_EXPR) { errcode_ = PARSE_NODE_TYPE_NO_MATCH; MS_LOG(EXCEPTION) << "Node type is error : " << node_main_type; } // Call the process function std::string node_name = node_type->node_name(); MS_LOG(DEBUG) << "Ast node is " << node_name; if (expr_method_map_.count(node_name)) { auto expr_node = (this->*expr_method_map_[node_name])(block, node); TraceManager::ClearParseOrResolveDebugInfo(); return expr_node; } else { errcode_ = PARSE_NODE_METHOD_UNSUPPORTED; MS_LOG(EXCEPTION) << "Unsupported expression '" << node_name << "'.\nMore details please refer to syntax support at https://www.mindspore.cn"; } } // Process the expr statement and expand it FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Expr"; // Expr only have value, no target py::tuple expand_info = ast_->CallParseModFunction(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node); // Refer python function expand_expr_statement, expand_info is one of the following: // True, expr.value, x // True, expr.value // False, None, None // // Check the expand info result if (expand_info.empty()) { MS_LOG(EXCEPTION) << "Empty expand_info."; } auto is_expand = py::cast(expand_info[0]); if (is_expand) { // Process the expr statement constexpr size_t expect_size = 2; if (expand_info.size() < expect_size) { MS_LOG(EXCEPTION) << "expand_info size:" << expand_info.size() << " less than " << expect_size << "."; } py::object value_object = expand_info[1]; // Make a Expr CNode. AnfNodePtr call_node = ParseExprNode(block, value_object); if (py::len(expand_info) == 2) { // Expression that not assigned to any variable. // This is usually a call with side effects. // e.g.: print(x) // We save it as an isolated node. auto &no_return_node = call_node; MS_LOG(INFO) << "Isolated node found(NoReturn), no_return_node: " << no_return_node->DebugString(2) << ", block: " << block << "/" << (block->func_graph() ? block->func_graph()->ToString() : "FG(Null)") << ", Line: " << trace::GetDebugInfo(no_return_node->debug_info(), "", kSourceLineTipDiscard); block->AddIsolatedNode(no_return_node); } else { // Expand the assign statement, // e.g.: x.append(y) -> x = x.append(y) py::object target_node = expand_info[2]; WriteAssignVars(block, target_node, call_node); } } return block; } LocationPtr Parser::GetLocation(const py::object &node) const { MS_EXCEPTION_IF_NULL(ast_); py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); constexpr size_t list_size = 5; if (ret.size() < list_size) { MS_LOG(EXCEPTION) << "List size should not be less than 5."; } const size_t file_name_index = 0; const size_t line_index = 1; const size_t column_index = 2; const size_t line_end_index = 3; const size_t column_end_index = 4; // Refer to Location::Location() for each member of ret: line, column, line_end, column_end. auto location = std::make_shared(ret[file_name_index].cast(), ret[line_index].cast(), ret[column_index].cast(), ret[line_end_index].cast(), ret[column_end_index].cast()); return location; } void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block, const FunctionBlockPtr &false_block) { MS_EXCEPTION_IF_NULL(true_block); MS_EXCEPTION_IF_NULL(false_block); true_block->AddPrevBlock(pre_block); true_block->Mature(); false_block->AddPrevBlock(pre_block); false_block->Mature(); } FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast return"; MS_EXCEPTION_IF_NULL(block); // Parse the return Statements value. py::object value_object = python_adapter::GetPyObjAttr(node, "value"); AnfNodePtr return_expr_node = ParseExprNode(block, value_object); // Check if need interpreting. return_expr_node = HandleInterpret(block, return_expr_node, value_object); // Create the `return` CNode. auto func_graph = block->func_graph(); CNodePtr return_cnode = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), return_expr_node}); func_graph->set_return(return_cnode); return block; } // Process binary operators,eg: `a + b`, `a | b`, etc. AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast BinOP"; py::object left = python_adapter::GetPyObjAttr(node, "left"); py::object right = python_adapter::GetPyObjAttr(node, "right"); py::object op = python_adapter::GetPyObjAttr(node, "op"); // Create left and right ANF node AnfNodePtr left_node = ParseExprNode(block, left); if (left_node == nullptr) { MS_LOG(EXCEPTION) << "DoBinOp process left node failed: " << errcode(); } AnfNodePtr right_node = ParseExprNode(block, right); if (right_node == nullptr) { MS_LOG(EXCEPTION) << "DoBinOp process right node failed:" << errcode(); } // Resolve the op MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_node = block->MakeResolveAstOp(op); // Create apply node MS_EXCEPTION_IF_NULL(block->func_graph()); return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node}); } AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Name"; auto name_id = py::cast(python_adapter::GetPyObjAttr(node, "id")); MS_LOG(DEBUG) << "The Name id is " << name_id; MS_EXCEPTION_IF_NULL(block); if (block->IsGlobalVar(name_id)) { MS_LOG(DEBUG) << "name_id: " << name_id; return block->MakeResolveSymbol(name_id); } return block->ReadVariable(name_id); } AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) { MS_LOG(DEBUG) << "Process ast NoneType"; return NewValueNode(kNone); } AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) { MS_LOG(DEBUG) << "Process ast Ellipsis"; return NewValueNode(kEllipsis); } AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) { MS_LOG(DEBUG) << "Process ast Num"; py::object obj = python_adapter::GetPyObjAttr(node, "n"); if (py::isinstance(obj)) { MS_LOG(INFO) << "The Num is int64_t:" << (std::string)py::str(obj); auto data = py::cast(obj); return NewValueNode(data); } else if (py::isinstance(obj)) { MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj); auto data = py::cast(obj); return NewValueNode(data); } else { // no else actually errcode_ = PARSE_NODE_TYPE_UNKNOWN; MS_EXCEPTION(TypeError) << "Only support 'Number' type of 'int` and 'float', but got type: " << obj.get_type() << " Value:" << py::str(obj); } } AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) { MS_LOG(DEBUG) << "Process ast Str"; auto str_s = py::cast(python_adapter::GetPyObjAttr(node, "s")); return NewValueNode(str_s); } AnfNodePtr Parser::ParseConstant(const FunctionBlockPtr &, const py::object &node) { MS_LOG(DEBUG) << "Process ast Constant"; py::object obj = python_adapter::GetPyObjAttr(node, "value"); if (py::isinstance(obj)) { MS_LOG(INFO) << "The Constant is bool:" << (std::string)py::str(obj); return NewValueNode(py::cast(obj)); } else if (py::isinstance(obj)) { MS_LOG(INFO) << "The Constant is int64_t:" << (std::string)py::str(obj); return NewValueNode(py::cast(obj)); } else if (py::isinstance(obj)) { MS_LOG(INFO) << "The Constant is float:" << (std::string)py::str(obj); return NewValueNode(py::cast(obj)); } else if (py::isinstance(obj)) { MS_LOG(INFO) << "The Constant is string:" << (std::string)py::str(obj); return NewValueNode(py::cast(obj)); } else if (py::isinstance(obj)) { MS_LOG(INFO) << "The Constant is none:" << (std::string)py::str(obj); return NewValueNode(kNone); } else if (py::isinstance(obj)) { MS_LOG(INFO) << "The Constance is ellipsis:" << (std::string)py::str(obj); return NewValueNode(kEllipsis); } else { // no else actually MS_EXCEPTION(TypeError) << "Unsupported Constant type : " << (std::string)py::str(obj); } } AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) { MS_LOG(DEBUG) << "Process ast NameConstant"; py::object obj = python_adapter::GetPyObjAttr(node, "value"); if (py::isinstance(obj)) { MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj); auto data = py::cast(obj); return NewValueNode(data); } else if (py::isinstance(obj)) { MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj); return NewValueNode(kNone); } // no else actually errcode_ = PARSE_NODE_TYPE_UNKNOWN; MS_LOG(EXCEPTION) << "Unsupported NameConstant type: " << (std::string)py::str(obj); } AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector &element_nodes) { MS_EXCEPTION_IF_NULL(block); AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); std::vector make_tuple_nodes; make_tuple_nodes.push_back(make_tuple_op); (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes), [](AnfNodePtr arg) -> AnfNodePtr { return arg; }); return block->func_graph()->NewCNodeInOrder(std::move(make_tuple_nodes)); } AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) { MS_EXCEPTION_IF_NULL(block); py::object father_class; const size_t expect_args_size = 2; if (args.empty()) { father_class = py::none(); } else if (args.size() == expect_args_size) { father_class = args[0]; auto arg_type = AstSubType(py::cast(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[1]))); if (arg_type != AST_SUB_TYPE_NAME || py::cast(python_adapter::GetPyObjAttr(args[1], "id")) != "self") { MS_EXCEPTION(ArgumentError) << "Argument 2 of 'super()' must be 'self', but got '" << py::cast(python_adapter::GetPyObjAttr(args[1], "id")) << "'."; } } else { MS_EXCEPTION(ArgumentError) << "Arguments number of 'super()' should be 0 or 2, but got " << args.size() << "."; } py::object target_class_instance = ast_->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast_->obj()); py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance); NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); SymbolPtr symbol = std::make_shared("namespace"); MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString(); return block->MakeResolve(name_space, symbol); } // Process function call, eg : f1(x, y) ... AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Call"; // Process function call py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func"); py::list args = python_adapter::GetPyObjAttr(node, "args"); auto arg_type = AstSubType(py::cast(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, function_ast_node))); if (arg_type == AST_SUB_TYPE_NAME) { auto name_id = py::cast(python_adapter::GetPyObjAttr(function_ast_node, "id")); if (name_id == "super") { return ParseSuper(block, args); } } AnfNodePtr call_function_node = ParseExprNode(block, function_ast_node); // Function call arguments should be passed in as groups and unpacked later using unpack call std::vector packed_arguments; std::vector group_arguments; bool need_unpack_args = ParseArgsInCall(block, args, &packed_arguments, &group_arguments); bool need_unpack_keywords = ParseKeywordsInCall(block, node, &packed_arguments); // If there is stared or keyword argument, unpack may be needed bool need_unpack = need_unpack_args || need_unpack_keywords; auto call_cnode = GenerateAnfNodeForCall(block, call_function_node, packed_arguments, group_arguments, need_unpack); if (call_function_node->interpret()) { call_cnode->set_interpret(true); } return call_cnode; } CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_node, const std::vector &packed_arguments) { MS_EXCEPTION_IF_NULL(func_graph); std::vector unpack_call_nodes; auto unpack_call_op = NewValueNode(std::make_shared(NAMED_METAGRAPH_UNPACKCALL)); unpack_call_nodes.push_back(unpack_call_op); unpack_call_nodes.push_back(call_function_node); (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), [](AnfNodePtr node) -> AnfNodePtr { return node; }); CNodePtr unpack_call = func_graph->NewCNodeInOrder(std::move(unpack_call_nodes)); return unpack_call; } AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_node, const std::vector &packed_arguments, const std::vector &group_arguments, bool need_unpack) const { // If there is keyword arguments or starred, using an unpack_call op to unpack the argument MS_EXCEPTION_IF_NULL(block); if (need_unpack) { return MakeUnpackCall(block->func_graph(), call_function_node, packed_arguments); } // else there is no keyword arguments and starred, parsed as normal arguments without unpack std::vector func_call_nodes; func_call_nodes.push_back(call_function_node); (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes), [](AnfNodePtr node) -> AnfNodePtr { return node; }); CNodePtr call_anf_node = block->func_graph()->NewCNodeInOrder(std::move(func_call_nodes)); return call_anf_node; } bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, std::vector *packed_arguments, std::vector *group_arguments) { MS_LOG(DEBUG) << "Process ast args in call"; MS_EXCEPTION_IF_NULL(packed_arguments); MS_EXCEPTION_IF_NULL(group_arguments); bool need_unpack = false; for (size_t i = 0; i < args.size(); i++) { auto arg_node = AstSubType(py::cast(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[i]))); if (arg_node == AST_SUB_TYPE_STARRED) { if (!group_arguments->empty()) { packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments)); } packed_arguments->push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value"))); group_arguments->clear(); need_unpack = true; } else { auto node = ParseExprNode(block, args[i]); node = HandleInterpret(block, node, args[i]); group_arguments->push_back(node); } } if (!group_arguments->empty()) { packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments)); } return need_unpack; } bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, std::vector *packed_arguments) { MS_LOG(DEBUG) << "Process ast key words in call"; bool need_unpack = false; py::list keywords = python_adapter::GetPyObjAttr(node, "keywords"); if (!keywords.empty()) { MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(packed_arguments); need_unpack = true; std::vector keys; std::vector values; for (size_t index = 0; index < keywords.size(); index++) { auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg"); auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value"); if (py::isinstance(kw_key)) { packed_arguments->push_back(ParseExprNode(block, kw_value)); } else { auto kw_key_c = kw_key.cast(); keys.push_back(NewValueNode(kw_key_c)); auto ret_node = ParseExprNode(block, kw_value); ret_node = HandleInterpret(block, ret_node, kw_value); values.push_back(ret_node); } } auto keys_tuple = GenerateMakeTuple(block, keys); auto values_tuple = GenerateMakeTuple(block, values); auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT); std::vector make_dict_nodes; make_dict_nodes.push_back(make_dict_op); make_dict_nodes.push_back(keys_tuple); make_dict_nodes.push_back(values_tuple); packed_arguments->push_back(block->func_graph()->NewCNodeInOrder(std::move(make_dict_nodes))); } return need_unpack; } // Process call attributes of class type define, eg: x.y() AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Attribute"; // Process class value, eg: self.xx if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { if (ast_->IsClassMember(node)) { std::string var_name = "self."; std::string attr_name = node.attr("attr").cast(); (void)var_name.append(attr_name); auto attr_obj = ast()->obj().attr(attr_name.c_str()); MS_EXCEPTION_IF_NULL(block); if (py::hasattr(ast()->obj(), attr_name.c_str()) && (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance(attr_obj) || py::isinstance(attr_obj) || py::isinstance(attr_obj) || py::isinstance(attr_obj) || data_converter::IsCellInstance(attr_obj))) { MS_LOG(DEBUG) << "var_name: " << var_name; return block->MakeResolveSymbol(var_name); } else { return block->ReadVariable(var_name); } } } // Process the get attr // Use the Primitive replace the operation resolve node (getattr), // because the getattr will eventually be converted to Primitive node AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr); // Process the attr body py::object value_body = python_adapter::GetPyObjAttr(node, "value"); AnfNodePtr value_node = ParseExprNode(block, value_body); if (value_node == nullptr) { MS_LOG(EXCEPTION) << "Parse attribute failed"; } // Process the node attr auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast(); MS_LOG(DEBUG) << "Attr = " << attr_str; AnfNodePtr attr_node = nullptr; { TraceGuard guard(GetLocation(python_adapter::GetPyObjAttr(node, "attr"))); attr_node = NewValueNode(attr_str); } // Create the apply node auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node}); if (value_node->interpret() || IsPrimitiveCNode(value_node, prim::kPrimPyInterpret)) { attr_cnode->set_interpret(true); } return attr_cnode; } // Process comparison expression : a == b. a > b etc. AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Compare"; // For python comparison ,there may be if x>y>5 , // Which there is two ops , but we only support one now py::list ops = python_adapter::GetPyObjAttr(node, "ops"); if (ops.size() != MAX_COMPARISON_OPS_SUPPORTED) { MS_EXCEPTION(NotSupportError) << "Only support comparison with 1 operator, but got " << ops.size() << ", which is " << py::str(ops); } py::object left = python_adapter::GetPyObjAttr(node, "left"); py::list comparators = python_adapter::GetPyObjAttr(node, "comparators"); if (comparators.empty()) { MS_LOG(EXCEPTION) << "Comparators can't be empty."; } AnfNodePtr left_node = ParseExprNode(block, left); AnfNodePtr right_node = ParseExprNode(block, comparators[0]); MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]); return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node}); } AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) { // If there is only one bool op now MS_EXCEPTION_IF_NULL(block); if (value_list.empty()) { MS_LOG(EXCEPTION) << "value list is empty."; } if (value_list.size() == 1) { AnfNodePtr first_node = ParseExprNode(block, value_list[0]); return first_node; } else { py::object first = value_list[0]; py::list rest; for (size_t i = 1; i < value_list.size(); i++) { rest.append(value_list[i]); } FunctionBlockPtr true_block = nullptr; FunctionBlockPtr false_block = nullptr; auto block_fg = block->func_graph(); { TraceGuard guard(std::make_shared(block_fg->debug_info())); true_block = MakeFunctionBlock(*this); } { TraceGuard guard(std::make_shared(block_fg->debug_info())); false_block = MakeFunctionBlock(*this); } MakeConditionBlocks(block, true_block, false_block); FunctionBlockPtr b1, b2; // If it is and, we need to process the rest nodes; // If it is or, we continue to next if (mode == AST_SUB_TYPE_AND) { b1 = true_block; b2 = false_block; } else if (mode == AST_SUB_TYPE_OR) { b2 = true_block; b1 = false_block; } else { MS_LOG(ERROR) << "Not supported mode: " << mode; return nullptr; } AnfNodePtr test_node = ParseExprNode(block, first); AnfNodePtr rest_node = ProcessBoolOpValueList(b1, rest, mode); b1->func_graph()->set_output(rest_node); b2->func_graph()->set_output(test_node); auto cond_node = block->ForceToBoolNode(test_node); auto switch_app = block_fg->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()), NewValueNode(false_block->func_graph())}); std::vector call_graph_nodes{switch_app}; auto switch_app_call = block_fg->NewCNodeInOrder(std::move(call_graph_nodes)); return switch_app_call; } } // Process comparison expression : a and b. a or b . AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast BoolOp"; py::object op_node = python_adapter::GetPyObjAttr(node, "op"); AstSubType op_type = ast_->GetOpType(op_node); if (op_type == AST_SUB_TYPE_UNKNOWN) { MS_LOG(EXCEPTION) << "ProcessBoolOp, got unknown op type"; } py::list op_values = python_adapter::GetPyObjAttr(node, "values"); return ProcessBoolOpValueList(block, op_values, op_type); } // Process a function def FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast FunctionDef"; FunctionBlockPtr function_block = ParseDefFunction(node, block); MS_EXCEPTION_IF_NULL(function_block); // Get function name py::str name = python_adapter::GetPyObjAttr(node, "name"); std::string function_name = name; ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph()); block->WriteVariable(function_name, valuenode_graph); return block; } // Process a lambda expression . like lambda x,y: x + y AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Lambda"; FunctionBlockPtr function_block = ParseLambdaFunction(node, block); MS_EXCEPTION_IF_NULL(function_block); auto block_fg = function_block->func_graph(); ValueNodePtr const_graph = NewValueNode(block_fg); return const_graph; } // Process a tuple AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Tuple"; MS_EXCEPTION_IF_NULL(block); py::tuple elts = python_adapter::GetPyObjAttr(node, "elts"); if (elts.empty()) { auto empty_tuple = std::vector(); return NewValueNode(std::make_shared(empty_tuple)); } std::vector tuple_vec; AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); tuple_vec.emplace_back(make_tuple_op); for (size_t i = 0; i < elts.size(); i++) { AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); tuple_vec.emplace_back(node_ptr); } CNodePtr tuple_app = block->func_graph()->NewCNodeInOrder(std::move(tuple_vec)); return tuple_app; } // Process a list AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast List"; MS_EXCEPTION_IF_NULL(block); py::list elts = python_adapter::GetPyObjAttr(node, "elts"); if (elts.empty()) { auto empty_list = std::vector(); return NewValueNode(std::make_shared(empty_list)); } std::vector list_vec; AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST); list_vec.emplace_back(make_list_op); for (size_t i = 0; i < elts.size(); i++) { AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); list_vec.emplace_back(node_ptr); } CNodePtr list_app = block->func_graph()->NewCNodeInOrder(std::move(list_vec)); return list_app; } // Process a subscript, such as x[y] , node expressed as value[slice] AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Subscript"; MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); py::object value_node = python_adapter::GetPyObjAttr(node, "value"); py::object slice_node = python_adapter::GetPyObjAttr(node, "slice"); AnfNodePtr value = ParseExprNode(block, value_node); AnfNodePtr slice = ParseExprNode(block, slice_node); return block->func_graph()->NewCNodeInOrder({op_getitem, value, slice}); } // Process a slice, get the slice value AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Slice"; MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE); py::object start = python_adapter::GetPyObjAttr(node, "lower"); py::object stop = python_adapter::GetPyObjAttr(node, "upper"); py::object step = python_adapter::GetPyObjAttr(node, "step"); AnfNodePtr start_node = ParseExprNode(block, start); AnfNodePtr stop_node = ParseExprNode(block, stop); AnfNodePtr step_node = ParseExprNode(block, step); return block->func_graph()->NewCNodeInOrder({op_makeslice, start_node, stop_node, step_node}); } // Process a extslice AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast ExtSlice"; MS_EXCEPTION_IF_NULL(block); AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims"); std::vector node_vec; node_vec.emplace_back(make_tuple_op); for (size_t i = 0; i < slice_tuple.size(); i++) { AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]); node_vec.emplace_back(node_ptr); } CNodePtr tuple_conde = block->func_graph()->NewCNodeInOrder(std::move(node_vec)); return tuple_conde; } // Process a index, get the index number AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Index"; py::object value_node = python_adapter::GetPyObjAttr(node, "value"); return ParseExprNode(block, value_node); } // Process a UnaryOp, +a, -b AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast UnaryOp"; py::object op = python_adapter::GetPyObjAttr(node, "op"); MS_EXCEPTION_IF_NULL(block); // Resolve the op AnfNodePtr op_node = block->MakeResolveAstOp(op); py::object operand = python_adapter::GetPyObjAttr(node, "operand"); AnfNodePtr operand_node = ParseExprNode(block, operand); return block->func_graph()->NewCNodeInOrder({op_node, operand_node}); } // Process a dict ast node expression AnfNodePtr Parser::ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector &key_nodes, const std::vector &value_nodes) { auto keys_tuple = GenerateMakeTuple(block, key_nodes); auto values_tuple = GenerateMakeTuple(block, value_nodes); MS_EXCEPTION_IF_NULL(block); auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT); return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple}); } AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Dict"; py::list keys = node.attr("keys"); py::list values = node.attr("values"); std::vector key_nodes; std::vector value_nodes; for (size_t i = 0; i < keys.size(); i++) { key_nodes.push_back(ParseExprNode(block, keys[i])); value_nodes.push_back(ParseExprNode(block, values[i])); } return ParseDictByKeysAndValues(block, key_nodes, value_nodes); } // Process a augment assign such as a += b or mat[stride_slice] += b. FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast AugAssign"; MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(ast_); py::object target_object = python_adapter::GetPyObjAttr(node, "target"); py::object op_object = python_adapter::GetPyObjAttr(node, "op"); py::object value_object = python_adapter::GetPyObjAttr(node, "value"); AnfNodePtr target_node = nullptr; AnfNodePtr op_node = block->MakeResolveAstOp(op_object); AnfNodePtr value_node = ParseExprNode(block, value_object); auto ast_type = AstSubType(py::cast(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object))); if (ast_type == AST_SUB_TYPE_NAME) { target_node = ParseName(block, target_object); } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) { target_node = ParseSubscript(block, target_object); } else if (ast_->IsClassMember(target_object)) { target_node = ParseAttribute(block, target_object); } else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) { TraceGuard(GetLocation(target_object)); MS_EXCEPTION(TypeError) << "Only support augassign to attribute of self, but got attribute of " << py::str(target_object.attr("value").attr("id")) << ".\n" << "More details please refer to syntax support at https://www.mindspore.cn"; } else { TraceGuard(GetLocation(target_object)); MS_EXCEPTION(TypeError) << "Only supported augassign to attribute of self, variable and index value, but got " << target_object.get_type() << ".\nMore details please refer to syntax support at https://www.mindspore.cn"; } if (target_node == nullptr) { MS_LOG(EXCEPTION) << "Can not get target node "; } CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node}); WriteAssignVars(block, target_object, augassign_app); return block; } // Process global declaration such as 'global x'; FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Global"; MS_EXCEPTION_IF_NULL(block); py::list vars = python_adapter::GetPyObjAttr(node, "names"); for (auto &item : vars) { block->AddGlobalVar(py::cast(item)); } return block; } // Process a if statement FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast If"; py::object test_node = python_adapter::GetPyObjAttr(node, "test"); AnfNodePtr condition_node = ParseExprNode(block, test_node); MS_EXCEPTION_IF_NULL(block); CNodePtr bool_node = block->ForceToBoolNode(condition_node); FunctionBlockPtr true_block = nullptr; FunctionBlockPtr false_block = nullptr; auto block_fg = block->func_graph(); { TraceGuard guard(std::make_shared(block_fg->debug_info())); true_block = MakeFunctionBlock(*this); } { TraceGuard guard(std::make_shared(block_fg->debug_info())); false_block = MakeFunctionBlock(*this); } MakeConditionBlocks(block, true_block, false_block); FunctionBlockPtr after_block = nullptr; { TraceGuard guard(std::make_shared(block_fg->debug_info())); after_block = MakeFunctionBlock(*this); } if (MsContext::GetInstance()->backend_policy() != "ge") { // For backends excludes 'ge', it can handle multi graph call, use this flag to // generate call not inline `after_block` graph to reduce if by if switch expansion. after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true); } // Process the if-true branch py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode); // If the return_ is set, it has its own continuation block if (true_end->func_graph()->get_return() == nullptr) { MS_LOG(DEBUG) << "true end jump to after."; true_end->Jump(after_block, {}); } // Process the orelse branch py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse"); FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode); // If the return_ is set, it has its own continuation block if (false_end->func_graph()->get_return() == nullptr) { MS_LOG(DEBUG) << "false_end jump to after."; false_end->Jump(after_block, {}); } block->ConditionalJump(bool_node, true_block, false_block); if (after_block->prev_blocks().empty()) { after_block->SetAsDeadBlock(); } after_block->Mature(); return after_block; } FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast While"; MS_EXCEPTION_IF_NULL(block); FunctionBlockPtr header_block = nullptr; FunctionBlockPtr body_block = nullptr; FunctionBlockPtr after_block = nullptr; { TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); header_block = MakeFunctionBlock(*this); auto func_graph = header_block->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); func_graph->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true); } { TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); body_block = MakeFunctionBlock(*this); } { TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); after_block = MakeFunctionBlock(*this); } body_block->AddPrevBlock(header_block); after_block->AddPrevBlock(header_block); block->Jump(header_block, {}); py::object test_node = python_adapter::GetPyObjAttr(node, "test"); AnfNodePtr condition_node = ParseExprNode(header_block, test_node); condition_node = header_block->ForceToWhileCond(condition_node); body_block->Mature(); header_block->ConditionalJump(condition_node, body_block, after_block); // Parse loop body statements with loop context. LoopContext loop_context{&loops_, header_block, nullptr}; py::object body_node = python_adapter::GetPyObjAttr(node, "body"); FunctionBlockPtr after_body = ParseStatements(body_block, body_node); if (after_body->func_graph()->get_return() == nullptr) { after_body->Jump(header_block, {}); } header_block->Mature(); after_block->Mature(); auto &end_block = loop_context.EndBlock(); // end_block exists if we encounter 'break' in loop body. if (end_block) { after_block->Jump(end_block, {}); end_block->Mature(); return end_block; } // No 'break', no end_block. return after_block; } CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::object &node, const AnfNodePtr &op_iter) { py::object iter_node = python_adapter::GetPyObjAttr(node, "iter"); AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node); return block->func_graph()->NewCNodeInOrder({op_iter, iter_anf_node}); } CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block, const AnfNodePtr &op_hasnext) { MS_EXCEPTION_IF_NULL(header_block); return header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param}); } FunctionBlockPtr Parser::GenerateBlock(const TraceInfoPtr &trace_info) { TraceGuard trace_guard(trace_info); FunctionBlockPtr block = MakeFunctionBlock(*this); MS_EXCEPTION_IF_NULL(block); return block; } int64_t Parser::GetForTransToWhileLoop() { // int64 support 63bits positive num mostly. constexpr auto max_num_length = 10; if (max_for_loop_count_str_.size() > max_num_length || max_for_loop_count_str_.empty()) { return MAX_FOR_LOOP_COUNT; } if (std::any_of(max_for_loop_count_str_.begin(), max_for_loop_count_str_.end(), [](char c) { return c < '0' || c > '9'; })) { return MAX_FOR_LOOP_COUNT; } int64_t loop_count; std::stringstream ss; ss << max_for_loop_count_str_; ss >> loop_count; return loop_count; } // A for loop will generate 3 functions :the test, the body, and the continuation // for x in xs: // body // It is compiled to be following statement // if len(xs) < max_loop_cnt, ParseForIter. Use iter to implement for loop, which always unroll loop // else, ParseForLoop. Use loop var to implement for loop, which always sink loop FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast For, create an if else statement"; MS_EXCEPTION_IF_NULL(block); // Create statement 'len(xs) < MAX_FOR_LOOP_COUNT' AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER); AnfNodePtr iter_node = ParseExprNode(block, iter_obj); CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({op_len, iter_node}); CNodePtr bool_node = block->func_graph()->NewCNodeInOrder( {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(GetForTransToWhileLoop())}); // Create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop' FunctionBlockPtr true_block = nullptr; FunctionBlockPtr false_block = nullptr; { TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); true_block = MakeFunctionBlock(*this); } { TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); false_block = MakeFunctionBlock(*this); } MakeConditionBlocks(block, true_block, false_block); FunctionBlockPtr after_block = nullptr; { TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); after_block = MakeFunctionBlock(*this); } FunctionBlockPtr true_end = ParseForIter(true_block, node); true_end->Jump(after_block, {}); FunctionBlockPtr false_end = ParseForLoop(false_block, node); false_end->Jump(after_block, {}); block->ConditionalJump(bool_node, true_block, false_block); after_block->Mature(); return after_block; } // A for loop will generate 3 functions :the test, the body, and the continuation // for x in xs: // body // It is compiled to be following statement // it = iter(xs) // while hastnext(it) // x, it = next(it) // body FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast For"; MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER); AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT); AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT); // Generate the iterator apply CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter); MS_EXCEPTION_IF_NULL(iter_apply); FunctionBlockPtr header_block = GenerateBlock(std::make_shared(block->func_graph()->debug_info())); MS_EXCEPTION_IF_NULL(header_block); // Generate the hasnext apply which is a condition ParameterPtr iter_param = header_block->func_graph()->add_parameter(); CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext); // Generate the body of the for statement FunctionBlockPtr body_block = GenerateBlock(std::make_shared(block->func_graph()->debug_info())); MS_EXCEPTION_IF_NULL(body_block); body_block->AddPrevBlock(header_block); // Generate the iterator next apply // Process as following: `app = next(it); target = app[0]; it = app[1];` CNodePtr app = body_block->func_graph()->NewCNodeInOrder({op_next, iter_param}); CNodePtr target_app = body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast(0))}); py::object target_node = python_adapter::GetPyObjAttr(node, "target"); CNodePtr iter2_app = body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast(1))}); WriteAssignVars(body_block, target_node, target_app); // Link the variable name with the target auto it_info = std::make_shared(target_app->debug_info()); iter_param->debug_info()->set_trace_info(it_info); iter2_app->debug_info()->set_trace_info(it_info); iter_apply->debug_info()->set_trace_info(it_info); FunctionBlockPtr after_block = nullptr; { TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); after_block = MakeFunctionBlock(*this); } MS_EXCEPTION_IF_NULL(after_block); after_block->AddPrevBlock(header_block); block->Jump(header_block, {iter_apply}); body_block->Mature(); header_block->ConditionalJump(cond_apply, body_block, after_block); // Parse loop body statements with loop context. LoopContext loop_context{&loops_, header_block, iter2_app}; py::object body_node = python_adapter::GetPyObjAttr(node, "body"); FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); if (after_body_block->func_graph()->get_return() == nullptr) { after_body_block->Jump(header_block, {iter2_app}); } header_block->Mature(); after_block->Mature(); auto &end_block = loop_context.EndBlock(); if (end_block) { // end_block exists if we encounter 'break' in loop body. after_block->Jump(end_block, {}); end_block->Mature(); return end_block; } // No 'break', no end_block. return after_block; } // A for loop will generate 3 functions :the test, the body, and the continuation // for x in xs: // body // It is compiled to be following statement // i = 0 // while i < len(xs) // x = xs[i] // i = i + 1 // body FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast For by loop variable"; MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); // Get variable name of 'x' in statement 'for x in xs' py::object target_node = python_adapter::GetPyObjAttr(node, "target"); // Create statement 'len(xs)' py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter"); AnfNodePtr iter_node = ParseExprNode(block, iter_obj); MS_EXCEPTION_IF_NULL(iter_node); // Generate node for loop count and convert it to tensor, to make the loop not unroll CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node}); auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations"); auto scalar_to_tensor_node = block->func_graph()->NewCNodeInOrder({NewValueNode(scalar_to_tensor)}); CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, scalar_len}); FunctionBlockPtr header_block = GenerateBlock(std::make_shared(block->func_graph()->debug_info())); MS_EXCEPTION_IF_NULL(header_block); // Create loop variable 'i' ParameterPtr loop_var = header_block->func_graph()->add_parameter(); // Create loop condition 'i < len(xs)' auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations"); auto less_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(prim_less)}); CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({less_node, loop_var, len_iter}); // Generate the body of the for statement FunctionBlockPtr body_block = GenerateBlock(std::make_shared(block->func_graph()->debug_info())); MS_EXCEPTION_IF_NULL(body_block); body_block->AddPrevBlock(header_block); // Create 'x = xs[i]' auto body_func_graph = body_block->func_graph(); CNodePtr target_var = body_func_graph->NewCNodeInOrder({op_getitem, iter_node, loop_var}); WriteAssignVars(body_block, target_node, target_var); // Create 'i = i + 1' auto prim_add = prim::GetPythonOps("Add", "mindspore.ops.operations"); auto add_node = body_func_graph->NewCNodeInOrder({NewValueNode(prim_add)}); auto body_scalar_to_tensor_node = body_func_graph->NewCNodeInOrder({NewValueNode(scalar_to_tensor)}); auto add_tensor_node = body_func_graph->NewCNodeInOrder({body_scalar_to_tensor_node, NewValueNode(static_cast(1))}); CNodePtr loop_var_inc = body_func_graph->NewCNodeInOrder({add_node, loop_var, add_tensor_node}); body_block->WriteVariable(loop_var->name(), loop_var_inc); // Link the variable name with the target auto it_info = std::make_shared(loop_var_inc->debug_info()); loop_var->debug_info()->set_trace_info(it_info); len_iter->debug_info()->set_trace_info(it_info); FunctionBlockPtr after_block = nullptr; { TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); after_block = MakeFunctionBlock(*this); } MS_EXCEPTION_IF_NULL(after_block); after_block->AddPrevBlock(header_block); CNodePtr zero_tensor = block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, NewValueNode(static_cast(0))}); block->Jump(header_block, {zero_tensor}); body_block->Mature(); header_block->ConditionalJump(cond_node, body_block, after_block, false); // Parse loop body statements with loop context. LoopContext loop_context{&loops_, header_block, loop_var_inc}; py::object body_node = python_adapter::GetPyObjAttr(node, "body"); FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); if (after_body_block->func_graph()->get_return() == nullptr) { after_body_block->Jump(header_block, {loop_var_inc}); } header_block->Mature(); after_block->Mature(); auto &end_block = loop_context.EndBlock(); if (end_block) { // end_block exists if we encounter 'break' in loop body. after_block->Jump(end_block, {}); end_block->Mature(); return end_block; } // No 'break', no end_block. return after_block; } AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast IfExp"; MS_EXCEPTION_IF_NULL(block); py::object test_node = python_adapter::GetPyObjAttr(node, "test"); AnfNodePtr condition_node = ParseExprNode(block, test_node); CNodePtr bool_node = block->ForceToBoolNode(condition_node); FunctionBlockPtr true_block = nullptr; FunctionBlockPtr false_block = nullptr; { TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); true_block = MakeFunctionBlock(*this); } { TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); false_block = MakeFunctionBlock(*this); } MakeConditionBlocks(block, true_block, false_block); // Process the if-true branch py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode)); AnfNodePtr true_node = ParseExprNode(true_block, bodyNode); // Process the orelse branch py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse"); false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode)); AnfNodePtr false_node = ParseExprNode(false_block, orelseNode); true_block->func_graph()->set_output(true_node); false_block->func_graph()->set_output(false_node); // Use the Primitive replace the operation resolve node (switch), // because the switch will eventually be converted to Primitive node CNodePtr switch_app = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), bool_node, NewValueNode(true_block->func_graph()), NewValueNode(false_block->func_graph())}); std::vector call_graph_nodes{switch_app}; CNodePtr switch_app_call = block->func_graph()->NewCNodeInOrder(std::move(call_graph_nodes)); return switch_app_call; } FunctionBlockPtr Parser::ParseListCompIter(const FunctionBlockPtr &block, const py::object &node, const py::object &generator_node) { // Create a header block. FunctionBlockPtr top_block = GenerateBlock(std::make_shared(block->func_graph()->debug_info())); // Handle iter attribute. py::object iter_node = python_adapter::GetPyObjAttr(generator_node, "iter"); AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node); AnfNodePtr op_iter = top_block->MakeResolveOperation(NAMED_PRIMITIVE_ITER); CNodePtr iter_apply = top_block->func_graph()->NewCNodeInOrder({op_iter, iter_anf_node}); // Create header graph. FunctionBlockPtr list_header_block = GenerateBlock(std::make_shared(block->func_graph()->debug_info())); list_header_block->AddPrevBlock(top_block); // Create hasNext apply. AnfNodePtr op_hasnext = top_block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT); ParameterPtr iter_param = list_header_block->func_graph()->add_parameter(); constexpr auto iter_param_name = "iter"; iter_param->set_name(iter_param_name); iter_param->debug_info()->set_name(iter_param_name); CNodePtr cond_apply = list_header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param}); // Call the header graph with iter. ParameterPtr list_param = list_header_block->func_graph()->add_parameter(); constexpr auto list_param_name = "list"; list_param->set_name(list_param_name); list_param->debug_info()->set_name(list_param_name); auto empty_list = std::vector(); AnfNodePtr empty_list_node = NewValueNode(std::make_shared(empty_list)); top_block->Jump(list_header_block, {iter_apply, empty_list_node}); // Create body graph. FunctionBlockPtr list_body_block = GenerateBlock(std::make_shared(block->func_graph()->debug_info())); list_body_block->AddPrevBlock(list_header_block); AnfNodePtr op_next = top_block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT); CNodePtr next_apply = list_body_block->func_graph()->NewCNodeInOrder({op_next, iter_param}); AnfNodePtr op_getitem = top_block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); CNodePtr item_apply = list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast(0))}); CNodePtr new_iter = list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast(1))}); // Save the `target` in a variable. py::object gen_target_node = python_adapter::GetPyObjAttr(generator_node, "target"); WriteAssignVars(list_body_block, gen_target_node, item_apply); auto ifs_new_list = ParseListCompIfs(list_body_block, list_param, node, generator_node); list_body_block->Jump(list_header_block, {new_iter, ifs_new_list}); // Create after graph. FunctionBlockPtr list_after_block = GenerateBlock(std::make_shared(block->func_graph()->debug_info())); list_after_block->AddPrevBlock(list_header_block); // Return the list in after graph. list_after_block->func_graph()->set_output(list_param); // Run the branches. list_header_block->ConditionalJump(cond_apply, list_body_block, list_after_block); top_block->Mature(); list_header_block->Mature(); list_body_block->Mature(); list_after_block->Mature(); return top_block; } AnfNodePtr Parser::ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param, const py::object &node, const py::object &generator_node) { // Handle ifs attribute. py::list ifs_node = python_adapter::GetPyObjAttr(generator_node, "ifs"); AnfNodePtr ifs_bool_node; if (ifs_node.empty()) { ifs_bool_node = NewValueNode(true); } else { ifs_bool_node = ProcessBoolOpValueList(list_body_block, ifs_node, AST_SUB_TYPE_AND); } // Create if-true graph. FunctionBlockPtr if_true_block = GenerateBlock(std::make_shared(list_body_block->func_graph()->debug_info())); if_true_block->AddPrevBlock(list_body_block); // Handle elt attribute in body block. py::object elt_obj = python_adapter::GetPyObjAttr(node, "elt"); AnfNodePtr elt_node = ParseExprNode(list_body_block, elt_obj); // Append the element. auto list_append_op = prim::kPrimListAppend; auto new_list = list_body_block->func_graph()->NewCNodeInOrder({NewValueNode(list_append_op), list_param, elt_node}); // Return new list in true branch graph. if_true_block->func_graph()->set_output(new_list); // Create if-false graph. FunctionBlockPtr if_false_block = GenerateBlock(std::make_shared(list_body_block->func_graph()->debug_info())); if_false_block->AddPrevBlock(list_body_block); // Return original list in false branch graph. if_false_block->func_graph()->set_output(list_param); // We don't want to create a header graph, where to get and wrap the result of Switch(). // So just call ConditionalJump() to set Switch() as output, and reset it later, as tricky. list_body_block->ConditionalJump(ifs_bool_node, if_true_block, if_false_block); // Output is Switch() result, i.e. updated list. auto switch_apply_node = list_body_block->func_graph()->output(); auto ifs_new_list = switch_apply_node; // Since we call ConditionalJump() above, to reset the Return as null before call Jump(). list_body_block->func_graph()->set_return(nullptr); if_true_block->Mature(); if_false_block->Mature(); return ifs_new_list; } // A ListComp contains: `elt` and `generators`. // `generators` contains: `target`, `iter` and `ifs`. // For example: // [x * x for x in range(0, 10) if x % 2 == 0] // It is compiled to be following statement: // list = [] // for x in range(0, 10): // if x % 2 == 0: // list.append(x * x) // return list AnfNodePtr Parser::ParseListComp(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast ListComp"; MS_EXCEPTION_IF_NULL(block); // Handle generators attribute. py::list generators_node = python_adapter::GetPyObjAttr(node, "generators"); if (generators_node.size() != 1) { MS_EXCEPTION(TypeError) << "The 'generators' supports 1 'comprehension' in ListComp/GeneratorExp, but got " << generators_node.size() << " comprehensions."; } py::object generator_node = generators_node[0]; auto generator_node_type = ast_->GetNodeType(generator_node); auto generator_node_name = generator_node_type->node_name(); constexpr auto comprehension_name = "comprehension"; if (generator_node_name != comprehension_name) { MS_LOG(EXCEPTION) << "Generator node name should be " << comprehension_name << ", but got " << generator_node_name; } // Parse ListComp's `iter` and add `elt` in it. auto top_block = ParseListCompIter(block, node, generator_node); // Call the top graph and return the list. auto call_function_node = NewValueNode(top_block->func_graph()); std::vector func_call_nodes; func_call_nodes.push_back(call_function_node); AnfNodePtr output = block->func_graph()->NewCNodeInOrder(std::move(func_call_nodes)); return output; } void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &assigned_node) { MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(assigned_node); py::str name = python_adapter::GetPyObjAttr(target_object, "id"); std::string name_id = name; assigned_node->debug_info()->set_name(name_id); // Set the debug name of the constant graph if (IsValueNode(assigned_node)) { // The value should be graph auto fg = GetValueNode(assigned_node); if (fg->debug_info()->name().empty()) { fg->debug_info()->set_name(name_id); } } MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString(); block->WriteVariable(name_id, assigned_node); } void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &assigned_node) { MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); py::list items = python_adapter::GetPyObjAttr(target_object, "elts"); for (size_t i = 0; i < items.size(); i++) { // Use the Primitive replace the operation resolve node (getitem), // because the getitem will eventually be converted to Primitive node CNodePtr item_apply = block->func_graph()->NewCNodeInOrder({op_getitem, assigned_node, NewValueNode(static_cast(i))}); py::object elt = items[i]; WriteAssignVars(block, elt, item_apply); } } void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &assigned_node) { // Now only support the self.xx = xxxxx, can't support x.y = xxxx AnfNodePtr target_node = ParseExprNode(block, target_object); MS_EXCEPTION_IF_NULL(target_node); auto attr_name = target_object.attr("attr").cast(); std::string var_name = "self." + attr_name; // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { MS_EXCEPTION(TypeError) << "'" << var_name << "' should be initialized as a 'Parameter' in the '__init__' function before assigning.\n\n" << trace::GetDebugInfo(target_node->debug_info()); } auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); auto obj_type = obj.attr("__class__").attr("__name__"); if (!py::hasattr(obj, "__parameter__")) { MS_EXCEPTION(TypeError) << "'" << var_name << "' should be initialized as a 'Parameter' type in the '__init__' function, but got '" << py::str(obj).cast() << "' with type '" << py::str(obj_type).cast() << ".\n\n" << trace::GetDebugInfo(target_node->debug_info()); } MS_EXCEPTION_IF_NULL(block); MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString(); block->SetStateAssign(target_node, assigned_node); } void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &assigned_node) { MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM); py::object value_obj = python_adapter::GetPyObjAttr(target_object, "value"); py::object slice_obj = python_adapter::GetPyObjAttr(target_object, "slice"); AnfNodePtr value_node = ParseExprNode(block, value_obj); AnfNodePtr slice_node = ParseExprNode(block, slice_obj); CNodePtr setitem_app = block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node}); // Getitem apply should return the sequence data structure itself std::string var_name; if (ast_->IsClassMember(value_obj)) { auto attr_name = value_obj.attr("attr").cast(); var_name = "self." + attr_name; if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { MS_EXCEPTION(TypeError) << "'" << var_name << "' should be initialized as a 'Parameter' in the '__init__' function before assigning.\n\n" << trace::GetDebugInfo(value_node->debug_info()); } auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); auto obj_type = obj.attr("__class__").attr("__name__"); if (!py::hasattr(obj, "__parameter__")) { MS_EXCEPTION(TypeError) << "'" << var_name << "' should be initialized as a 'Parameter' in the '__init__' function, but got '" << py::str(obj).cast() << "' with type '" << py::str(obj_type).cast() << "'.\n\n" << trace::GetDebugInfo(value_node->debug_info()); } block->WriteVariable(var_name, setitem_app); return; } if (AstSubType(py::cast(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, value_obj))) == AST_SUB_TYPE_SUBSCRIPT) { HandleAssignSubscript(block, value_obj, setitem_app); return; } if (!py::hasattr(value_obj, "id")) { MS_EXCEPTION(TypeError) << "Attribute id not found in " << py::str(value_obj).cast() << "\n\n" << trace::GetDebugInfo(value_node->debug_info()); } var_name = value_obj.attr("id").cast(); block->WriteVariable(var_name, setitem_app); } void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &value_node) { MS_EXCEPTION_IF_NULL(value_node); MS_LOG(DEBUG) << "Process WriteAssignVars"; auto ast_type = AstSubType(py::cast(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object))); if (ast_type == AST_SUB_TYPE_NAME) { HandleAssignName(block, target_object, value_node); } else if (ast_type == AST_SUB_TYPE_TUPLE) { HandleAssignTuple(block, target_object, value_node); } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) { HandleAssignSubscript(block, target_object, value_node); } else if (ast_->IsClassMember(target_object)) { HandleAssignClassMember(block, target_object, value_node); } else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) { TraceGuard(GetLocation(target_object)); MS_EXCEPTION(TypeError) << "Only support assign to attribute of self, but got attribute of " << py::str(target_object.attr("value").attr("id")) << ".\n" << "More details please refer to syntax support at https://www.mindspore.cn"; } else { TraceGuard(GetLocation(target_object)); MS_EXCEPTION(TypeError) << "Only supported augassign to attribute of self, variable and index value, but got " << target_object.get_type() << ".\nMore details please refer to syntax support at https://www.mindspore.cn"; } } AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node, const py::object &value_object) { // The fallback feature is enabled in default. // Not support change the flag during the process is alive. static const auto use_fallback = (support_fallback() != "0"); if (!use_fallback || !value_node->interpret()) { return value_node; } const auto script_text = py::cast(ast()->GetAstNodeText(value_object)); // Prepare global parameters. py::dict global_dict = block->global_py_params(); ValuePtr globals_converted_value = nullptr; if (!ConvertData(global_dict, &globals_converted_value)) { MS_LOG(EXCEPTION) << "Convert data failed"; } auto global_dict_node = NewValueNode(globals_converted_value); // Prepare local parameters. auto [keys, values] = block->local_py_params(); auto local_dict_node = ParseDictByKeysAndValues(block, keys, values); // Update the valued node if it need interpreting. constexpr int recursive_level = 2; MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: `" << script_text << "`,\nvalue_node: " << value_node->DebugString(recursive_level) << ",\nglobal_dict_node: " << global_dict_node->ToString() << ",\nlocal_dict_node: " << local_dict_node->DebugString(recursive_level); AnfNodePtr interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node); // Print a hint for user. auto line_info = trace::GetDebugInfo(value_node->debug_info()); MS_LOG(INFO) << "Found unsupported syntax in Graph mode, those codes would be fallen back to Python interpreter:" << "\n\n" << line_info; InterpretNodeRecorder::GetInstance().PushLineInfo(line_info); return interpreted_node; } // Process a assign statement, such as a = b, a, b = tuple(xx, xx) FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast assign"; py::object value_object = python_adapter::GetPyObjAttr(node, "value"); AnfNodePtr value_node = ParseExprNode(block, value_object); value_node = HandleInterpret(block, value_node, value_object); py::object targets_object = python_adapter::GetPyObjAttr(node, "targets"); py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__"); size_t count = LongToSize(pcount); MS_LOG(DEBUG) << "The nodes count is " << count; for (size_t i = 0; i < count; i++) { auto target_node = py::cast(targets_object)[i]; WriteAssignVars(block, target_node, value_node); } return block; } FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) { if (loops_.empty()) { // Report error if loop context not set for the 'break' statement. MS_LOG(EXCEPTION) << "Unexpected 'break'."; } // Get current loop. Loop &loop = loops_.top(); if (loop.end == nullptr) { // Create end_block if it is not existed. TraceGuard trace_guard(std::make_shared(block->func_graph()->debug_info())); loop.end = MakeFunctionBlock(*this); } // Jump to the end_block. block->Jump(loop.end, {}); return block; } FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) { if (loops_.empty()) { // Report error if loop context not set for the 'continue' statement. MS_LOG(EXCEPTION) << "Unexpected 'continue."; } // Jump to the header of the loop with iterator called. Loop &loop = loops_.top(); std::vector args; if (loop.iterator != nullptr) { args.emplace_back(loop.iterator); } block->Jump(loop.header, args); return block; } FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) { // We just bypass 'pass' statement. return block; } AnfNodePtr FindPhis(const mindspore::HashMap &removable_phis, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); const auto &inp = node->cast(); const auto &iter = removable_phis.find(inp); if (iter == removable_phis.end()) { return node; } return FindPhis(removable_phis, iter->second); } void Parser::RemoveUnnecessaryPhis() { // Merge all removable phis to one map; mindspore::HashMap removable_phis; std::vector phis; for (FunctionBlockPtr &block : func_block_list_) { MS_EXCEPTION_IF_NULL(block); removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end()); std::transform(block->removable_phis().begin(), block->removable_phis().end(), std::back_inserter(phis), [](const auto &pair) { return pair.first; }); } if (removable_phis.empty()) { return; } auto mng = Manage(func_graph_, false); // Replace the nodes // Remove from inside to outside for (int64_t idx = SizeToLong(phis.size() - 1); idx >= 0; idx--) { auto phi = phis[LongToSize(idx)]; auto new_node = FindPhis(removable_phis, phi); mng->Replace(phi, new_node); } // Remove the parameter for (FunctionBlockPtr &block : func_block_list_) { MS_EXCEPTION_IF_NULL(block); auto &local_removable_phis = block->removable_phis(); if (local_removable_phis.empty()) { continue; } auto func_graph = block->func_graph(); auto ¶meters = func_graph->parameters(); std::vector new_parameters(parameters.size()); auto it = std::copy_if( parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](const AnfNodePtr ¶m) { MS_EXCEPTION_IF_NULL(param); return local_removable_phis.find(param->cast()) == local_removable_phis.end(); }); // Shrink container to new size new_parameters.resize(static_cast(std::distance(new_parameters.begin(), it))); func_graph->set_parameters(new_parameters); } } // ParseFunctionAst class code bool ParseFunctionAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) { // Init the type target_type_ = PARSE_TARGET_UNKNOW; // Call python parse, get the parser fn module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD); // Get the obj type auto type = data_converter::GetObjType(obj_); if (type == RESOLVE_TYPE_FUNCTION) { target_type_ = PARSE_TARGET_FUNCTION; function_ = obj_; } else if (type == RESOLVE_TYPE_METHOD) { // Process the method ,need get the method's self obj target_type_ = PARSE_TARGET_METHOD; py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS); if (py::isinstance(method_object)) { MS_LOG(ERROR) << "Get method's self object instance failed."; return false; } target_type_ = PARSE_TARGET_OBJECT_INSTANCE; function_ = obj_; obj_ = method_object; } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) { // obj is class instance, get the method to parse. function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method); if (py::isinstance(function_)) { MS_LOG(ERROR) << "Get obj method function failed."; return false; } target_type_ = PARSE_TARGET_OBJECT_INSTANCE; // Check the fn is method auto obj_type = data_converter::GetObjType(function_); if (obj_type != RESOLVE_TYPE_METHOD) { MS_LOG(WARNING) << "Parse method function is invalid."; return false; } } else { MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type; return false; } // Call python parse get ast tree parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method); py::tuple ast_info = python_adapter::CallPyObjMethod(parser_, "parse"); const size_t ast_info_size = 2; if (ast_info.size() != ast_info_size) { MS_EXCEPTION(NameError) << "ast info size is not equal to 2."; } ast_tokens_ = ast_info[0]; ast_tree_ = ast_info[1]; // Get fn name and module function_module_ = py::cast(python_adapter::GetPyObjAttr(parser_, "function_module")); function_name_ = py::cast(python_adapter::GetPyObjAttr(parser_, "function_name")); function_filename_ = py::cast(python_adapter::GetPyObjAttr(parser_, "filename")); function_line_offset_ = py::cast(python_adapter::GetPyObjAttr(parser_, "line_offset")); return true; } // Get ast tree node : is the tree bode list[0] py::object ParseFunctionAst::GetAstNode() { py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body"); py::object ast_node = tree_body[0]; return ast_node; } // Get ast tokens node text. py::str ParseFunctionAst::GetAstNodeText(const py::object &node_obj) { return python_adapter::CallPyObjMethod(ast_tokens_, "get_text", node_obj); } py::list ParseFunctionAst::GetArgs(const py::object &func_node) { py::list ret = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS, func_node); return ret; } py::list ParseFunctionAst::GetArgsDefaultValues(const py::object &func_node) { py::list ret = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node); return ret; } AstNodeTypePtr ParseFunctionAst::GetNodeType(const py::object &node) { py::list list_value = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_NODE_TYPE, node); const size_t list_value_size = 2; if (list_value.size() < list_value_size) { MS_LOG(EXCEPTION) << "The node of python method must has 2 values."; } auto node_name = py::cast(list_value[0]); auto type = AstMainType(py::cast(list_value[1])); return std::make_shared(node, node_name, type); } AstSubType ParseFunctionAst::GetOpType(const py::object &node) { auto op_type = AstSubType(python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_AST_TYPE, node).cast()); return op_type; } bool ParseFunctionAst::IsClassMember(const py::object &node) { py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node); if (!py::isinstance(ret)) { MS_LOG(ERROR) << "The result of mod function parse, should be bool type."; return false; } return ret.cast(); } void SetMixedPrecisionFlag(const py::object &obj, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); if (!py::isinstance(obj)) { return; } auto cell = py::cast(obj); MS_EXCEPTION_IF_NULL(cell); auto mixed_type = cell->GetMixedPrecisionType(); if (mixed_type != MixedPrecisionType::kNotSet) { func_graph->set_flag(GRAPH_FLAG_MIX_PRECISION_FP16, mixed_type == MixedPrecisionType::kFP16); func_graph->set_flag(GRAPH_FLAG_MIX_PRECISION_FP32, mixed_type == MixedPrecisionType::kFP32); } } bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { MS_LOG(ERROR) << "FuncGraph is null"; return false; } SetMixedPrecisionFlag(obj, func_graph); if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) { MS_LOG(DEBUG) << "No flags"; return true; } py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG); for (auto &item : flags) { if (!py::isinstance(item.first)) { MS_LOG(ERROR) << "Type error in flags dict convert"; return false; } auto name = py::cast(item.first); if (py::isinstance(item.second)) { auto value = py::cast(item.second); MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; func_graph->set_flag(name, value); } else if (py::isinstance(item.second)) { auto value = py::cast(item.second); MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; func_graph->set_attr(name, MakeValue(value)); } else { MS_LOG(ERROR) << "Type error in flags/attrs dict convert"; return false; } } return true; } // Generate and copy a ValueNode, or a CNode with its child nodes static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m_node) { MS_EXCEPTION_IF_NULL(param_node); if (param_node->isa()) { return std::make_shared(param_node->cast()->value()); } // Parameter default value is CNode. std::size_t index = 0; std::vector old_cnodes; old_cnodes.emplace_back(param_node); MS_EXCEPTION_IF_NULL(func_graph); auto res = func_graph->NewCNodeInOrder({}); std::vector new_cnodes; new_cnodes.emplace_back(res); while (index < old_cnodes.size()) { auto current = old_cnodes[index]; auto current_new_cnode = new_cnodes[index]; index++; if (current->isa()) { auto &inputs = current->cast()->inputs(); for (auto it = inputs.begin(); it != inputs.end(); it++) { AnfNodePtr input = *it; if (input != nullptr && input->isa()) { old_cnodes.emplace_back(input); auto new_cnode = func_graph->NewCNodeInOrder({}); new_cnodes.emplace_back(new_cnode); current_new_cnode->add_input(new_cnode); } else if (input->isa()) { current_new_cnode->add_input(std::make_shared(input->cast()->value())); } else { MS_LOG(EXCEPTION) << "Wrong type item in default parameters: " << input->ToString(); } } } } return res; } FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) { auto current_graph = dyn_cast(cell_ptr); if (current_graph == nullptr) { MS_LOG(EXCEPTION) << "Current graph cast failed from " << cell_ptr->ToString(); } auto func_graph = std::make_shared(); func_graph->debug_info()->set_name(current_graph->debug_info()->name() + "_wrapper"); func_graph->debug_info()->set_location(current_graph->debug_info()->location()); // Copy all parameters information for (auto ¶ : current_graph->parameters()) { auto param = func_graph->add_parameter(); auto orig_param = para->cast(); auto name = orig_param->name(); param->set_name(name); param->debug_info()->set_name(name); param->debug_info()->set_location(param->debug_info()->location()); param->set_is_top_graph_param(true); } func_graph->set_has_vararg(current_graph->has_vararg()); func_graph->set_has_kwarg(current_graph->has_kwarg()); func_graph->set_kwonlyargs_count(current_graph->kwonlyargs_count()); // Copy all default values for (auto &d : current_graph->parameter_default_value()) { func_graph->set_param_default_value(d.first, CopyNodesFromParamDefaultValue(func_graph, d.second)); } // cell_obj MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell)); parse::UpdateFuncGraphFlags(cell, func_graph); // Top graph's construct flag if (py::hasattr(cell, "construct")) { parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph); } auto unpacking = func_graph->has_vararg() || func_graph->has_kwarg(); if (!unpacking) { std::vector inputs; inputs.emplace_back(NewValueNode(cell_ptr)); auto ¶ms = func_graph->parameters(); (void)std::transform(params.begin(), params.end(), std::back_inserter(inputs), [](AnfNodePtr node) -> AnfNodePtr { return node; }); auto call_node = func_graph->NewCNodeInOrder(std::move(inputs)); TraceGuard guard(current_graph->get_return()->debug_info()->location()); func_graph->set_output(call_node); } else { // ret = cell_obj(*arg, *kwargs) auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), func_graph->parameters()); TraceGuard guard(current_graph->get_return()->debug_info()->location()); // Set output as ret func_graph->set_output(call_fn); } return func_graph; } } // namespace parse } // namespace mindspore