You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parse.cc 70 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/parse/parse.h"
  19. #include <utility>
  20. #include <string>
  21. #include <memory>
  22. #include <sstream>
  23. #include <unordered_map>
  24. #include <algorithm>
  25. #include "pipeline/jit/parse/resolve.h"
  26. #include "frontend/operator/ops.h"
  27. #include "pipeline/jit/parse/data_converter.h"
  28. #include "frontend/operator/composite/composite.h"
  29. #include "utils/context/ms_context.h"
  30. #include "debug/trace.h"
  31. namespace mindspore {
  32. namespace parse {
  33. FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) {
  34. (void)python_adapter::set_python_scoped();
  35. if (obj == nullptr || py::isinstance<py::none>(obj)) {
  36. MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none";
  37. return nullptr;
  38. }
  39. auto ast = std::make_shared<ParseAst>(obj);
  40. bool success = ast->InitParseAstInfo(python_mod_get_parse_method);
  41. if (!success) {
  42. MS_LOG(ERROR) << "Parse code to ast tree failed.";
  43. return nullptr;
  44. }
  45. auto parser = std::make_shared<Parser>(ast);
  46. FuncGraphPtr func_graph = parser->ParseFuncGraph();
  47. if (func_graph == nullptr) {
  48. MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode();
  49. return nullptr;
  50. }
  51. return func_graph;
  52. }
  53. // if any mixed precision flag add a cast node after the parameter node.
  54. AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
  55. TypePtr dst_type;
  56. if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
  57. dst_type = kFloat32;
  58. } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
  59. dst_type = kFloat16;
  60. } else {
  61. return param;
  62. }
  63. auto cast_helper = prim::kPrimMixedPrecisionCast;
  64. auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param});
  65. return cast;
  66. }
  67. FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
  68. Parser::Parser(const std::shared_ptr<ParseAst> &ast) : ast_(ast) {
  69. errcode_ = PARSE_SUCCESS;
  70. BuildMethodMap();
  71. }
  72. void Parser::BuildMethodMap() {
  73. stmt_method_map_["Return"] = &Parser::ParseReturn;
  74. stmt_method_map_["Expr"] = &Parser::ParseExpr;
  75. stmt_method_map_["If"] = &Parser::ParseIf;
  76. stmt_method_map_["Assign"] = &Parser::ParseAssign;
  77. stmt_method_map_["While"] = &Parser::ParseWhile;
  78. stmt_method_map_["For"] = &Parser::ParseFor;
  79. stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef;
  80. stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign;
  81. stmt_method_map_["Global"] = &Parser::ParseGlobal;
  82. stmt_method_map_["Break"] = &Parser::ParseBreak;
  83. stmt_method_map_["Continue"] = &Parser::ParseContinue;
  84. stmt_method_map_["Pass"] = &Parser::ParsePass;
  85. expr_method_map_["NoneType"] = &Parser::ParseNone;
  86. expr_method_map_["BinOp"] = &Parser::ParseBinOp;
  87. expr_method_map_["Name"] = &Parser::ParseName;
  88. expr_method_map_["Num"] = &Parser::ParseNum;
  89. expr_method_map_["Str"] = &Parser::ParseStr;
  90. expr_method_map_["NameConstant"] = &Parser::ParseNameConstant;
  91. expr_method_map_["Call"] = &Parser::ParseCall;
  92. expr_method_map_["IfExp"] = &Parser::ParseIfExp;
  93. expr_method_map_["Attribute"] = &Parser::ParseAttribute;
  94. expr_method_map_["Compare"] = &Parser::ParseCompare;
  95. expr_method_map_["BoolOp"] = &Parser::ParseBoolOp;
  96. expr_method_map_["Lambda"] = &Parser::ParseLambda;
  97. expr_method_map_["Tuple"] = &Parser::ParseTuple;
  98. expr_method_map_["List"] = &Parser::ParseList;
  99. expr_method_map_["Subscript"] = &Parser::ParseSubscript;
  100. expr_method_map_["Slice"] = &Parser::ParseSlice;
  101. expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice;
  102. expr_method_map_["Index"] = &Parser::ParseIndex;
  103. expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
  104. expr_method_map_["Dict"] = &Parser::ParseDict;
  105. expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
  106. }
  107. void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
  108. void Parser::InitParserEnvironment(const py::object &obj) {
  109. Parser::top_func_graph_ = FuncGraphWeakPtr();
  110. ScopeManager::GetInstance().ClearScope();
  111. (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj);
  112. }
  113. void Parser::CleanParserResource() {
  114. Parser::top_func_graph_ = FuncGraphWeakPtr();
  115. ScopeManager::GetInstance().ClearScope();
  116. }
  117. FuncGraphPtr Parser::ParseFuncGraph() {
  118. // get ast FunctionDef node
  119. py::object node = ast_->GetAstNode();
  120. FunctionBlockPtr pFnBlock = ParseFunction(node);
  121. if (errcode() != PARSE_SUCCESS) {
  122. MS_LOG(ERROR) << "Parse function error, code is " << errcode();
  123. return nullptr;
  124. }
  125. RemoveUnnecessaryPhis();
  126. MS_EXCEPTION_IF_NULL(pFnBlock);
  127. return pFnBlock->func_graph();
  128. }
  129. void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
  130. py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args");
  131. py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg");
  132. block->func_graph()->set_has_vararg(!py::isinstance<py::none>(var_arg_node));
  133. py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg");
  134. block->func_graph()->set_has_kwarg(!py::isinstance<py::none>(kw_arg_node));
  135. py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs");
  136. block->func_graph()->set_kwonlyargs_count(SizeToInt(kwonly_args.size()));
  137. MS_EXCEPTION_IF_NULL(ast_);
  138. py::list args = ast_->GetArgs(fn_node);
  139. for (std::size_t i = 0; i < args.size(); i++) {
  140. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  141. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  142. if (arg_name == "self") {
  143. continue;
  144. }
  145. }
  146. TraceManager::DebugTrace(GetLocation(args[i]));
  147. auto para_node = std::make_shared<Parameter>(block->func_graph());
  148. MS_EXCEPTION_IF_NULL(para_node);
  149. TraceManager::EndTrace();
  150. para_node->set_name(arg_name);
  151. para_node->debug_info()->set_name(arg_name);
  152. block->func_graph()->add_parameter(para_node);
  153. AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block->func_graph(), para_node);
  154. block->WriteVariable(arg_name, para_after_cast);
  155. MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
  156. }
  157. }
  158. void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
  159. py::list defaults = ast_->GetArgsDefaultValues(fn_node);
  160. py::list args = ast_->GetArgs(fn_node);
  161. std::vector<std::string> namelist_for_default_value;
  162. std::vector<AnfNodePtr> default_values;
  163. for (std::size_t i = 0; i < args.size(); i++) {
  164. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  165. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  166. if (arg_name == "self") {
  167. continue;
  168. }
  169. }
  170. namelist_for_default_value.push_back(arg_name);
  171. if (py::isinstance<py::none>(defaults[i])) {
  172. default_values.push_back(NewValueNode(kNull));
  173. } else {
  174. default_values.push_back(ParseExprNode(block, defaults[i]));
  175. }
  176. }
  177. block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values);
  178. }
  179. ScopePtr Parser::GetScopeForParseFunction() {
  180. ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope();
  181. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  182. py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj());
  183. if (!py::isinstance<py::none>(scope_str)) {
  184. auto scope_name = py::cast<std::string>(scope_str);
  185. scope = std::make_shared<Scope>(scope_name);
  186. }
  187. }
  188. return scope;
  189. }
  190. FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) {
  191. ScopePtr scope = GetScopeForParseFunction();
  192. // the node created in the parsefunction context, will inherit the scope created using scope_guard
  193. ScopeGuard scope_guard(scope);
  194. TraceGuard trace_guard(data_converter::GetObjKey(ast()->obj())[0], GetLocation(node));
  195. FunctionBlockPtr pFunBlock = MakeFunctionBlock(*this);
  196. if (block != nullptr) {
  197. pFunBlock->AddPrevBlock(block);
  198. } else {
  199. func_graph_ = pFunBlock->func_graph();
  200. }
  201. pFunBlock->Mature();
  202. auto current_fg = pFunBlock->func_graph();
  203. auto function_name = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "name"));
  204. MS_LOG(DEBUG) << "The function name is " << function_name;
  205. current_fg->debug_info()->set_name(function_name);
  206. MS_EXCEPTION_IF_NULL(ast_);
  207. py::list deco_list = node.attr("decorator_list");
  208. if (deco_list.size() > 0) {
  209. current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
  210. }
  211. bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
  212. if (ast_->obj() != ast_->function()) {
  213. set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg);
  214. }
  215. if (!set_flag) {
  216. MS_LOG(ERROR) << "Set flags failed";
  217. return nullptr;
  218. }
  219. GenerateArgsNodeForFunction(pFunBlock, node);
  220. // when parsing the top graph of construct, save the top graph
  221. if (GetTopFuncGraph() == nullptr) {
  222. UpdateTopFuncGraph(pFunBlock->func_graph());
  223. }
  224. // save the function node to block
  225. pFunBlock->WriteVariable(function_name, NewValueNode(current_fg));
  226. py::object funcObj = python_adapter::GetPyObjAttr(node, "body");
  227. (void)ParseStatements(pFunBlock, funcObj);
  228. if (current_fg->get_return() == nullptr) {
  229. MS_LOG(ERROR) << "Graph return node is null, loc:" << GetLocation(node)->ToString();
  230. errcode_ = PARSE_NO_RETURN;
  231. return pFunBlock;
  232. }
  233. GenerateArgsDefaultValueForFunction(pFunBlock, node);
  234. return pFunBlock;
  235. }
  236. FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr fn_block, const py::object &nodes) {
  237. py::int_ pcount = python_adapter::CallPyObjMethod(nodes, "__len__");
  238. size_t count = IntToSize(pcount);
  239. MS_LOG(DEBUG) << "The nodes count is " << count;
  240. for (size_t i = 0; i < count; i++) {
  241. auto node = py::cast<py::list>(nodes)[i];
  242. TraceManager::DebugTrace(GetLocation(node));
  243. fn_block = ParseStatement(fn_block, node);
  244. TraceManager::EndTrace();
  245. // insert appropriate depended items for the function block if it has a return node
  246. if (fn_block->func_graph()->get_return() != nullptr) {
  247. fn_block->InsertDependItemsBeforeReturn();
  248. // Skip statements after 'return' (or 'break', 'continue').
  249. break;
  250. }
  251. }
  252. return fn_block;
  253. }
  254. FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) {
  255. auto node_type = ast_->GetNodeType(node);
  256. // check the node type
  257. AstMainType nodeType = node_type->main_type();
  258. if (nodeType != AST_MAIN_TYPE_STMT) {
  259. MS_LOG(INFO) << "Node type is error : " << nodeType;
  260. return block;
  261. }
  262. // call the process function
  263. std::string node_name = node_type->node_name();
  264. MS_LOG(DEBUG) << "Ast node is " << node_name;
  265. if (stmt_method_map_.count(node_name)) {
  266. TraceManager::DebugTrace(GetLocation(node));
  267. auto stmt_block = (this->*stmt_method_map_[node_name])(block, node);
  268. TraceManager::EndTrace();
  269. return stmt_block;
  270. } else {
  271. errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
  272. py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  273. if (location.size() < 2) {
  274. MS_LOG(EXCEPTION) << "List size should not be less than 2.";
  275. }
  276. auto filename = location[0].cast<std::string>();
  277. auto line_no = location[1].cast<int>();
  278. MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no;
  279. }
  280. }
  281. AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) {
  282. MS_LOG(DEBUG) << "Process ast expr";
  283. auto node_type = ast_->GetNodeType(node);
  284. // check the node type
  285. AstMainType node_main_type = node_type->main_type();
  286. if (node_main_type != AST_MAIN_TYPE_EXPR) {
  287. MS_LOG(ERROR) << "Node type is error : " << node_main_type;
  288. errcode_ = PARSE_NODE_TYPE_NO_MATCH;
  289. return nullptr;
  290. }
  291. // call the process function
  292. std::string node_name = node_type->node_name();
  293. MS_LOG(DEBUG) << "Ast node is " << node_name;
  294. if (expr_method_map_.count(node_name)) {
  295. TraceManager::DebugTrace(GetLocation(node));
  296. auto expr_node = (this->*expr_method_map_[node_name])(block, node);
  297. TraceManager::EndTrace();
  298. return expr_node;
  299. } else {
  300. errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
  301. py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  302. auto filename = ret[0].cast<std::string>();
  303. auto line_no = ret[1].cast<int>();
  304. MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no;
  305. }
  306. }
  307. // process the expr statement and expand it
  308. // eg: x.append(y) -> x = x.append(y)
  309. FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) {
  310. MS_LOG(DEBUG) << "Process ast Expr";
  311. // Expr only have value , no target
  312. py::tuple expand_info = ast_->CallParserObjMethod(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node);
  313. // refer python function expand_expr_statement, expand_info is one of the following:
  314. // True, expr.value, x
  315. // True, expr.value
  316. // False, None, None
  317. // check the expand info result
  318. auto is_expand = py::cast<bool>(expand_info[0]);
  319. if (is_expand) {
  320. // process the expr statement
  321. py::object value_object = expand_info[1];
  322. AnfNodePtr value_node = ParseExprNode(block, value_object);
  323. if (py::len(expand_info) == 2) {
  324. // add to depend list and insert before output
  325. block->AddAutoDepend(value_node);
  326. } else {
  327. // expand the assign statement
  328. py::object target_node = expand_info[2];
  329. WriteAssignVars(block, target_node, value_node);
  330. }
  331. }
  332. return block;
  333. }
  334. LocationPtr Parser::GetLocation(const py::object &node) const {
  335. MS_EXCEPTION_IF_NULL(ast_);
  336. py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  337. if (ret.size() < 5) {
  338. MS_LOG(EXCEPTION) << "List size should not be less than 5.";
  339. }
  340. // refer to Location::Location() for each member of ret: line, column, line_end, column_end.
  341. auto location = std::make_shared<Location>(ret[0].cast<std::string>(), ret[1].cast<int>(), ret[2].cast<int>(),
  342. ret[3].cast<int>(), ret[4].cast<int>());
  343. return location;
  344. }
  345. void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
  346. const FunctionBlockPtr &false_block) {
  347. true_block->AddPrevBlock(pre_block);
  348. true_block->Mature();
  349. false_block->AddPrevBlock(pre_block);
  350. false_block->Mature();
  351. }
  352. FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) {
  353. MS_LOG(DEBUG) << "Process ast return";
  354. MS_EXCEPTION_IF_NULL(block);
  355. // create return valuenode
  356. AnfNodePtr pReturnValueNode = NewValueNode(prim::kPrimReturn);
  357. // parse the return Statements value
  358. py::object value = python_adapter::GetPyObjAttr(node, "value");
  359. AnfNodePtr pReturnStatementNode = ParseExprNode(block, value);
  360. // Create the cnode
  361. CNodePtr pReturnCNode = block->func_graph()->NewCNode({pReturnValueNode, pReturnStatementNode});
  362. block->func_graph()->set_return(pReturnCNode);
  363. return block;
  364. }
  365. // Process binary operators,eg: `a + b`, `a | b`, etc.
  366. AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) {
  367. MS_LOG(DEBUG) << "Process ast BinOP";
  368. py::object left = python_adapter::GetPyObjAttr(node, "left");
  369. py::object right = python_adapter::GetPyObjAttr(node, "right");
  370. py::object op = python_adapter::GetPyObjAttr(node, "op");
  371. // create left and right ANF node
  372. AnfNodePtr left_node = ParseExprNode(block, left);
  373. if (left_node == nullptr) {
  374. MS_LOG(WARNING) << "DoBinOp process left node failed: " << errcode();
  375. return nullptr;
  376. }
  377. AnfNodePtr right_node = ParseExprNode(block, right);
  378. if (right_node == nullptr) {
  379. MS_LOG(WARNING) << "DoBinOp process right node failed:" << errcode();
  380. return nullptr;
  381. }
  382. // resolve the op
  383. AnfNodePtr op_node = block->MakeResolveAstOp(op);
  384. // create apply node
  385. return block->func_graph()->NewCNode({op_node, left_node, right_node});
  386. }
  387. AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) {
  388. MS_LOG(DEBUG) << "Process ast Name";
  389. auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
  390. MS_LOG(DEBUG) << "The Name id is " << name_id;
  391. TraceGuard trace_guard(GetLocation(node));
  392. if (block->IsGlobalVar(name_id)) {
  393. return block->MakeResolveSymbol(name_id);
  394. }
  395. return block->ReadVariable(name_id);
  396. }
  397. AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) {
  398. MS_LOG(DEBUG) << "Process ast NoneType";
  399. return NewValueNode(kNone);
  400. }
  401. AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) {
  402. MS_LOG(DEBUG) << "Process ast Ellipsis";
  403. return NewValueNode(kEllipsis);
  404. }
  405. AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
  406. MS_LOG(DEBUG) << "Process ast Num";
  407. py::object obj = python_adapter::GetPyObjAttr(node, "n");
  408. TraceGuard trace_guard(GetLocation(node));
  409. if (py::isinstance<py::int_>(obj)) {
  410. MS_LOG(INFO) << "The Num is int:" << (std::string)py::str(obj);
  411. auto data = py::cast<int>(obj);
  412. return NewValueNode(data);
  413. } else if (py::isinstance<py::float_>(obj)) {
  414. MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj);
  415. auto data = py::cast<float>(obj);
  416. return NewValueNode(data);
  417. } else {
  418. // no else actually
  419. MS_LOG(ERROR) << "Unsupported Num type : " << (std::string)py::str(obj) << GetLocation(node)->ToString();
  420. errcode_ = PARSE_NODE_TYPE_UNKOWN;
  421. return nullptr;
  422. }
  423. }
  424. AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) {
  425. MS_LOG(DEBUG) << "Process ast Str";
  426. auto str_s = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "s"));
  427. return NewValueNode(str_s);
  428. }
  429. AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) {
  430. MS_LOG(DEBUG) << "Process ast NameConstant";
  431. py::object obj = python_adapter::GetPyObjAttr(node, "value");
  432. TraceGuard trace_guard(GetLocation(node));
  433. if (py::isinstance<py::bool_>(obj)) {
  434. MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj);
  435. auto data = py::cast<bool>(obj);
  436. return NewValueNode(data);
  437. } else if (py::isinstance<py::none>(obj)) {
  438. MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj);
  439. return NewValueNode(kNone);
  440. } else {
  441. // no else actually
  442. MS_LOG(ERROR) << "Unsupported NameConstant type: " << (std::string)py::str(obj) << GetLocation(node)->ToString();
  443. errcode_ = PARSE_NODE_TYPE_UNKOWN;
  444. return nullptr;
  445. }
  446. }
  447. AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes) {
  448. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  449. std::vector<AnfNodePtr> make_tuple_nodes;
  450. make_tuple_nodes.push_back(make_tuple_op);
  451. (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes),
  452. [](AnfNodePtr arg) -> AnfNodePtr { return arg; });
  453. return block->func_graph()->NewCNode(make_tuple_nodes);
  454. }
  455. AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) {
  456. py::object father_class;
  457. if (args.empty()) {
  458. father_class = py::none();
  459. } else if (args.size() == 2) {
  460. father_class = args[0];
  461. auto arg_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[1])));
  462. if (arg_type != AST_SUB_TYPE_NAME || py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) != "self") {
  463. MS_EXCEPTION(ArgumentError) << "When call 'super', the second arg should be 'self'.";
  464. }
  465. } else {
  466. MS_EXCEPTION(ArgumentError) << "When call 'super', the args number should be 0 or 2, but got" << args.size() << ".";
  467. }
  468. py::object target_class_instance = ast()->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast()->obj());
  469. py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
  470. NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
  471. SymbolPtr symbol = std::make_shared<Symbol>("namespace");
  472. return block->MakeResolve(name_space, symbol);
  473. }
  474. // process function call, eg : f1(x, y) ...
  475. AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) {
  476. MS_LOG(DEBUG) << "Process ast Call";
  477. // process function call
  478. py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func");
  479. py::list args = python_adapter::GetPyObjAttr(node, "args");
  480. auto arg_type =
  481. AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, function_ast_node)));
  482. if (arg_type == AST_SUB_TYPE_NAME) {
  483. auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id"));
  484. if (name_id == "super") {
  485. return ParseSuper(block, args);
  486. }
  487. }
  488. AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node);
  489. // function call arguments should be passed in as groups and unpacked later using unpack call
  490. std::vector<AnfNodePtr> packed_arguments;
  491. std::vector<AnfNodePtr> group_arguments;
  492. bool need_unpack_args = ParseArgsInCall(block, args, &packed_arguments, &group_arguments);
  493. bool need_unpack_keywords = ParseKeywordsInCall(block, node, &packed_arguments);
  494. // if there is stared or keyword argument, unpack may be needed
  495. bool need_unpack = need_unpack_args || need_unpack_keywords;
  496. return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack);
  497. }
  498. AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
  499. const std::vector<AnfNodePtr> &packed_arguments,
  500. const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
  501. // if there is keyword arguments or starred, using an unpack_call op to unpack the argument
  502. if (need_unpack) {
  503. std::vector<AnfNodePtr> unpack_call_nodes;
  504. auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
  505. unpack_call_nodes.push_back(unpack_call_op);
  506. unpack_call_nodes.push_back(call_function_anf_node);
  507. (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes),
  508. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  509. CNodePtr unpack_call = block->func_graph()->NewCNode(unpack_call_nodes);
  510. return unpack_call;
  511. }
  512. // else there is no keyword arguments and starred, parsed as normal arguments without unpack
  513. std::vector<AnfNodePtr> func_call_nodes;
  514. func_call_nodes.push_back(call_function_anf_node);
  515. (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes),
  516. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  517. CNodePtr call_anf_node = block->func_graph()->NewCNode(func_call_nodes);
  518. return call_anf_node;
  519. }
  520. bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args,
  521. std::vector<AnfNodePtr> *packed_arguments, std::vector<AnfNodePtr> *group_arguments) {
  522. bool need_unpack = false;
  523. for (size_t i = 0; i < args.size(); i++) {
  524. auto arg_node = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[i])));
  525. if (arg_node == AST_SUB_TYPE_STARRED) {
  526. if (!group_arguments->empty()) {
  527. packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments));
  528. }
  529. packed_arguments->push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value")));
  530. group_arguments->clear();
  531. need_unpack = true;
  532. } else {
  533. group_arguments->push_back(ParseExprNode(block, args[i]));
  534. }
  535. }
  536. if (!group_arguments->empty()) {
  537. packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments));
  538. }
  539. return need_unpack;
  540. }
  541. bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node,
  542. std::vector<AnfNodePtr> *packed_arguments) {
  543. bool need_unpack = false;
  544. py::list keywords = python_adapter::GetPyObjAttr(node, "keywords");
  545. if (!keywords.empty()) {
  546. need_unpack = true;
  547. std::vector<AnfNodePtr> keys;
  548. std::vector<AnfNodePtr> values;
  549. for (size_t index = 0; index < keywords.size(); index++) {
  550. auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg");
  551. auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value");
  552. if (py::isinstance<py::none>(kw_key)) {
  553. packed_arguments->push_back(ParseExprNode(block, kw_value));
  554. } else {
  555. auto kw_key_c = kw_key.cast<std::string>();
  556. keys.push_back(NewValueNode(kw_key_c));
  557. values.push_back(ParseExprNode(block, kw_value));
  558. }
  559. }
  560. auto keys_tuple = GenerateMakeTuple(block, keys);
  561. auto values_tuple = GenerateMakeTuple(block, values);
  562. auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
  563. std::vector<AnfNodePtr> make_dict_nodes;
  564. make_dict_nodes.push_back(make_dict_op);
  565. make_dict_nodes.push_back(keys_tuple);
  566. make_dict_nodes.push_back(values_tuple);
  567. packed_arguments->push_back(block->func_graph()->NewCNode(make_dict_nodes));
  568. }
  569. return need_unpack;
  570. }
  571. // process call attributes of class type define, eg: x.y()
  572. AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) {
  573. MS_LOG(DEBUG) << "Process ast Attribute";
  574. // process class value,eg: self.xx
  575. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  576. if (ast_->IsClassMember(node)) {
  577. std::string var_name = "self.";
  578. std::string attr_name = node.attr("attr").cast<std::string>();
  579. (void)var_name.append(attr_name);
  580. auto attr_obj = ast()->obj().attr(attr_name.c_str());
  581. if (py::hasattr(ast()->obj(), attr_name.c_str()) &&
  582. (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance<py::int_>(attr_obj) ||
  583. py::isinstance<py::float_>(attr_obj) || py::isinstance<py::bool_>(attr_obj) ||
  584. py::isinstance<py::str>(attr_obj) || data_converter::IsCellInstance(attr_obj))) {
  585. return block->MakeResolveSymbol(var_name);
  586. } else {
  587. return block->ReadVariable(var_name);
  588. }
  589. }
  590. }
  591. // process the get attr
  592. // Use the Primitive replace the operation resolve node (getattr)
  593. // because the getattr will eventually be converted to Primitive node
  594. AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr);
  595. // process the attr body
  596. py::object value_body = python_adapter::GetPyObjAttr(node, "value");
  597. AnfNodePtr value_node = ParseExprNode(block, value_body);
  598. if (value_node == nullptr) {
  599. MS_LOG(WARNING) << "Parse attribute failed";
  600. return nullptr;
  601. }
  602. // process the node attr
  603. auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast<std::string>();
  604. MS_LOG(DEBUG) << "Attr = " << attr_str;
  605. TraceManager::DebugTrace(GetLocation(python_adapter::GetPyObjAttr(node, "attr")));
  606. AnfNodePtr attr_node = NewValueNode(attr_str);
  607. TraceManager::EndTrace();
  608. // create the apply node
  609. return block->func_graph()->NewCNode({op_node, value_node, attr_node});
  610. }
  611. // Process comparison expression : a == b. a > b etc.
  612. AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) {
  613. MS_LOG(DEBUG) << "Process ast Compare";
  614. // for python comparison ,there may be if x>y>5 ,
  615. // which there is two ops , but we only support one now
  616. py::list ops = python_adapter::GetPyObjAttr(node, "ops");
  617. if (ops.size() > MAX_COMPARISON_OPS_SUPPORTED) {
  618. MS_LOG(ERROR) << "MindSpore does not support comparison with operators more than one now, ops size =" << ops.size();
  619. return nullptr;
  620. }
  621. py::object left = python_adapter::GetPyObjAttr(node, "left");
  622. py::list comparators = python_adapter::GetPyObjAttr(node, "comparators");
  623. AnfNodePtr left_node = ParseExprNode(block, left);
  624. AnfNodePtr right_node = ParseExprNode(block, comparators[0]);
  625. MS_EXCEPTION_IF_NULL(block);
  626. AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]);
  627. return block->func_graph()->NewCNode({op_node, left_node, right_node});
  628. }
  629. AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list,
  630. const py::object &op) {
  631. // if there is only one bool op now
  632. if (value_list.size() == 1) {
  633. AnfNodePtr first_node = ParseExprNode(block, value_list[0]);
  634. return first_node;
  635. } else {
  636. py::object first = value_list[0];
  637. py::list rest;
  638. for (size_t i = 1; i < value_list.size(); i++) {
  639. rest.append(value_list[i]);
  640. }
  641. AnfNodePtr first_node = ParseExprNode(block, first);
  642. AnfNodePtr rest_node = ProcessBoolOpValueList(block, rest, op);
  643. auto op_node = block->MakeResolveAstOp(op);
  644. return block->func_graph()->NewCNode({op_node, first_node, rest_node});
  645. }
  646. }
  647. // Process comparison expression : a and b. a or b .
  648. AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) {
  649. MS_LOG(DEBUG) << "Process ast BoolOp";
  650. py::object op_node = python_adapter::GetPyObjAttr(node, "op");
  651. py::list op_values = python_adapter::GetPyObjAttr(node, "values");
  652. return ProcessBoolOpValueList(block, op_values, op_node);
  653. }
  654. // Process a function def
  655. FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
  656. MS_LOG(DEBUG) << "Process ast FunctionDef";
  657. FunctionBlockPtr function_block = ParseFunction(node, block);
  658. MS_EXCEPTION_IF_NULL(function_block);
  659. // get function name
  660. py::str name = python_adapter::GetPyObjAttr(node, "name");
  661. std::string function_name = name;
  662. ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph());
  663. block->WriteVariable(function_name, valuenode_graph);
  664. return block;
  665. }
  666. // Process a lambda expression . like lambda x,y: x + y
  667. AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
  668. MS_LOG(DEBUG) << "Process ast Lambda";
  669. FunctionBlockPtr func_block = MakeFunctionBlock(*this);
  670. func_block->AddPrevBlock(block);
  671. func_block->Mature();
  672. // get lambda args
  673. py::list args = ast_->GetArgs(node);
  674. for (std::size_t i = 0; i < args.size(); i++) {
  675. std::string arg = py::cast<std::string>(args[i].attr("arg"));
  676. TraceManager::DebugTrace(GetLocation(args[i]));
  677. auto para_node = std::make_shared<Parameter>(func_block->func_graph());
  678. TraceManager::EndTrace();
  679. para_node->debug_info()->set_name(arg);
  680. func_block->func_graph()->add_parameter(para_node);
  681. func_block->WriteVariable(arg, para_node);
  682. MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg;
  683. }
  684. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  685. AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
  686. func_block->func_graph()->set_output(lambda_body_node);
  687. ValueNodePtr const_graph = NewValueNode(func_block->func_graph());
  688. return const_graph;
  689. }
  690. // process a tuple
  691. AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) {
  692. MS_LOG(DEBUG) << "Process ast Tuple";
  693. MS_EXCEPTION_IF_NULL(block);
  694. py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
  695. if (elts.size() == 0) {
  696. auto empty_tuple = std::vector<ValuePtr>();
  697. return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
  698. }
  699. std::vector<AnfNodePtr> tuple_vec;
  700. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  701. tuple_vec.emplace_back(make_tuple_op);
  702. for (size_t i = 0; i < elts.size(); i++) {
  703. AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
  704. tuple_vec.emplace_back(node_ptr);
  705. }
  706. CNodePtr tuple_app = block->func_graph()->NewCNode(tuple_vec);
  707. return tuple_app;
  708. }
  709. // process a list
  710. AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) {
  711. MS_LOG(DEBUG) << "Process ast List";
  712. MS_EXCEPTION_IF_NULL(block);
  713. py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
  714. if (elts.size() == 0) {
  715. auto empty_list = std::vector<ValuePtr>();
  716. return NewValueNode(std::make_shared<ValueList>(empty_list));
  717. }
  718. std::vector<AnfNodePtr> list_vec;
  719. AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
  720. list_vec.emplace_back(make_list_op);
  721. for (size_t i = 0; i < elts.size(); i++) {
  722. AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
  723. list_vec.emplace_back(node_ptr);
  724. }
  725. CNodePtr list_app = block->func_graph()->NewCNode(list_vec);
  726. return list_app;
  727. }
  728. // process a subscript, such as x[y] , node expressed as value[slice]
  729. AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) {
  730. MS_LOG(DEBUG) << "Process ast Subscript";
  731. MS_EXCEPTION_IF_NULL(block);
  732. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  733. py::object value_node = python_adapter::GetPyObjAttr(node, "value");
  734. py::object slice_node = python_adapter::GetPyObjAttr(node, "slice");
  735. AnfNodePtr value = ParseExprNode(block, value_node);
  736. AnfNodePtr slice = ParseExprNode(block, slice_node);
  737. return block->func_graph()->NewCNode({op_getitem, value, slice});
  738. }
  739. // process a slice, get the slice value
  740. AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) {
  741. MS_LOG(DEBUG) << "Process ast Slice";
  742. MS_EXCEPTION_IF_NULL(block);
  743. AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE);
  744. py::object start = python_adapter::GetPyObjAttr(node, "lower");
  745. py::object stop = python_adapter::GetPyObjAttr(node, "upper");
  746. py::object step = python_adapter::GetPyObjAttr(node, "step");
  747. AnfNodePtr start_node = ParseExprNode(block, start);
  748. AnfNodePtr stop_node = ParseExprNode(block, stop);
  749. AnfNodePtr step_node = ParseExprNode(block, step);
  750. return block->func_graph()->NewCNode({op_makeslice, start_node, stop_node, step_node});
  751. }
  752. // process a extslice
  753. AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) {
  754. MS_LOG(DEBUG) << "Process ast ExtSlice";
  755. MS_EXCEPTION_IF_NULL(block);
  756. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  757. py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims");
  758. std::vector<AnfNodePtr> node_vec;
  759. node_vec.emplace_back(make_tuple_op);
  760. for (size_t i = 0; i < slice_tuple.size(); i++) {
  761. AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]);
  762. node_vec.emplace_back(node_ptr);
  763. }
  764. CNodePtr tuple_conde = block->func_graph()->NewCNode(node_vec);
  765. return tuple_conde;
  766. }
  767. // process a index, get the index number
  768. AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) {
  769. MS_LOG(DEBUG) << "Process ast Index";
  770. py::object value_node = python_adapter::GetPyObjAttr(node, "value");
  771. return ParseExprNode(block, value_node);
  772. }
  773. // process a UnaryOp, +a, -b
  774. AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) {
  775. MS_LOG(DEBUG) << "Process ast UnaryOp";
  776. py::object op = python_adapter::GetPyObjAttr(node, "op");
  777. MS_EXCEPTION_IF_NULL(block);
  778. // resolve the op
  779. AnfNodePtr op_node = block->MakeResolveAstOp(op);
  780. py::object operand = python_adapter::GetPyObjAttr(node, "operand");
  781. AnfNodePtr operand_node = ParseExprNode(block, operand);
  782. return block->func_graph()->NewCNode({op_node, operand_node});
  783. }
  784. // process a dict ast node expression
  785. AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
  786. MS_LOG(DEBUG) << "Process ast Dict";
  787. py::list keys = node.attr("keys");
  788. py::list values = node.attr("values");
  789. std::vector<AnfNodePtr> key_nodes;
  790. std::vector<AnfNodePtr> value_nodes;
  791. for (size_t i = 0; i < keys.size(); i++) {
  792. key_nodes.push_back(ParseExprNode(block, keys[i]));
  793. value_nodes.push_back(ParseExprNode(block, values[i]));
  794. }
  795. auto keys_tuple = GenerateMakeTuple(block, key_nodes);
  796. auto values_tuple = GenerateMakeTuple(block, value_nodes);
  797. MS_EXCEPTION_IF_NULL(block);
  798. auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
  799. return block->func_graph()->NewCNode({make_dict_op, keys_tuple, values_tuple});
  800. }
  801. // process a augment assign such as a += b;
  802. FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) {
  803. MS_LOG(DEBUG) << "Process ast AugAssign";
  804. py::object op = python_adapter::GetPyObjAttr(node, "op");
  805. MS_EXCEPTION_IF_NULL(block);
  806. // resolve the op
  807. AnfNodePtr op_node = block->MakeResolveAstOp(op);
  808. py::object target_node = python_adapter::GetPyObjAttr(node, "target");
  809. MS_EXCEPTION_IF_NULL(ast_);
  810. auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, target_node)));
  811. AnfNodePtr read_node = nullptr;
  812. if (ast_type == AST_SUB_TYPE_NAME) {
  813. read_node = ParseName(block, target_node);
  814. } else if (ast_->IsClassMember(target_node)) {
  815. read_node = ParseAttribute(block, target_node);
  816. } else {
  817. MS_LOG(EXCEPTION) << "Not supported augassign";
  818. }
  819. if (read_node == nullptr) {
  820. MS_LOG(EXCEPTION) << "Can not get target node ";
  821. }
  822. py::object value = python_adapter::GetPyObjAttr(node, "value");
  823. AnfNodePtr value_node = ParseExprNode(block, value);
  824. CNodePtr augassign_app = block->func_graph()->NewCNode({op_node, read_node, value_node});
  825. WriteAssignVars(block, target_node, augassign_app);
  826. return block;
  827. }
  828. // process global declaration such as 'global x';
  829. FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) {
  830. MS_LOG(DEBUG) << "Process ast Global";
  831. MS_EXCEPTION_IF_NULL(block);
  832. py::list vars = python_adapter::GetPyObjAttr(node, "names");
  833. for (auto &item : vars) {
  834. block->AddGlobalVar(py::cast<std::string>(item));
  835. }
  836. return block;
  837. }
  838. // process a if statement
  839. FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) {
  840. MS_LOG(DEBUG) << "Process ast If";
  841. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  842. AnfNodePtr condition_node = ParseExprNode(block, test_node);
  843. MS_EXCEPTION_IF_NULL(block);
  844. CNodePtr bool_node = block->ForceToBoolNode(condition_node);
  845. TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
  846. FunctionBlockPtr true_block = MakeFunctionBlock(*this);
  847. TraceManager::EndTrace();
  848. TraceManager::DebugTrace(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
  849. FunctionBlockPtr false_block = MakeFunctionBlock(*this);
  850. TraceManager::EndTrace();
  851. MakeConditionBlocks(block, true_block, false_block);
  852. TraceManager::DebugTrace(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
  853. FunctionBlockPtr after_block = MakeFunctionBlock(*this);
  854. TraceManager::EndTrace();
  855. // process the if-true branch
  856. py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
  857. FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode);
  858. // if the return_ is set ,it has its own continuation block
  859. if (true_end->func_graph()->get_return() == nullptr) {
  860. true_end->Jump(after_block, nullptr);
  861. }
  862. // process the orelse branch
  863. py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
  864. FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode);
  865. // if the return_ is set ,it has its own continuation block
  866. if (false_end->func_graph()->get_return() == nullptr) {
  867. false_end->Jump(after_block, nullptr);
  868. }
  869. block->ConditionalJump(bool_node, true_block, false_block);
  870. after_block->Mature();
  871. return after_block;
  872. }
  873. FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) {
  874. MS_LOG(DEBUG) << "Process ast While";
  875. MS_EXCEPTION_IF_NULL(block);
  876. MS_LOG(INFO) << "Parse while statement";
  877. TraceManager::DebugTrace(std::make_shared<TraceWhileHeader>(block->func_graph()->debug_info()));
  878. FunctionBlockPtr header_block = MakeFunctionBlock(*this);
  879. TraceManager::EndTrace();
  880. TraceManager::DebugTrace(std::make_shared<TraceWhileBody>(block->func_graph()->debug_info()));
  881. FunctionBlockPtr body_block = MakeFunctionBlock(*this);
  882. TraceManager::EndTrace();
  883. TraceManager::DebugTrace(std::make_shared<TraceWhileAfter>(block->func_graph()->debug_info()));
  884. FunctionBlockPtr after_block = MakeFunctionBlock(*this);
  885. TraceManager::EndTrace();
  886. body_block->AddPrevBlock(header_block);
  887. after_block->AddPrevBlock(header_block);
  888. block->Jump(header_block, nullptr);
  889. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  890. AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
  891. condition_node = header_block->ForceToWhileCond(condition_node);
  892. body_block->Mature();
  893. header_block->ConditionalJump(condition_node, body_block, after_block);
  894. // Parse loop body statements with loop context.
  895. LoopContext loop_context{&loops_, header_block, nullptr};
  896. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  897. FunctionBlockPtr after_body = ParseStatements(body_block, body_node);
  898. if (after_body->func_graph()->get_return() == nullptr) {
  899. after_body->Jump(header_block, nullptr);
  900. }
  901. header_block->Mature();
  902. after_block->Mature();
  903. auto &end_block = loop_context.EndBlock();
  904. if (end_block) {
  905. // end_block exists if we encounter 'break' in loop body.
  906. after_block->Jump(end_block, nullptr);
  907. end_block->Mature();
  908. return end_block;
  909. }
  910. // No 'break', no end_block.
  911. return after_block;
  912. }
  913. CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::object &node,
  914. const AnfNodePtr &op_iter) {
  915. py::object iter_node = python_adapter::GetPyObjAttr(node, "iter");
  916. AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
  917. return block->func_graph()->NewCNode({op_iter, iter_anf_node});
  918. }
  919. CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
  920. const AnfNodePtr &op_hasnext) {
  921. MS_EXCEPTION_IF_NULL(header_block);
  922. return header_block->func_graph()->NewCNode({op_hasnext, iter_param});
  923. }
  924. FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
  925. TraceManager::DebugTrace(trace_info);
  926. FunctionBlockPtr body_block = MakeFunctionBlock(*this);
  927. TraceManager::EndTrace();
  928. return body_block;
  929. }
  930. // A for loop will generate 3 functions :the test, the body, and the continuation
  931. // for x in xs:
  932. // body
  933. // it is compiled to be following statement
  934. // if len(xs) < max_loop_cnt:
  935. // ParseForIter() // use iter to implement for loop, which always unroll loop
  936. // else:
  937. // ParseForLoop() // use loop var to implement for loop, which always sink loop
  938. FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
  939. MS_LOG(DEBUG) << "Process ast For, create an if else statement";
  940. MS_EXCEPTION_IF_NULL(block);
  941. // create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT'
  942. AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
  943. py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
  944. AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
  945. CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
  946. CNodePtr bool_node = block->func_graph()->NewCNode(
  947. {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(prim::MAX_FOR_LOOP_COUNT)});
  948. // create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
  949. TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
  950. FunctionBlockPtr true_block = MakeFunctionBlock(*this);
  951. TraceManager::EndTrace();
  952. TraceManager::DebugTrace(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
  953. FunctionBlockPtr false_block = MakeFunctionBlock(*this);
  954. TraceManager::EndTrace();
  955. MakeConditionBlocks(block, true_block, false_block);
  956. TraceManager::DebugTrace(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
  957. FunctionBlockPtr after_block = MakeFunctionBlock(*this);
  958. TraceManager::EndTrace();
  959. FunctionBlockPtr true_end = ParseForIter(true_block, node);
  960. true_end->Jump(after_block, nullptr);
  961. FunctionBlockPtr false_end = ParseForLoop(false_block, node);
  962. false_end->Jump(after_block, nullptr);
  963. block->ConditionalJump(bool_node, true_block, false_block);
  964. after_block->Mature();
  965. return after_block;
  966. }
  967. // A for loop will generate 3 functions :the test, the body, and the continuation
  968. // for x in xs:
  969. // body
  970. // it is compiled to be following statement
  971. // it = iter(xs)
  972. // while hastnext(it)
  973. // x, it = next(it)
  974. // body
  975. FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) {
  976. MS_LOG(DEBUG) << "Process ast For";
  977. MS_EXCEPTION_IF_NULL(block);
  978. AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
  979. AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
  980. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  981. AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
  982. // generate the iterator apply
  983. CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter);
  984. MS_EXCEPTION_IF_NULL(iter_apply);
  985. FunctionBlockPtr header_block =
  986. GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
  987. MS_EXCEPTION_IF_NULL(header_block);
  988. // generate the hasnext apply which is a condition
  989. ParameterPtr iter_param = header_block->func_graph()->add_parameter();
  990. CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext);
  991. // generate the body of the for statement
  992. FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
  993. MS_EXCEPTION_IF_NULL(body_block);
  994. body_block->AddPrevBlock(header_block);
  995. // generate the iterator next apply
  996. // process as following: `app = next(it); target = app[0]; it = app[1];`
  997. CNodePtr app = body_block->func_graph()->NewCNode({op_next, iter_param});
  998. CNodePtr target_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(0)});
  999. py::object target_node = python_adapter::GetPyObjAttr(node, "target");
  1000. CNodePtr iter2_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(1)});
  1001. WriteAssignVars(body_block, target_node, target_app);
  1002. // link the variable name with the target
  1003. auto it_info = std::make_shared<TraceIterator>(target_app->debug_info());
  1004. iter_param->debug_info()->set_trace_info(it_info);
  1005. iter2_app->debug_info()->set_trace_info(it_info);
  1006. iter_apply->debug_info()->set_trace_info(it_info);
  1007. TraceManager::DebugTrace(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
  1008. FunctionBlockPtr after_block = MakeFunctionBlock(*this);
  1009. MS_EXCEPTION_IF_NULL(after_block);
  1010. TraceManager::EndTrace();
  1011. after_block->AddPrevBlock(header_block);
  1012. block->Jump(header_block, iter_apply);
  1013. body_block->Mature();
  1014. header_block->ConditionalJump(cond_apply, body_block, after_block);
  1015. // Parse loop body statements with loop context.
  1016. LoopContext loop_context{&loops_, header_block, iter2_app};
  1017. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1018. FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
  1019. if (after_body_block->func_graph()->get_return() == nullptr) {
  1020. after_body_block->Jump(header_block, iter2_app);
  1021. }
  1022. header_block->Mature();
  1023. after_block->Mature();
  1024. auto &end_block = loop_context.EndBlock();
  1025. if (end_block) {
  1026. // end_block exists if we encounter 'break' in loop body.
  1027. after_block->Jump(end_block, nullptr);
  1028. end_block->Mature();
  1029. return end_block;
  1030. }
  1031. // No 'break', no end_block.
  1032. return after_block;
  1033. }
  1034. // A for loop will generate 3 functions :the test, the body, and the continuation
  1035. // for x in xs:
  1036. // body
  1037. // it is compiled to be following statement
  1038. // i = 0
  1039. // while i < len(xs)
  1040. // x = xs[i]
  1041. // i = i + 1
  1042. // body
  1043. FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) {
  1044. MS_LOG(DEBUG) << "Process ast For by loop variable";
  1045. MS_EXCEPTION_IF_NULL(block);
  1046. AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
  1047. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1048. // get varibale name of 'x' in statement 'for x in xs'
  1049. py::object target_node = python_adapter::GetPyObjAttr(node, "target");
  1050. // create statement 'len(xs)'
  1051. py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
  1052. AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
  1053. MS_EXCEPTION_IF_NULL(iter_node);
  1054. CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
  1055. FunctionBlockPtr header_block =
  1056. GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
  1057. MS_EXCEPTION_IF_NULL(header_block);
  1058. // create loop variable 'i'
  1059. ParameterPtr loop_var = header_block->func_graph()->add_parameter();
  1060. // create loop condition 'i < len(xs)'
  1061. CNodePtr cond_node = header_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), loop_var, len_iter});
  1062. // generate the body of the for statement
  1063. FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
  1064. MS_EXCEPTION_IF_NULL(body_block);
  1065. body_block->AddPrevBlock(header_block);
  1066. // create 'x = xs[i]'
  1067. CNodePtr target_var = body_block->func_graph()->NewCNode({op_getitem, iter_node, loop_var});
  1068. WriteAssignVars(body_block, target_node, target_var);
  1069. // create 'i = i + 1'
  1070. CNodePtr loop_var_inc =
  1071. body_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(1)});
  1072. body_block->WriteVariable(loop_var->name(), loop_var_inc);
  1073. // link the variable name with the target
  1074. auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
  1075. loop_var->debug_info()->set_trace_info(it_info);
  1076. len_iter->debug_info()->set_trace_info(it_info);
  1077. TraceManager::DebugTrace(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
  1078. FunctionBlockPtr after_block = MakeFunctionBlock(*this);
  1079. MS_EXCEPTION_IF_NULL(after_block);
  1080. TraceManager::EndTrace();
  1081. after_block->AddPrevBlock(header_block);
  1082. block->Jump(header_block, NewValueNode(0));
  1083. body_block->Mature();
  1084. header_block->ConditionalJump(cond_node, body_block, after_block, false);
  1085. // Parse loop body statements with loop context.
  1086. LoopContext loop_context{&loops_, header_block, loop_var_inc};
  1087. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1088. FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
  1089. if (after_body_block->func_graph()->get_return() == nullptr) {
  1090. after_body_block->Jump(header_block, loop_var_inc);
  1091. }
  1092. header_block->Mature();
  1093. after_block->Mature();
  1094. auto &end_block = loop_context.EndBlock();
  1095. if (end_block) {
  1096. // end_block exists if we encounter 'break' in loop body.
  1097. after_block->Jump(end_block, nullptr);
  1098. end_block->Mature();
  1099. return end_block;
  1100. }
  1101. // No 'break', no end_block.
  1102. return after_block;
  1103. }
  1104. AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) {
  1105. MS_LOG(DEBUG) << "Process ast IfExp";
  1106. MS_EXCEPTION_IF_NULL(block);
  1107. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  1108. AnfNodePtr condition_node = ParseExprNode(block, test_node);
  1109. CNodePtr bool_node = block->ForceToBoolNode(condition_node);
  1110. TraceManager::DebugTrace(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
  1111. FunctionBlockPtr true_block = MakeFunctionBlock(*this);
  1112. TraceManager::EndTrace();
  1113. TraceManager::DebugTrace(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
  1114. FunctionBlockPtr false_block = MakeFunctionBlock(*this);
  1115. TraceManager::EndTrace();
  1116. MakeConditionBlocks(block, true_block, false_block);
  1117. // process the if-true branch
  1118. py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
  1119. true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode));
  1120. AnfNodePtr true_node = ParseExprNode(true_block, bodyNode);
  1121. // process the orelse branch
  1122. py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
  1123. false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode));
  1124. AnfNodePtr false_node = ParseExprNode(false_block, orelseNode);
  1125. true_block->func_graph()->set_output(true_node);
  1126. false_block->func_graph()->set_output(false_node);
  1127. // Use the Primitive replace the operation resolve node (switch)
  1128. // because the switch will eventually be converted to Primitive node
  1129. CNodePtr switch_app =
  1130. block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), bool_node, NewValueNode(true_block->func_graph()),
  1131. NewValueNode(false_block->func_graph())});
  1132. std::vector<AnfNodePtr> call_graph_nodes{switch_app};
  1133. CNodePtr switch_app_call = block->func_graph()->NewCNode(call_graph_nodes);
  1134. return switch_app_call;
  1135. }
  1136. void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
  1137. MS_EXCEPTION_IF_NULL(block);
  1138. MS_EXCEPTION_IF_NULL(assigned_node);
  1139. py::str name = python_adapter::GetPyObjAttr(targ, "id");
  1140. std::string name_id = name;
  1141. assigned_node->debug_info()->set_name(name_id);
  1142. // set the debug name of the constant graph
  1143. if (IsValueNode<FuncGraph>(assigned_node)) {
  1144. // the value should be graph
  1145. auto fg = GetValueNode<FuncGraphPtr>(assigned_node);
  1146. if (fg->debug_info()->name().empty()) {
  1147. fg->debug_info()->set_name(name_id);
  1148. }
  1149. }
  1150. block->WriteVariable(name_id, assigned_node);
  1151. }
  1152. void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
  1153. MS_EXCEPTION_IF_NULL(block);
  1154. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1155. py::list items = python_adapter::GetPyObjAttr(targ, "elts");
  1156. for (size_t i = 0; i < items.size(); i++) {
  1157. // Use the Primitive replace the operation resolve node (getitem)
  1158. // because the getitem will eventually be converted to Primitive node
  1159. CNodePtr item_apply = block->func_graph()->NewCNode({op_getitem, assigned_node, NewValueNode(static_cast<int>(i))});
  1160. py::object elt = items[i];
  1161. WriteAssignVars(block, elt, item_apply);
  1162. }
  1163. }
  1164. void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ,
  1165. const AnfNodePtr &assigned_node) {
  1166. // Now only support the self.xx = xxxxx, can't support x.y = xxxx
  1167. AnfNodePtr target_node = ParseExprNode(block, targ);
  1168. MS_EXCEPTION_IF_NULL(target_node);
  1169. std::string attr_name = targ.attr("attr").cast<std::string>();
  1170. std::string var_name = "self.";
  1171. (void)var_name.append(attr_name);
  1172. MS_LOG(DEBUG) << "assign " << var_name;
  1173. // Get targ location info for error printing
  1174. py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, targ);
  1175. if (location.size() < 2) {
  1176. MS_LOG(EXCEPTION) << "List size should not be less than 2.";
  1177. }
  1178. auto filename = location[0].cast<std::string>();
  1179. auto line_no = location[1].cast<int>();
  1180. // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
  1181. if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
  1182. MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined, at " << filename << ":"
  1183. << line_no;
  1184. }
  1185. auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
  1186. auto obj_type = obj.attr("__class__").attr("__name__");
  1187. if (!py::hasattr(obj, "__parameter__")) {
  1188. MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '"
  1189. << py::str(obj).cast<std::string>() << "' with type '"
  1190. << py::str(obj_type).cast<std::string>() << "' at " << filename << ":" << line_no;
  1191. }
  1192. MS_EXCEPTION_IF_NULL(block);
  1193. block->WriteVariable(var_name, assigned_node);
  1194. MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString();
  1195. block->SetStateAssgin(target_node, var_name);
  1196. }
  1197. void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ,
  1198. const AnfNodePtr &assigned_node) {
  1199. MS_EXCEPTION_IF_NULL(block);
  1200. AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
  1201. py::object value_obj = python_adapter::GetPyObjAttr(targ, "value");
  1202. py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice");
  1203. AnfNodePtr value_node = ParseExprNode(block, value_obj);
  1204. AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
  1205. CNodePtr setitem_app = block->func_graph()->NewCNode({op_setitem, value_node, slice_node, assigned_node});
  1206. // getitem apply should return the sequence data structure itself
  1207. std::string var_name = "";
  1208. if (ast_->IsClassMember(value_obj)) {
  1209. std::string attr_name = value_obj.attr("attr").cast<std::string>();
  1210. var_name = "self." + attr_name;
  1211. if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
  1212. MS_EXCEPTION(TypeError) << "'" << var_name << "' was not defined in the class '__init__' function.";
  1213. }
  1214. auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
  1215. auto obj_type = obj.attr("__class__").attr("__name__");
  1216. if (!py::hasattr(obj, "__parameter__")) {
  1217. MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '"
  1218. << py::str(obj).cast<std::string>() << "' with type '"
  1219. << py::str(obj_type).cast<std::string>() << "'.";
  1220. }
  1221. } else {
  1222. var_name = value_obj.attr("id").cast<std::string>();
  1223. }
  1224. block->WriteVariable(var_name, setitem_app);
  1225. }
  1226. void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node) {
  1227. MS_EXCEPTION_IF_NULL(value_node);
  1228. MS_LOG(DEBUG) << "Process WriteAssignVars";
  1229. auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, targ)));
  1230. if (ast_type == AST_SUB_TYPE_NAME) {
  1231. HandleAssignName(block, targ, value_node);
  1232. } else if (ast_type == AST_SUB_TYPE_TUPLE) {
  1233. HandleAssignTuple(block, targ, value_node);
  1234. } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
  1235. HandleAssignSubscript(block, targ, value_node);
  1236. } else if (ast_->IsClassMember(targ)) {
  1237. HandleAssignClassMember(block, targ, value_node);
  1238. } else {
  1239. MS_LOG(EXCEPTION) << "Not supported assign type: " << ast_type
  1240. << " NodeInfo: " << trace::GetDebugInfo(value_node->debug_info());
  1241. }
  1242. }
  1243. // process a assign statement, such as a =b, a,b = tup
  1244. FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
  1245. MS_LOG(DEBUG) << "Process ast assgin";
  1246. py::object value_object = python_adapter::GetPyObjAttr(node, "value");
  1247. AnfNodePtr value_node = ParseExprNode(block, value_object);
  1248. py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
  1249. py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__");
  1250. size_t count = IntToSize(pcount);
  1251. MS_LOG(DEBUG) << "The nodes count is " << count;
  1252. for (size_t i = 0; i < count; i++) {
  1253. auto target_node = py::cast<py::list>(targets_object)[i];
  1254. WriteAssignVars(block, target_node, value_node);
  1255. }
  1256. return block;
  1257. }
  1258. FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) {
  1259. if (loops_.empty()) {
  1260. // Report error if loop context not set for the 'break' statement.
  1261. py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  1262. if (location.size() < 2) {
  1263. MS_LOG(EXCEPTION) << "List size should not be less than 2.";
  1264. }
  1265. auto filename = location[0].cast<std::string>();
  1266. auto line_no = location[1].cast<int>();
  1267. MS_LOG(EXCEPTION) << "Unexpected 'break' at " << filename << ":" << line_no;
  1268. }
  1269. // Get current loop.
  1270. Loop &loop = loops_.top();
  1271. if (loop.end == nullptr) {
  1272. // Create end_block if it is not existed.
  1273. TraceManager::DebugTrace(std::make_shared<TraceLoopEnd>(block->func_graph()->debug_info()));
  1274. loop.end = MakeFunctionBlock(*this);
  1275. TraceManager::EndTrace();
  1276. }
  1277. // Jump to the end_block.
  1278. block->Jump(loop.end, nullptr);
  1279. return block;
  1280. }
  1281. FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) {
  1282. if (loops_.empty()) {
  1283. // Report error if loop context not set for the 'continue' statement.
  1284. py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  1285. if (location.size() < 2) {
  1286. MS_LOG(EXCEPTION) << "List size should not be less than 2.";
  1287. }
  1288. auto filename = location[0].cast<std::string>();
  1289. auto line_no = location[1].cast<int>();
  1290. MS_LOG(EXCEPTION) << "Unexpected 'continue' at " << filename << ":" << line_no;
  1291. }
  1292. // Jump to the header of the loop with iterator called.
  1293. Loop &loop = loops_.top();
  1294. block->Jump(loop.header, loop.iterator);
  1295. return block;
  1296. }
  1297. FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) {
  1298. // We just bypass 'pass' statement.
  1299. return block;
  1300. }
  1301. AnfNodePtr FindPhis(const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis, const AnfNodePtr &node) {
  1302. const auto &inp = node->cast<ParameterPtr>();
  1303. const auto &iter = removable_phis.find(inp);
  1304. if (iter == removable_phis.end()) {
  1305. return node;
  1306. }
  1307. return FindPhis(removable_phis, iter->second);
  1308. }
  1309. void Parser::RemoveUnnecessaryPhis() {
  1310. // merge all removable phis to one map;
  1311. std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis;
  1312. std::vector<ParameterPtr> phis;
  1313. for (FunctionBlockPtr &block : func_block_list_) {
  1314. MS_EXCEPTION_IF_NULL(block);
  1315. removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end());
  1316. std::transform(block->removable_phis().begin(), block->removable_phis().end(), std::back_inserter(phis),
  1317. [](std::pair<ParameterPtr, AnfNodePtr> pair) { return pair.first; });
  1318. }
  1319. if (removable_phis.size() == 0) {
  1320. return;
  1321. }
  1322. auto fg_name = func_graph_->ToString();
  1323. auto mng = Manage(func_graph_, false);
  1324. // replace the nodes
  1325. // remove from inside to outside
  1326. for (int idx = SizeToInt(phis.size() - 1); idx >= 0; idx--) {
  1327. auto phi = phis[IntToSize(idx)];
  1328. auto new_node = FindPhis(removable_phis, phi);
  1329. MS_LOG(DEBUG) << "phi " << phi->DebugString() << " to " << new_node->DebugString();
  1330. mng->Replace(phi, new_node);
  1331. }
  1332. // remove the parameter
  1333. for (FunctionBlockPtr &block : func_block_list_) {
  1334. MS_EXCEPTION_IF_NULL(block);
  1335. auto &local_removable_phis = block->removable_phis();
  1336. if (local_removable_phis.size() == 0) {
  1337. continue;
  1338. }
  1339. auto func_graph = block->func_graph();
  1340. auto &parameters = func_graph->parameters();
  1341. std::vector<AnfNodePtr> new_parameters(parameters.size());
  1342. auto it = std::copy_if(
  1343. parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](AnfNodePtr param) {
  1344. return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end();
  1345. });
  1346. // shrink container to new size
  1347. new_parameters.resize(std::distance(new_parameters.begin(), it));
  1348. func_graph->set_parameters(new_parameters);
  1349. }
  1350. for (auto fg : mng->func_graphs()) {
  1351. fg->ClearAllManagerInfo();
  1352. }
  1353. }
  1354. // ParseAst class code
  1355. bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) {
  1356. // init the type
  1357. target_type_ = PARSE_TARGET_UNKNOW;
  1358. // call python parse, get the parser fn
  1359. module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  1360. py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD);
  1361. // get the obj type
  1362. auto type = data_converter::GetObjType(obj_);
  1363. if (type == RESOLVE_TYPE_FUNCTION) {
  1364. target_type_ = PARSE_TARGET_FUNCTION;
  1365. function_ = obj_;
  1366. } else if (type == RESOLVE_TYPE_METHOD) {
  1367. // process the method ,need get the method's self obj
  1368. target_type_ = PARSE_TARGET_METHOD;
  1369. py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS);
  1370. if (py::isinstance<py::none>(method_object)) {
  1371. MS_LOG(ERROR) << "Get method's self object instance failed.";
  1372. return false;
  1373. }
  1374. target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
  1375. function_ = obj_;
  1376. obj_ = method_object;
  1377. } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) {
  1378. // obj is class instance, get the method to parse.
  1379. function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method);
  1380. if (py::isinstance<py::none>(function_)) {
  1381. MS_LOG(ERROR) << "Get obj method function failed.";
  1382. return false;
  1383. }
  1384. target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
  1385. // check the fn is method
  1386. auto obj_type = data_converter::GetObjType(function_);
  1387. if (obj_type != RESOLVE_TYPE_METHOD) {
  1388. MS_LOG(WARNING) << "Parse method function is invalid.";
  1389. return false;
  1390. }
  1391. } else {
  1392. MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type;
  1393. return false;
  1394. }
  1395. // call python parse get ast tree
  1396. parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method);
  1397. ast_tree_ = python_adapter::CallPyObjMethod(parser_, "parse");
  1398. // get fn name and module
  1399. function_module_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_module"));
  1400. function_name_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_name"));
  1401. function_filename_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "filename"));
  1402. function_line_offset_ = py::cast<int>(python_adapter::GetPyObjAttr(parser_, "line_offset"));
  1403. return true;
  1404. }
  1405. // Get ast tree node : is the tree bode list[0]
  1406. py::object ParseAst::GetAstNode() {
  1407. py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body");
  1408. py::object ast_node = tree_body[0];
  1409. return ast_node;
  1410. }
  1411. py::list ParseAst::GetArgs(const py::object &func_node) {
  1412. py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS, func_node);
  1413. return ret;
  1414. }
  1415. py::list ParseAst::GetArgsDefaultValues(const py::object &func_node) {
  1416. py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node);
  1417. return ret;
  1418. }
  1419. AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) {
  1420. py::list list_value = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_NODE_TYPE, node);
  1421. if (list_value.size() < 2) {
  1422. MS_LOG(ERROR) << "The node of python method must has 2 values.";
  1423. return nullptr;
  1424. }
  1425. auto node_name = py::cast<std::string>(list_value[0]);
  1426. auto type = AstMainType(py::cast<int32_t>(list_value[1]));
  1427. return std::make_shared<AstNodeType>(node, node_name, type);
  1428. }
  1429. AstSubType ParseAst::GetOpType(const py::object &node) {
  1430. auto op_type = AstSubType(python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_AST_TYPE, node).cast<int32_t>());
  1431. return op_type;
  1432. }
  1433. bool ParseAst::IsClassMember(const py::object &node) {
  1434. py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node);
  1435. if (!py::isinstance<py::bool_>(ret)) {
  1436. MS_LOG(ERROR) << "The result of mod function parse, should be bool type.";
  1437. return false;
  1438. }
  1439. return ret.cast<bool>();
  1440. }
  1441. bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
  1442. if (func_graph == nullptr) {
  1443. MS_LOG(ERROR) << "FuncGraph is null";
  1444. return false;
  1445. }
  1446. if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) {
  1447. MS_LOG(DEBUG) << "No flags";
  1448. return true;
  1449. }
  1450. py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG);
  1451. for (auto &item : flags) {
  1452. if (!py::isinstance<py::str>(item.first)) {
  1453. MS_LOG(ERROR) << "Type error in flags dict convert";
  1454. return false;
  1455. }
  1456. auto name = py::cast<std::string>(item.first);
  1457. if (py::isinstance<py::bool_>(item.second)) {
  1458. auto value = py::cast<bool>(item.second);
  1459. MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
  1460. func_graph->set_flag(name, value);
  1461. } else if (py::isinstance<py::str>(item.second)) {
  1462. auto value = py::cast<std::string>(item.second);
  1463. MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
  1464. func_graph->set_attr(name, MakeValue(value));
  1465. } else {
  1466. MS_LOG(ERROR) << "Type error in flags/attrs dict convert";
  1467. return false;
  1468. }
  1469. }
  1470. return true;
  1471. }
  1472. } // namespace parse
  1473. } // namespace mindspore