You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parse.cc 79 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/parse/parse.h"
  19. #include <utility>
  20. #include <string>
  21. #include <memory>
  22. #include <unordered_map>
  23. #include <sstream>
  24. #include <algorithm>
  25. #include "pipeline/jit/parse/resolve.h"
  26. #include "frontend/operator/ops.h"
  27. #include "pipeline/jit/parse/data_converter.h"
  28. #include "frontend/operator/composite/composite.h"
  29. #include "utils/ms_context.h"
  30. #include "debug/trace.h"
  31. namespace mindspore {
  32. namespace parse {
  33. FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) {
  34. (void)python_adapter::set_python_scoped();
  35. if (!obj || py::isinstance<py::none>(obj)) {
  36. MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none";
  37. return nullptr;
  38. }
  39. auto ast = std::make_shared<ParseAst>(obj);
  40. bool success = ast->InitParseAstInfo(python_mod_get_parse_method);
  41. if (!success) {
  42. MS_LOG(ERROR) << "Parse code to ast tree failed.";
  43. return nullptr;
  44. }
  45. auto parser = std::make_shared<Parser>(ast);
  46. FuncGraphPtr func_graph = parser->ParseFuncGraph();
  47. if (func_graph == nullptr) {
  48. MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode();
  49. return nullptr;
  50. }
  51. return func_graph;
  52. }
  53. TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
  54. if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
  55. return kFloat32;
  56. } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
  57. return kFloat16;
  58. } else {
  59. return nullptr;
  60. }
  61. }
  62. // if any mixed precision flag add a cast node after the parameter node.
  63. AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
  64. TypePtr dst_type;
  65. if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
  66. dst_type = kFloat32;
  67. } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
  68. dst_type = kFloat16;
  69. } else {
  70. return param;
  71. }
  72. auto cast_helper = prim::kPrimMixedPrecisionCast;
  73. auto cast = func_graph->NewCNodeAfter(param, {NewValueNode(cast_helper), NewValueNode(dst_type), param});
  74. return cast;
  75. }
  76. FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
  77. Parser::Parser(const std::shared_ptr<ParseAst> &ast) : ast_(ast) {
  78. errcode_ = PARSE_SUCCESS;
  79. BuildMethodMap();
  80. }
  81. void Parser::BuildMethodMap() {
  82. stmt_method_map_["Return"] = &Parser::ParseReturn;
  83. stmt_method_map_["Expr"] = &Parser::ParseExpr;
  84. stmt_method_map_["If"] = &Parser::ParseIf;
  85. stmt_method_map_["Assign"] = &Parser::ParseAssign;
  86. stmt_method_map_["While"] = &Parser::ParseWhile;
  87. stmt_method_map_["For"] = &Parser::ParseFor;
  88. stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef;
  89. stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign;
  90. stmt_method_map_["Global"] = &Parser::ParseGlobal;
  91. stmt_method_map_["Break"] = &Parser::ParseBreak;
  92. stmt_method_map_["Continue"] = &Parser::ParseContinue;
  93. stmt_method_map_["Pass"] = &Parser::ParsePass;
  94. expr_method_map_["NoneType"] = &Parser::ParseNone;
  95. expr_method_map_["BinOp"] = &Parser::ParseBinOp;
  96. expr_method_map_["Name"] = &Parser::ParseName;
  97. expr_method_map_["Num"] = &Parser::ParseNum;
  98. expr_method_map_["Str"] = &Parser::ParseStr;
  99. expr_method_map_["Constant"] = &Parser::ParseConstant;
  100. expr_method_map_["NameConstant"] = &Parser::ParseNameConstant;
  101. expr_method_map_["Call"] = &Parser::ParseCall;
  102. expr_method_map_["IfExp"] = &Parser::ParseIfExp;
  103. expr_method_map_["Attribute"] = &Parser::ParseAttribute;
  104. expr_method_map_["Compare"] = &Parser::ParseCompare;
  105. expr_method_map_["BoolOp"] = &Parser::ParseBoolOp;
  106. expr_method_map_["Lambda"] = &Parser::ParseLambda;
  107. expr_method_map_["Tuple"] = &Parser::ParseTuple;
  108. expr_method_map_["List"] = &Parser::ParseList;
  109. expr_method_map_["Subscript"] = &Parser::ParseSubscript;
  110. expr_method_map_["Slice"] = &Parser::ParseSlice;
  111. expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice;
  112. expr_method_map_["Index"] = &Parser::ParseIndex;
  113. expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
  114. expr_method_map_["Dict"] = &Parser::ParseDict;
  115. expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
  116. }
  117. void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
  118. void Parser::InitParserEnvironment(const py::object &obj) {
  119. Parser::top_func_graph_ = FuncGraphWeakPtr();
  120. ScopeManager::GetInstance().ClearScope();
  121. (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj);
  122. }
  123. void Parser::CleanParserResource() {
  124. Parser::top_func_graph_ = FuncGraphWeakPtr();
  125. ScopeManager::GetInstance().ClearScope();
  126. }
  127. AnfNodePtr AppendParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
  128. MS_EXCEPTION_IF_NULL(func_graph);
  129. auto value = py::cast<tensor::MetaTensorPtr>(obj);
  130. // parameter object should not be none
  131. if (value == nullptr || !value->is_parameter()) {
  132. MS_LOG(EXCEPTION) << "Parameter error: because obj is not Parameter object.";
  133. }
  134. // get the parameter name from parameter object
  135. auto param_name = value->param_info()->name();
  136. auto top_graph = func_graph;
  137. // if the parameter node has been created , return it
  138. AnfNodePtr para_node = nullptr;
  139. for (auto param : top_graph->parameters()) {
  140. auto param_node = dyn_cast<Parameter>(param);
  141. if (param_node != nullptr && param_node->name() == param_name) {
  142. para_node = param;
  143. break;
  144. }
  145. }
  146. if (para_node == nullptr) {
  147. auto node = top_graph->AddWeightParameter(param_name);
  148. node->set_default_param(value);
  149. // set_abstract for parameter
  150. auto abs = value->ToAbstract();
  151. // boarden value
  152. abs = abs->Broaden();
  153. node->set_abstract(abs);
  154. para_node = node;
  155. }
  156. return para_node;
  157. }
  158. void UpdataParam(const FuncGraphPtr &top_graph, const py::object &cell) {
  159. auto params = py::list(cell.attr("get_parameters")()).cast<std::vector<py::object>>();
  160. for (size_t i = 0; i < params.size(); i++) {
  161. (void)AppendParameterObj(top_graph, params[i]);
  162. }
  163. }
  164. void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &ast) {
  165. // check whether the functions referred by this function and itself are missing 'return' statement
  166. auto mng = Manage(fn, false);
  167. for (auto func_graph : mng->func_graphs()) {
  168. if (func_graph->get_return() != nullptr) {
  169. continue;
  170. }
  171. py::object node = ast->GetAstNode();
  172. py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  173. py::str desc =
  174. python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]);
  175. MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
  176. }
  177. // clear manager info after checking missing return
  178. for (auto fg : mng->func_graphs()) {
  179. fg->ClearAllManagerInfo();
  180. }
  181. }
  182. FuncGraphPtr Parser::ParseFuncGraph() {
  183. // get ast FunctionDef node
  184. py::object node = ast_->GetAstNode();
  185. FunctionBlockPtr pFnBlock = ParseFunction(node);
  186. if (errcode() != PARSE_SUCCESS) {
  187. MS_LOG(ERROR) << "Parse function error, code is " << errcode();
  188. return nullptr;
  189. }
  190. // Add unused variables as isolate nodes.
  191. for (auto &block : func_block_list_) {
  192. block->FindIsolateVariables();
  193. }
  194. RemoveUnnecessaryPhis();
  195. MS_EXCEPTION_IF_NULL(pFnBlock);
  196. CheckFuncReturn(pFnBlock->func_graph(), ast_);
  197. return pFnBlock->func_graph();
  198. }
  199. void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
  200. py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args");
  201. py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg");
  202. block->func_graph()->set_has_vararg(!py::isinstance<py::none>(var_arg_node));
  203. py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg");
  204. block->func_graph()->set_has_kwarg(!py::isinstance<py::none>(kw_arg_node));
  205. py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs");
  206. block->func_graph()->set_kwonlyargs_count(SizeToLong(kwonly_args.size()));
  207. MS_EXCEPTION_IF_NULL(ast_);
  208. py::list args = ast_->GetArgs(fn_node);
  209. for (std::size_t i = 0; i < args.size(); i++) {
  210. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  211. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  212. if (arg_name == "self") {
  213. continue;
  214. }
  215. }
  216. TraceGuard guard(GetLocation(args[i]));
  217. auto para_node = std::make_shared<Parameter>(block->func_graph());
  218. MS_EXCEPTION_IF_NULL(para_node);
  219. para_node->set_name(arg_name);
  220. para_node->debug_info()->set_name(arg_name);
  221. block->func_graph()->add_parameter(para_node);
  222. AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block->func_graph(), para_node);
  223. block->WriteVariable(arg_name, para_after_cast);
  224. MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
  225. }
  226. }
  227. void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
  228. py::list defaults = ast_->GetArgsDefaultValues(fn_node);
  229. py::list args = ast_->GetArgs(fn_node);
  230. std::vector<std::string> namelist_for_default_value;
  231. std::vector<AnfNodePtr> default_values;
  232. for (std::size_t i = 0; i < args.size(); i++) {
  233. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  234. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  235. if (arg_name == "self") {
  236. continue;
  237. }
  238. }
  239. namelist_for_default_value.push_back(arg_name);
  240. if (py::isinstance<py::none>(defaults[i])) {
  241. default_values.push_back(NewValueNode(kNull));
  242. } else {
  243. default_values.push_back(ParseExprNode(block, defaults[i]));
  244. }
  245. }
  246. block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values);
  247. }
  248. ScopePtr Parser::GetScopeForParseFunction() {
  249. ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope();
  250. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  251. py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj());
  252. if (!py::isinstance<py::none>(scope_str)) {
  253. auto scope_name = py::cast<std::string>(scope_str);
  254. scope = std::make_shared<Scope>(scope_name);
  255. }
  256. }
  257. return scope;
  258. }
  259. FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) {
  260. ScopePtr scope = GetScopeForParseFunction();
  261. // the node created in the parsefunction context, will inherit the scope created using scope_guard
  262. ScopeGuard scope_guard(scope);
  263. TraceGuard trace_guard(data_converter::GetObjKey(ast()->obj())[0], GetLocation(node));
  264. FunctionBlockPtr pFunBlock = MakeFunctionBlock(*this);
  265. if (block != nullptr) {
  266. pFunBlock->AddPrevBlock(block);
  267. } else {
  268. func_graph_ = pFunBlock->func_graph();
  269. }
  270. pFunBlock->Mature();
  271. auto current_fg = pFunBlock->func_graph();
  272. auto function_name = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "name"));
  273. MS_LOG(DEBUG) << "The function name is " << function_name;
  274. current_fg->debug_info()->set_name(function_name);
  275. MS_EXCEPTION_IF_NULL(ast_);
  276. py::list deco_list = node.attr("decorator_list");
  277. if (deco_list.size() > 0) {
  278. current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
  279. }
  280. bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
  281. if (!ast_->obj().is(ast_->function())) {
  282. set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg);
  283. }
  284. if (!set_flag) {
  285. MS_LOG(ERROR) << "Set flags failed";
  286. return nullptr;
  287. }
  288. GenerateArgsNodeForFunction(pFunBlock, node);
  289. // when parsing the top graph of construct, save the top graph
  290. if (GetTopFuncGraph() == nullptr) {
  291. UpdateTopFuncGraph(pFunBlock->func_graph());
  292. }
  293. // save the function node to block
  294. pFunBlock->WriteVariable(function_name, NewValueNode(current_fg));
  295. py::object funcObj = python_adapter::GetPyObjAttr(node, "body");
  296. (void)ParseStatements(pFunBlock, funcObj);
  297. if (current_fg->get_return() == nullptr) {
  298. py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  299. py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, node, ret[0], ret[1]);
  300. MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
  301. }
  302. GenerateArgsDefaultValueForFunction(pFunBlock, node);
  303. return pFunBlock;
  304. }
  305. FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr fn_block, const py::object &nodes) {
  306. auto node_list = py::cast<py::list>(nodes);
  307. size_t count = py::len(node_list);
  308. MS_LOG(DEBUG) << "The nodes count is " << count;
  309. for (size_t i = 0; i < count; ++i) {
  310. auto node = node_list[i];
  311. fn_block = ParseStatement(fn_block, node);
  312. // insert appropriate depended items for the function block if it has a return node
  313. if (fn_block->func_graph()->get_return() != nullptr) {
  314. // Skip statements after 'return' (or 'break', 'continue').
  315. break;
  316. }
  317. }
  318. return fn_block;
  319. }
  320. FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) {
  321. TraceGuard trace_guard(GetLocation(node));
  322. auto node_type = ast_->GetNodeType(node);
  323. // check the node type
  324. AstMainType nodeType = node_type->main_type();
  325. if (nodeType != AST_MAIN_TYPE_STMT) {
  326. MS_LOG(INFO) << "Node type is error : " << nodeType;
  327. return block;
  328. }
  329. // call the process function
  330. std::string node_name = node_type->node_name();
  331. MS_LOG(DEBUG) << "Ast node is " << node_name;
  332. if (stmt_method_map_.count(node_name)) {
  333. auto stmt_block = (this->*stmt_method_map_[node_name])(block, node);
  334. TraceManager::ClearParseOrResolveDebugInfo();
  335. return stmt_block;
  336. } else {
  337. errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
  338. MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "'.";
  339. }
  340. }
  341. AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) {
  342. MS_LOG(DEBUG) << "Process ast expr";
  343. TraceGuard trace_guard(GetLocation(node));
  344. auto node_type = ast_->GetNodeType(node);
  345. // check the node type
  346. AstMainType node_main_type = node_type->main_type();
  347. if (node_main_type != AST_MAIN_TYPE_EXPR) {
  348. MS_LOG(ERROR) << "Node type is error : " << node_main_type;
  349. errcode_ = PARSE_NODE_TYPE_NO_MATCH;
  350. return nullptr;
  351. }
  352. // call the process function
  353. std::string node_name = node_type->node_name();
  354. MS_LOG(DEBUG) << "Ast node is " << node_name;
  355. if (expr_method_map_.count(node_name)) {
  356. auto expr_node = (this->*expr_method_map_[node_name])(block, node);
  357. TraceManager::ClearParseOrResolveDebugInfo();
  358. return expr_node;
  359. } else {
  360. errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
  361. MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "'.";
  362. }
  363. }
  364. // process the expr statement and expand it
  365. FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) {
  366. MS_LOG(DEBUG) << "Process ast Expr";
  367. // Expr only have value , no target
  368. py::tuple expand_info = ast_->CallParserObjMethod(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node);
  369. // refer python function expand_expr_statement, expand_info is one of the following:
  370. // True, expr.value, x
  371. // True, expr.value
  372. // False, None, None
  373. // check the expand info result
  374. auto is_expand = py::cast<bool>(expand_info[0]);
  375. if (is_expand) {
  376. // process the expr statement
  377. py::object value_object = expand_info[1];
  378. AnfNodePtr value_node = ParseExprNode(block, value_object);
  379. if (py::len(expand_info) == 2) {
  380. // expression that not assigned to any variable,
  381. // this is usually a call with side effects,
  382. // e.g.: print(x)
  383. // we save it as an isolate node.
  384. value_node->func_graph()->AddIsolateNode(value_node);
  385. } else {
  386. // expand the assign statement,
  387. // e.g.: x.append(y) -> x = x.append(y)
  388. py::object target_node = expand_info[2];
  389. WriteAssignVars(block, target_node, value_node);
  390. }
  391. }
  392. return block;
  393. }
  394. LocationPtr Parser::GetLocation(const py::object &node) const {
  395. MS_EXCEPTION_IF_NULL(ast_);
  396. py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  397. if (ret.size() < 5) {
  398. MS_LOG(EXCEPTION) << "List size should not be less than 5.";
  399. }
  400. // refer to Location::Location() for each member of ret: line, column, line_end, column_end.
  401. auto location = std::make_shared<Location>(ret[0].cast<std::string>(), ret[1].cast<int64_t>(), ret[2].cast<int64_t>(),
  402. ret[3].cast<int64_t>(), ret[4].cast<int64_t>());
  403. return location;
  404. }
  405. void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
  406. const FunctionBlockPtr &false_block) {
  407. true_block->AddPrevBlock(pre_block);
  408. true_block->Mature();
  409. false_block->AddPrevBlock(pre_block);
  410. false_block->Mature();
  411. }
  412. FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) {
  413. MS_LOG(DEBUG) << "Process ast return";
  414. MS_EXCEPTION_IF_NULL(block);
  415. // create return valuenode
  416. AnfNodePtr pReturnValueNode = NewValueNode(prim::kPrimReturn);
  417. // parse the return Statements value
  418. py::object value = python_adapter::GetPyObjAttr(node, "value");
  419. AnfNodePtr pReturnStatementNode = ParseExprNode(block, value);
  420. // Create the cnode
  421. CNodePtr pReturnCNode = block->func_graph()->NewCNodeInOrder({pReturnValueNode, pReturnStatementNode});
  422. block->func_graph()->set_return(pReturnCNode);
  423. return block;
  424. }
  425. // Process binary operators,eg: `a + b`, `a | b`, etc.
  426. AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) {
  427. MS_LOG(DEBUG) << "Process ast BinOP";
  428. py::object left = python_adapter::GetPyObjAttr(node, "left");
  429. py::object right = python_adapter::GetPyObjAttr(node, "right");
  430. py::object op = python_adapter::GetPyObjAttr(node, "op");
  431. // create left and right ANF node
  432. AnfNodePtr left_node = ParseExprNode(block, left);
  433. if (left_node == nullptr) {
  434. MS_LOG(WARNING) << "DoBinOp process left node failed: " << errcode();
  435. return nullptr;
  436. }
  437. AnfNodePtr right_node = ParseExprNode(block, right);
  438. if (right_node == nullptr) {
  439. MS_LOG(WARNING) << "DoBinOp process right node failed:" << errcode();
  440. return nullptr;
  441. }
  442. // resolve the op
  443. AnfNodePtr op_node = block->MakeResolveAstOp(op);
  444. // create apply node
  445. return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
  446. }
  447. AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) {
  448. MS_LOG(DEBUG) << "Process ast Name";
  449. auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
  450. MS_LOG(DEBUG) << "The Name id is " << name_id;
  451. if (block->IsGlobalVar(name_id)) {
  452. return block->MakeResolveSymbol(name_id);
  453. }
  454. return block->ReadVariable(name_id);
  455. }
  456. AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) {
  457. MS_LOG(DEBUG) << "Process ast NoneType";
  458. return NewValueNode(kNone);
  459. }
  460. AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) {
  461. MS_LOG(DEBUG) << "Process ast Ellipsis";
  462. return NewValueNode(kEllipsis);
  463. }
  464. AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
  465. MS_LOG(DEBUG) << "Process ast Num";
  466. py::object obj = python_adapter::GetPyObjAttr(node, "n");
  467. if (py::isinstance<py::int_>(obj)) {
  468. MS_LOG(INFO) << "The Num is int64_t:" << (std::string)py::str(obj);
  469. auto data = py::cast<int64_t>(obj);
  470. return NewValueNode(data);
  471. } else if (py::isinstance<py::float_>(obj)) {
  472. MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj);
  473. auto data = py::cast<float>(obj);
  474. return NewValueNode(data);
  475. } else {
  476. // no else actually
  477. MS_LOG(ERROR) << "Unsupported Num type : " << (std::string)py::str(obj);
  478. errcode_ = PARSE_NODE_TYPE_UNKOWN;
  479. return nullptr;
  480. }
  481. }
  482. AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) {
  483. MS_LOG(DEBUG) << "Process ast Str";
  484. auto str_s = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "s"));
  485. return NewValueNode(str_s);
  486. }
  487. AnfNodePtr Parser::ParseConstant(const FunctionBlockPtr &, const py::object &node) {
  488. MS_LOG(DEBUG) << "Process ast Constant";
  489. py::object obj = python_adapter::GetPyObjAttr(node, "value");
  490. if (py::isinstance<py::bool_>(obj)) {
  491. MS_LOG(INFO) << "The Constant is bool:" << (std::string)py::str(obj);
  492. auto data = py::cast<bool>(obj);
  493. return NewValueNode(data);
  494. } else if (py::isinstance<py::int_>(obj)) {
  495. MS_LOG(INFO) << "The Constant is int64_t:" << (std::string)py::str(obj);
  496. auto data = py::cast<int64_t>(obj);
  497. return NewValueNode(data);
  498. } else if (py::isinstance<py::float_>(obj)) {
  499. MS_LOG(INFO) << "The Constant is float:" << (std::string)py::str(obj);
  500. auto data = py::cast<float>(obj);
  501. return NewValueNode(data);
  502. } else if (py::isinstance<py::str>(obj)) {
  503. MS_LOG(INFO) << "The Constant is string:" << (std::string)py::str(obj);
  504. auto data = py::cast<std::string>(obj);
  505. return NewValueNode(data);
  506. } else if (py::isinstance<py::none>(obj)) {
  507. MS_LOG(INFO) << "The Constant is none:" << (std::string)py::str(obj);
  508. return NewValueNode(kNone);
  509. } else {
  510. // no else actually
  511. MS_EXCEPTION(TypeError) << "Unsupported Constant type : " << (std::string)py::str(obj);
  512. }
  513. }
  514. AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) {
  515. MS_LOG(DEBUG) << "Process ast NameConstant";
  516. py::object obj = python_adapter::GetPyObjAttr(node, "value");
  517. if (py::isinstance<py::bool_>(obj)) {
  518. MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj);
  519. auto data = py::cast<bool>(obj);
  520. return NewValueNode(data);
  521. } else if (py::isinstance<py::none>(obj)) {
  522. MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj);
  523. return NewValueNode(kNone);
  524. } else {
  525. // no else actually
  526. MS_LOG(ERROR) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
  527. errcode_ = PARSE_NODE_TYPE_UNKOWN;
  528. return nullptr;
  529. }
  530. }
  531. AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes) {
  532. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  533. std::vector<AnfNodePtr> make_tuple_nodes;
  534. make_tuple_nodes.push_back(make_tuple_op);
  535. (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes),
  536. [](AnfNodePtr arg) -> AnfNodePtr { return arg; });
  537. return block->func_graph()->NewCNodeInOrder(make_tuple_nodes);
  538. }
  539. AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) {
  540. py::object father_class;
  541. if (args.empty()) {
  542. father_class = py::none();
  543. } else if (args.size() == 2) {
  544. father_class = args[0];
  545. auto arg_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[1])));
  546. if (arg_type != AST_SUB_TYPE_NAME || py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) != "self") {
  547. MS_EXCEPTION(ArgumentError) << "When call 'super', the second arg should be 'self'.";
  548. }
  549. } else {
  550. MS_EXCEPTION(ArgumentError) << "When call 'super', the args number should be 0 or 2, but got" << args.size() << ".";
  551. }
  552. py::object target_class_instance = ast()->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast()->obj());
  553. py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
  554. NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
  555. SymbolPtr symbol = std::make_shared<Symbol>("namespace");
  556. return block->MakeResolve(name_space, symbol);
  557. }
  558. // process function call, eg : f1(x, y) ...
  559. AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) {
  560. MS_LOG(DEBUG) << "Process ast Call";
  561. // process function call
  562. py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func");
  563. py::list args = python_adapter::GetPyObjAttr(node, "args");
  564. auto arg_type =
  565. AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, function_ast_node)));
  566. if (arg_type == AST_SUB_TYPE_NAME) {
  567. auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id"));
  568. if (name_id == "super") {
  569. return ParseSuper(block, args);
  570. }
  571. }
  572. AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node);
  573. // function call arguments should be passed in as groups and unpacked later using unpack call
  574. std::vector<AnfNodePtr> packed_arguments;
  575. std::vector<AnfNodePtr> group_arguments;
  576. bool need_unpack_args = ParseArgsInCall(block, args, &packed_arguments, &group_arguments);
  577. bool need_unpack_keywords = ParseKeywordsInCall(block, node, &packed_arguments);
  578. // if there is stared or keyword argument, unpack may be needed
  579. bool need_unpack = need_unpack_args || need_unpack_keywords;
  580. return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack);
  581. }
  582. CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_anf_node,
  583. const std::vector<AnfNodePtr> &packed_arguments) {
  584. std::vector<AnfNodePtr> unpack_call_nodes;
  585. auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
  586. unpack_call_nodes.push_back(unpack_call_op);
  587. unpack_call_nodes.push_back(call_function_anf_node);
  588. (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes),
  589. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  590. CNodePtr unpack_call = func_graph->NewCNodeInOrder(unpack_call_nodes);
  591. return unpack_call;
  592. }
  593. AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
  594. const std::vector<AnfNodePtr> &packed_arguments,
  595. const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
  596. // if there is keyword arguments or starred, using an unpack_call op to unpack the argument
  597. if (need_unpack) {
  598. return MakeUnpackCall(block->func_graph(), call_function_anf_node, packed_arguments);
  599. }
  600. // else there is no keyword arguments and starred, parsed as normal arguments without unpack
  601. std::vector<AnfNodePtr> func_call_nodes;
  602. func_call_nodes.push_back(call_function_anf_node);
  603. (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes),
  604. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  605. CNodePtr call_anf_node = block->func_graph()->NewCNodeInOrder(func_call_nodes);
  606. return call_anf_node;
  607. }
  608. bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args,
  609. std::vector<AnfNodePtr> *packed_arguments, std::vector<AnfNodePtr> *group_arguments) {
  610. bool need_unpack = false;
  611. for (size_t i = 0; i < args.size(); i++) {
  612. auto arg_node = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[i])));
  613. if (arg_node == AST_SUB_TYPE_STARRED) {
  614. if (!group_arguments->empty()) {
  615. packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments));
  616. }
  617. packed_arguments->push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value")));
  618. group_arguments->clear();
  619. need_unpack = true;
  620. } else {
  621. group_arguments->push_back(ParseExprNode(block, args[i]));
  622. }
  623. }
  624. if (!group_arguments->empty()) {
  625. packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments));
  626. }
  627. return need_unpack;
  628. }
  629. bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node,
  630. std::vector<AnfNodePtr> *packed_arguments) {
  631. bool need_unpack = false;
  632. py::list keywords = python_adapter::GetPyObjAttr(node, "keywords");
  633. if (!keywords.empty()) {
  634. need_unpack = true;
  635. std::vector<AnfNodePtr> keys;
  636. std::vector<AnfNodePtr> values;
  637. for (size_t index = 0; index < keywords.size(); index++) {
  638. auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg");
  639. auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value");
  640. if (py::isinstance<py::none>(kw_key)) {
  641. packed_arguments->push_back(ParseExprNode(block, kw_value));
  642. } else {
  643. auto kw_key_c = kw_key.cast<std::string>();
  644. keys.push_back(NewValueNode(kw_key_c));
  645. values.push_back(ParseExprNode(block, kw_value));
  646. }
  647. }
  648. auto keys_tuple = GenerateMakeTuple(block, keys);
  649. auto values_tuple = GenerateMakeTuple(block, values);
  650. auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
  651. std::vector<AnfNodePtr> make_dict_nodes;
  652. make_dict_nodes.push_back(make_dict_op);
  653. make_dict_nodes.push_back(keys_tuple);
  654. make_dict_nodes.push_back(values_tuple);
  655. packed_arguments->push_back(block->func_graph()->NewCNodeInOrder(make_dict_nodes));
  656. }
  657. return need_unpack;
  658. }
  659. // process call attributes of class type define, eg: x.y()
  660. AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) {
  661. MS_LOG(DEBUG) << "Process ast Attribute";
  662. // process class value,eg: self.xx
  663. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  664. if (ast_->IsClassMember(node)) {
  665. std::string var_name = "self.";
  666. std::string attr_name = node.attr("attr").cast<std::string>();
  667. (void)var_name.append(attr_name);
  668. auto attr_obj = ast()->obj().attr(attr_name.c_str());
  669. if (py::hasattr(ast()->obj(), attr_name.c_str()) &&
  670. (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance<py::int_>(attr_obj) ||
  671. py::isinstance<py::float_>(attr_obj) || py::isinstance<py::bool_>(attr_obj) ||
  672. py::isinstance<py::str>(attr_obj) || data_converter::IsCellInstance(attr_obj))) {
  673. return block->MakeResolveSymbol(var_name);
  674. } else {
  675. return block->ReadVariable(var_name);
  676. }
  677. }
  678. }
  679. // process the get attr
  680. // Use the Primitive replace the operation resolve node (getattr)
  681. // because the getattr will eventually be converted to Primitive node
  682. AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr);
  683. // process the attr body
  684. py::object value_body = python_adapter::GetPyObjAttr(node, "value");
  685. AnfNodePtr value_node = ParseExprNode(block, value_body);
  686. if (value_node == nullptr) {
  687. MS_LOG(WARNING) << "Parse attribute failed";
  688. return nullptr;
  689. }
  690. // process the node attr
  691. auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast<std::string>();
  692. MS_LOG(DEBUG) << "Attr = " << attr_str;
  693. AnfNodePtr attr_node = nullptr;
  694. {
  695. TraceGuard guard(GetLocation(python_adapter::GetPyObjAttr(node, "attr")));
  696. attr_node = NewValueNode(attr_str);
  697. }
  698. // create the apply node
  699. return block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
  700. }
  701. // Process comparison expression : a == b. a > b etc.
  702. AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) {
  703. MS_LOG(DEBUG) << "Process ast Compare";
  704. // for python comparison ,there may be if x>y>5 ,
  705. // which there is two ops , but we only support one now
  706. py::list ops = python_adapter::GetPyObjAttr(node, "ops");
  707. if (ops.size() > MAX_COMPARISON_OPS_SUPPORTED) {
  708. MS_LOG(ERROR) << "MindSpore does not support comparison with operators more than one now, ops size =" << ops.size();
  709. return nullptr;
  710. }
  711. py::object left = python_adapter::GetPyObjAttr(node, "left");
  712. py::list comparators = python_adapter::GetPyObjAttr(node, "comparators");
  713. AnfNodePtr left_node = ParseExprNode(block, left);
  714. AnfNodePtr right_node = ParseExprNode(block, comparators[0]);
  715. MS_EXCEPTION_IF_NULL(block);
  716. AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]);
  717. return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
  718. }
  719. AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) {
  720. // if there is only one bool op now
  721. if (value_list.size() == 1) {
  722. AnfNodePtr first_node = ParseExprNode(block, value_list[0]);
  723. return first_node;
  724. } else {
  725. py::object first = value_list[0];
  726. py::list rest;
  727. for (size_t i = 1; i < value_list.size(); i++) {
  728. rest.append(value_list[i]);
  729. }
  730. MS_EXCEPTION_IF_NULL(block);
  731. FunctionBlockPtr true_block = nullptr;
  732. FunctionBlockPtr false_block = nullptr;
  733. {
  734. TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
  735. true_block = MakeFunctionBlock(*this);
  736. }
  737. {
  738. TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
  739. false_block = MakeFunctionBlock(*this);
  740. }
  741. MakeConditionBlocks(block, true_block, false_block);
  742. FunctionBlockPtr b1, b2;
  743. // if it is and, we need to process the rest nodes;
  744. // if it is or, we continue to next
  745. if (mode == AST_SUB_TYPE_AND) {
  746. b1 = true_block;
  747. b2 = false_block;
  748. } else if (mode == AST_SUB_TYPE_OR) {
  749. b2 = true_block;
  750. b1 = false_block;
  751. } else {
  752. MS_LOG(ERROR) << "Not supported mode: " << mode;
  753. return nullptr;
  754. }
  755. AnfNodePtr test_node = ParseExprNode(block, first);
  756. AnfNodePtr rest_node = ProcessBoolOpValueList(b1, rest, mode);
  757. b1->func_graph()->set_output(rest_node);
  758. b2->func_graph()->set_output(test_node);
  759. auto cond_node = block->ForceToBoolNode(test_node);
  760. auto switch_app = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node,
  761. NewValueNode(true_block->func_graph()),
  762. NewValueNode(false_block->func_graph())});
  763. std::vector<AnfNodePtr> call_graph_nodes{switch_app};
  764. auto switch_app_call = block->func_graph()->NewCNodeInOrder(call_graph_nodes);
  765. return switch_app_call;
  766. }
  767. }
  768. // Process comparison expression : a and b. a or b .
  769. AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) {
  770. MS_LOG(DEBUG) << "Process ast BoolOp";
  771. py::object op_node = python_adapter::GetPyObjAttr(node, "op");
  772. AstSubType op_type = ast_->GetOpType(op_node);
  773. if (op_type == AST_SUB_TYPE_UNKNOWN) {
  774. MS_LOG(WARNING) << "ProcessBoolOp, got unknown op type";
  775. return nullptr;
  776. }
  777. py::list op_values = python_adapter::GetPyObjAttr(node, "values");
  778. return ProcessBoolOpValueList(block, op_values, op_type);
  779. }
  780. // Process a function def
  781. FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
  782. MS_LOG(DEBUG) << "Process ast FunctionDef";
  783. FunctionBlockPtr function_block = ParseFunction(node, block);
  784. MS_EXCEPTION_IF_NULL(function_block);
  785. // get function name
  786. py::str name = python_adapter::GetPyObjAttr(node, "name");
  787. std::string function_name = name;
  788. ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph());
  789. block->WriteVariable(function_name, valuenode_graph);
  790. return block;
  791. }
  792. // Process a lambda expression . like lambda x,y: x + y
  793. AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
  794. MS_LOG(DEBUG) << "Process ast Lambda";
  795. FunctionBlockPtr func_block = MakeFunctionBlock(*this);
  796. func_block->AddPrevBlock(block);
  797. func_block->Mature();
  798. // get lambda args
  799. py::list args = ast_->GetArgs(node);
  800. for (std::size_t i = 0; i < args.size(); i++) {
  801. std::string arg = py::cast<std::string>(args[i].attr("arg"));
  802. TraceGuard guard(GetLocation(args[i]));
  803. auto para_node = std::make_shared<Parameter>(func_block->func_graph());
  804. para_node->debug_info()->set_name(arg);
  805. func_block->func_graph()->add_parameter(para_node);
  806. func_block->WriteVariable(arg, para_node);
  807. MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg;
  808. }
  809. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  810. AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
  811. func_block->func_graph()->set_output(lambda_body_node);
  812. ValueNodePtr const_graph = NewValueNode(func_block->func_graph());
  813. return const_graph;
  814. }
  815. // process a tuple
  816. AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) {
  817. MS_LOG(DEBUG) << "Process ast Tuple";
  818. MS_EXCEPTION_IF_NULL(block);
  819. py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
  820. if (elts.size() == 0) {
  821. auto empty_tuple = std::vector<ValuePtr>();
  822. return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
  823. }
  824. std::vector<AnfNodePtr> tuple_vec;
  825. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  826. tuple_vec.emplace_back(make_tuple_op);
  827. for (size_t i = 0; i < elts.size(); i++) {
  828. AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
  829. tuple_vec.emplace_back(node_ptr);
  830. }
  831. CNodePtr tuple_app = block->func_graph()->NewCNodeInOrder(tuple_vec);
  832. return tuple_app;
  833. }
  834. // process a list
  835. AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) {
  836. MS_LOG(DEBUG) << "Process ast List";
  837. MS_EXCEPTION_IF_NULL(block);
  838. py::list elts = python_adapter::GetPyObjAttr(node, "elts");
  839. if (elts.size() == 0) {
  840. auto empty_list = std::vector<ValuePtr>();
  841. return NewValueNode(std::make_shared<ValueList>(empty_list));
  842. }
  843. std::vector<AnfNodePtr> list_vec;
  844. AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
  845. list_vec.emplace_back(make_list_op);
  846. for (size_t i = 0; i < elts.size(); i++) {
  847. AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
  848. list_vec.emplace_back(node_ptr);
  849. }
  850. CNodePtr list_app = block->func_graph()->NewCNodeInOrder(list_vec);
  851. return list_app;
  852. }
  853. // process a subscript, such as x[y] , node expressed as value[slice]
  854. AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) {
  855. MS_LOG(DEBUG) << "Process ast Subscript";
  856. MS_EXCEPTION_IF_NULL(block);
  857. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  858. py::object value_node = python_adapter::GetPyObjAttr(node, "value");
  859. py::object slice_node = python_adapter::GetPyObjAttr(node, "slice");
  860. AnfNodePtr value = ParseExprNode(block, value_node);
  861. AnfNodePtr slice = ParseExprNode(block, slice_node);
  862. return block->func_graph()->NewCNodeInOrder({op_getitem, value, slice});
  863. }
  864. // process a slice, get the slice value
  865. AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) {
  866. MS_LOG(DEBUG) << "Process ast Slice";
  867. MS_EXCEPTION_IF_NULL(block);
  868. AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE);
  869. py::object start = python_adapter::GetPyObjAttr(node, "lower");
  870. py::object stop = python_adapter::GetPyObjAttr(node, "upper");
  871. py::object step = python_adapter::GetPyObjAttr(node, "step");
  872. AnfNodePtr start_node = ParseExprNode(block, start);
  873. AnfNodePtr stop_node = ParseExprNode(block, stop);
  874. AnfNodePtr step_node = ParseExprNode(block, step);
  875. return block->func_graph()->NewCNodeInOrder({op_makeslice, start_node, stop_node, step_node});
  876. }
  877. // process a extslice
  878. AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) {
  879. MS_LOG(DEBUG) << "Process ast ExtSlice";
  880. MS_EXCEPTION_IF_NULL(block);
  881. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  882. py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims");
  883. std::vector<AnfNodePtr> node_vec;
  884. node_vec.emplace_back(make_tuple_op);
  885. for (size_t i = 0; i < slice_tuple.size(); i++) {
  886. AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]);
  887. node_vec.emplace_back(node_ptr);
  888. }
  889. CNodePtr tuple_conde = block->func_graph()->NewCNodeInOrder(node_vec);
  890. return tuple_conde;
  891. }
  892. // process a index, get the index number
  893. AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) {
  894. MS_LOG(DEBUG) << "Process ast Index";
  895. py::object value_node = python_adapter::GetPyObjAttr(node, "value");
  896. return ParseExprNode(block, value_node);
  897. }
  898. // process a UnaryOp, +a, -b
  899. AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) {
  900. MS_LOG(DEBUG) << "Process ast UnaryOp";
  901. py::object op = python_adapter::GetPyObjAttr(node, "op");
  902. MS_EXCEPTION_IF_NULL(block);
  903. // resolve the op
  904. AnfNodePtr op_node = block->MakeResolveAstOp(op);
  905. py::object operand = python_adapter::GetPyObjAttr(node, "operand");
  906. AnfNodePtr operand_node = ParseExprNode(block, operand);
  907. return block->func_graph()->NewCNodeInOrder({op_node, operand_node});
  908. }
  909. // process a dict ast node expression
  910. AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
  911. MS_LOG(DEBUG) << "Process ast Dict";
  912. py::list keys = node.attr("keys");
  913. py::list values = node.attr("values");
  914. std::vector<AnfNodePtr> key_nodes;
  915. std::vector<AnfNodePtr> value_nodes;
  916. for (size_t i = 0; i < keys.size(); i++) {
  917. key_nodes.push_back(ParseExprNode(block, keys[i]));
  918. value_nodes.push_back(ParseExprNode(block, values[i]));
  919. }
  920. auto keys_tuple = GenerateMakeTuple(block, key_nodes);
  921. auto values_tuple = GenerateMakeTuple(block, value_nodes);
  922. MS_EXCEPTION_IF_NULL(block);
  923. auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
  924. return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple});
  925. }
  926. // process a augment assign such as a += b or mat[stride_slice] += b.
  927. FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) {
  928. MS_LOG(DEBUG) << "Process ast AugAssign";
  929. MS_EXCEPTION_IF_NULL(block);
  930. MS_EXCEPTION_IF_NULL(ast_);
  931. py::object target_obj = python_adapter::GetPyObjAttr(node, "target");
  932. py::object op_obj = python_adapter::GetPyObjAttr(node, "op");
  933. py::object value_obj = python_adapter::GetPyObjAttr(node, "value");
  934. AnfNodePtr target_node = nullptr;
  935. AnfNodePtr op_node = block->MakeResolveAstOp(op_obj);
  936. AnfNodePtr value_node = ParseExprNode(block, value_obj);
  937. auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, target_obj)));
  938. if (ast_type == AST_SUB_TYPE_NAME) {
  939. target_node = ParseName(block, target_obj);
  940. } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
  941. target_node = ParseSubscript(block, target_obj);
  942. } else if (ast_->IsClassMember(target_obj)) {
  943. target_node = ParseAttribute(block, target_obj);
  944. } else {
  945. MS_LOG(EXCEPTION) << "Not supported augassign";
  946. }
  947. if (target_node == nullptr) {
  948. MS_LOG(EXCEPTION) << "Can not get target node ";
  949. }
  950. CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
  951. WriteAssignVars(block, target_obj, augassign_app);
  952. return block;
  953. }
  954. // process global declaration such as 'global x';
  955. FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) {
  956. MS_LOG(DEBUG) << "Process ast Global";
  957. MS_EXCEPTION_IF_NULL(block);
  958. py::list vars = python_adapter::GetPyObjAttr(node, "names");
  959. for (auto &item : vars) {
  960. block->AddGlobalVar(py::cast<std::string>(item));
  961. }
  962. return block;
  963. }
  964. // process a if statement
  965. FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) {
  966. MS_LOG(DEBUG) << "Process ast If";
  967. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  968. AnfNodePtr condition_node = ParseExprNode(block, test_node);
  969. MS_EXCEPTION_IF_NULL(block);
  970. CNodePtr bool_node = block->ForceToBoolNode(condition_node);
  971. FunctionBlockPtr true_block = nullptr;
  972. FunctionBlockPtr false_block = nullptr;
  973. {
  974. TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
  975. true_block = MakeFunctionBlock(*this);
  976. }
  977. {
  978. TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
  979. false_block = MakeFunctionBlock(*this);
  980. }
  981. MakeConditionBlocks(block, true_block, false_block);
  982. FunctionBlockPtr after_block = nullptr;
  983. {
  984. TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
  985. after_block = MakeFunctionBlock(*this);
  986. }
  987. if (MsContext::GetInstance()->backend_policy() != "ge") {
  988. // for backends excludes 'ge', it can handle multi graph call, use this flag to
  989. // generate call not inline `after_block` graph to reduce if by if switch expansion.
  990. after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
  991. }
  992. // process the if-true branch
  993. py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
  994. FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode);
  995. // if the return_ is set ,it has its own continuation block
  996. if (true_end->func_graph()->get_return() == nullptr) {
  997. true_end->Jump(after_block, nullptr);
  998. }
  999. // process the orelse branch
  1000. py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
  1001. FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode);
  1002. // if the return_ is set ,it has its own continuation block
  1003. if (false_end->func_graph()->get_return() == nullptr) {
  1004. false_end->Jump(after_block, nullptr);
  1005. }
  1006. block->ConditionalJump(bool_node, true_block, false_block);
  1007. after_block->Mature();
  1008. return after_block;
  1009. }
  1010. FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) {
  1011. MS_LOG(DEBUG) << "Process ast While";
  1012. MS_EXCEPTION_IF_NULL(block);
  1013. MS_LOG(INFO) << "Parse while statement";
  1014. FunctionBlockPtr header_block = nullptr;
  1015. FunctionBlockPtr body_block = nullptr;
  1016. FunctionBlockPtr after_block = nullptr;
  1017. {
  1018. TraceGuard guard(std::make_shared<TraceWhileHeader>(block->func_graph()->debug_info()));
  1019. header_block = MakeFunctionBlock(*this);
  1020. }
  1021. {
  1022. TraceGuard guard(std::make_shared<TraceWhileBody>(block->func_graph()->debug_info()));
  1023. body_block = MakeFunctionBlock(*this);
  1024. }
  1025. {
  1026. TraceGuard guard(std::make_shared<TraceWhileAfter>(block->func_graph()->debug_info()));
  1027. after_block = MakeFunctionBlock(*this);
  1028. }
  1029. body_block->AddPrevBlock(header_block);
  1030. after_block->AddPrevBlock(header_block);
  1031. block->Jump(header_block, nullptr);
  1032. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  1033. AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
  1034. condition_node = header_block->ForceToWhileCond(condition_node);
  1035. body_block->Mature();
  1036. header_block->ConditionalJump(condition_node, body_block, after_block);
  1037. // Parse loop body statements with loop context.
  1038. LoopContext loop_context{&loops_, header_block, nullptr};
  1039. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1040. FunctionBlockPtr after_body = ParseStatements(body_block, body_node);
  1041. if (after_body->func_graph()->get_return() == nullptr) {
  1042. after_body->Jump(header_block, nullptr);
  1043. }
  1044. header_block->Mature();
  1045. after_block->Mature();
  1046. auto &end_block = loop_context.EndBlock();
  1047. if (end_block) {
  1048. // end_block exists if we encounter 'break' in loop body.
  1049. after_block->Jump(end_block, nullptr);
  1050. end_block->Mature();
  1051. return end_block;
  1052. }
  1053. // No 'break', no end_block.
  1054. return after_block;
  1055. }
  1056. CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::object &node,
  1057. const AnfNodePtr &op_iter) {
  1058. py::object iter_node = python_adapter::GetPyObjAttr(node, "iter");
  1059. AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
  1060. return block->func_graph()->NewCNodeInOrder({op_iter, iter_anf_node});
  1061. }
  1062. CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
  1063. const AnfNodePtr &op_hasnext) {
  1064. MS_EXCEPTION_IF_NULL(header_block);
  1065. return header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
  1066. }
  1067. FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
  1068. TraceGuard trace_guard(trace_info);
  1069. FunctionBlockPtr body_block = MakeFunctionBlock(*this);
  1070. return body_block;
  1071. }
  1072. int64_t GetForTransToWhileLoop() {
  1073. static const auto loop_str = common::GetEnv("ENV_FOR_TO_WHILE_LOOP");
  1074. // int64 support 63bits positive num mostly.
  1075. if (loop_str.size() > 63 || loop_str.empty()) {
  1076. return MAX_FOR_LOOP_COUNT;
  1077. }
  1078. if (std::any_of(loop_str.begin(), loop_str.end(), [](char c) { return c < '0' || c > '9'; })) {
  1079. return MAX_FOR_LOOP_COUNT;
  1080. }
  1081. int64_t loop_count;
  1082. std::stringstream ss;
  1083. ss << loop_str;
  1084. ss >> loop_count;
  1085. return loop_count;
  1086. }
  1087. // A for loop will generate 3 functions :the test, the body, and the continuation
  1088. // for x in xs:
  1089. // body
  1090. // it is compiled to be following statement
  1091. // if len(xs) < max_loop_cnt:
  1092. // ParseForIter() // use iter to implement for loop, which always unroll loop
  1093. // else:
  1094. // ParseForLoop() // use loop var to implement for loop, which always sink loop
  1095. FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
  1096. MS_LOG(DEBUG) << "Process ast For, create an if else statement";
  1097. MS_EXCEPTION_IF_NULL(block);
  1098. // create statement 'len(xs) < MAX_FOR_LOOP_COUNT'
  1099. AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
  1100. py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
  1101. AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
  1102. CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
  1103. CNodePtr bool_node = block->func_graph()->NewCNodeInOrder(
  1104. {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(GetForTransToWhileLoop())});
  1105. // create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
  1106. FunctionBlockPtr true_block = nullptr;
  1107. FunctionBlockPtr false_block = nullptr;
  1108. {
  1109. TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
  1110. true_block = MakeFunctionBlock(*this);
  1111. }
  1112. {
  1113. TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
  1114. false_block = MakeFunctionBlock(*this);
  1115. }
  1116. MakeConditionBlocks(block, true_block, false_block);
  1117. FunctionBlockPtr after_block = nullptr;
  1118. {
  1119. TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
  1120. after_block = MakeFunctionBlock(*this);
  1121. }
  1122. FunctionBlockPtr true_end = ParseForIter(true_block, node);
  1123. true_end->Jump(after_block, nullptr);
  1124. FunctionBlockPtr false_end = ParseForLoop(false_block, node);
  1125. false_end->Jump(after_block, nullptr);
  1126. block->ConditionalJump(bool_node, true_block, false_block);
  1127. after_block->Mature();
  1128. return after_block;
  1129. }
  1130. // A for loop will generate 3 functions :the test, the body, and the continuation
  1131. // for x in xs:
  1132. // body
  1133. // it is compiled to be following statement
  1134. // it = iter(xs)
  1135. // while hastnext(it)
  1136. // x, it = next(it)
  1137. // body
  1138. FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) {
  1139. MS_LOG(DEBUG) << "Process ast For";
  1140. MS_EXCEPTION_IF_NULL(block);
  1141. AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
  1142. AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
  1143. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1144. AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
  1145. // generate the iterator apply
  1146. CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter);
  1147. MS_EXCEPTION_IF_NULL(iter_apply);
  1148. FunctionBlockPtr header_block =
  1149. GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
  1150. MS_EXCEPTION_IF_NULL(header_block);
  1151. // generate the hasnext apply which is a condition
  1152. ParameterPtr iter_param = header_block->func_graph()->add_parameter();
  1153. CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext);
  1154. // generate the body of the for statement
  1155. FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
  1156. MS_EXCEPTION_IF_NULL(body_block);
  1157. body_block->AddPrevBlock(header_block);
  1158. // generate the iterator next apply
  1159. // process as following: `app = next(it); target = app[0]; it = app[1];`
  1160. CNodePtr app = body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
  1161. CNodePtr target_app =
  1162. body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(0))});
  1163. py::object target_node = python_adapter::GetPyObjAttr(node, "target");
  1164. CNodePtr iter2_app =
  1165. body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(1))});
  1166. WriteAssignVars(body_block, target_node, target_app);
  1167. // link the variable name with the target
  1168. auto it_info = std::make_shared<TraceIterator>(target_app->debug_info());
  1169. iter_param->debug_info()->set_trace_info(it_info);
  1170. iter2_app->debug_info()->set_trace_info(it_info);
  1171. iter_apply->debug_info()->set_trace_info(it_info);
  1172. FunctionBlockPtr after_block = nullptr;
  1173. {
  1174. TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
  1175. after_block = MakeFunctionBlock(*this);
  1176. }
  1177. MS_EXCEPTION_IF_NULL(after_block);
  1178. after_block->AddPrevBlock(header_block);
  1179. block->Jump(header_block, iter_apply);
  1180. body_block->Mature();
  1181. header_block->ConditionalJump(cond_apply, body_block, after_block);
  1182. // Parse loop body statements with loop context.
  1183. LoopContext loop_context{&loops_, header_block, iter2_app};
  1184. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1185. FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
  1186. if (after_body_block->func_graph()->get_return() == nullptr) {
  1187. after_body_block->Jump(header_block, iter2_app);
  1188. }
  1189. header_block->Mature();
  1190. after_block->Mature();
  1191. auto &end_block = loop_context.EndBlock();
  1192. if (end_block) {
  1193. // end_block exists if we encounter 'break' in loop body.
  1194. after_block->Jump(end_block, nullptr);
  1195. end_block->Mature();
  1196. return end_block;
  1197. }
  1198. // No 'break', no end_block.
  1199. return after_block;
  1200. }
  1201. // A for loop will generate 3 functions :the test, the body, and the continuation
  1202. // for x in xs:
  1203. // body
  1204. // it is compiled to be following statement
  1205. // i = 0
  1206. // while i < len(xs)
  1207. // x = xs[i]
  1208. // i = i + 1
  1209. // body
  1210. FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) {
  1211. MS_LOG(DEBUG) << "Process ast For by loop variable";
  1212. MS_EXCEPTION_IF_NULL(block);
  1213. AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
  1214. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1215. // get variable name of 'x' in statement 'for x in xs'
  1216. py::object target_node = python_adapter::GetPyObjAttr(node, "target");
  1217. // create statement 'len(xs)'
  1218. py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
  1219. AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
  1220. MS_EXCEPTION_IF_NULL(iter_node);
  1221. // Generate node for loop count and convert it to tensor, to make the loop not unroll
  1222. CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
  1223. auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations");
  1224. auto scalar_to_tensor_node = block->func_graph()->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
  1225. CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, scalar_len});
  1226. FunctionBlockPtr header_block =
  1227. GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
  1228. MS_EXCEPTION_IF_NULL(header_block);
  1229. // create loop variable 'i'
  1230. ParameterPtr loop_var = header_block->func_graph()->add_parameter();
  1231. // create loop condition 'i < len(xs)'
  1232. auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations");
  1233. auto less_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(prim_less)});
  1234. CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({less_node, loop_var, len_iter});
  1235. // generate the body of the for statement
  1236. FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
  1237. MS_EXCEPTION_IF_NULL(body_block);
  1238. body_block->AddPrevBlock(header_block);
  1239. // create 'x = xs[i]'
  1240. CNodePtr target_var = body_block->func_graph()->NewCNodeInOrder({op_getitem, iter_node, loop_var});
  1241. WriteAssignVars(body_block, target_node, target_var);
  1242. // create 'i = i + 1'
  1243. CNodePtr loop_var_inc = body_block->func_graph()->NewCNodeInOrder(
  1244. {NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(static_cast<int64_t>(1))});
  1245. body_block->WriteVariable(loop_var->name(), loop_var_inc);
  1246. // link the variable name with the target
  1247. auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
  1248. loop_var->debug_info()->set_trace_info(it_info);
  1249. len_iter->debug_info()->set_trace_info(it_info);
  1250. FunctionBlockPtr after_block = nullptr;
  1251. {
  1252. TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
  1253. after_block = MakeFunctionBlock(*this);
  1254. }
  1255. MS_EXCEPTION_IF_NULL(after_block);
  1256. after_block->AddPrevBlock(header_block);
  1257. block->Jump(header_block, NewValueNode(static_cast<int64_t>(0)));
  1258. body_block->Mature();
  1259. header_block->ConditionalJump(cond_node, body_block, after_block, false);
  1260. // Parse loop body statements with loop context.
  1261. LoopContext loop_context{&loops_, header_block, loop_var_inc};
  1262. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1263. FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
  1264. if (after_body_block->func_graph()->get_return() == nullptr) {
  1265. after_body_block->Jump(header_block, loop_var_inc);
  1266. }
  1267. header_block->Mature();
  1268. after_block->Mature();
  1269. auto &end_block = loop_context.EndBlock();
  1270. if (end_block) {
  1271. // end_block exists if we encounter 'break' in loop body.
  1272. after_block->Jump(end_block, nullptr);
  1273. end_block->Mature();
  1274. return end_block;
  1275. }
  1276. // No 'break', no end_block.
  1277. return after_block;
  1278. }
  1279. AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) {
  1280. MS_LOG(DEBUG) << "Process ast IfExp";
  1281. MS_EXCEPTION_IF_NULL(block);
  1282. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  1283. AnfNodePtr condition_node = ParseExprNode(block, test_node);
  1284. CNodePtr bool_node = block->ForceToBoolNode(condition_node);
  1285. FunctionBlockPtr true_block = nullptr;
  1286. FunctionBlockPtr false_block = nullptr;
  1287. {
  1288. TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
  1289. true_block = MakeFunctionBlock(*this);
  1290. }
  1291. {
  1292. TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
  1293. false_block = MakeFunctionBlock(*this);
  1294. }
  1295. MakeConditionBlocks(block, true_block, false_block);
  1296. // process the if-true branch
  1297. py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
  1298. true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode));
  1299. AnfNodePtr true_node = ParseExprNode(true_block, bodyNode);
  1300. // process the orelse branch
  1301. py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
  1302. false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode));
  1303. AnfNodePtr false_node = ParseExprNode(false_block, orelseNode);
  1304. true_block->func_graph()->set_output(true_node);
  1305. false_block->func_graph()->set_output(false_node);
  1306. // Use the Primitive replace the operation resolve node (switch)
  1307. // because the switch will eventually be converted to Primitive node
  1308. CNodePtr switch_app = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), bool_node,
  1309. NewValueNode(true_block->func_graph()),
  1310. NewValueNode(false_block->func_graph())});
  1311. std::vector<AnfNodePtr> call_graph_nodes{switch_app};
  1312. CNodePtr switch_app_call = block->func_graph()->NewCNodeInOrder(call_graph_nodes);
  1313. return switch_app_call;
  1314. }
  1315. void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
  1316. MS_EXCEPTION_IF_NULL(block);
  1317. MS_EXCEPTION_IF_NULL(assigned_node);
  1318. py::str name = python_adapter::GetPyObjAttr(targ, "id");
  1319. std::string name_id = name;
  1320. assigned_node->debug_info()->set_name(name_id);
  1321. // set the debug name of the constant graph
  1322. if (IsValueNode<FuncGraph>(assigned_node)) {
  1323. // the value should be graph
  1324. auto fg = GetValueNode<FuncGraphPtr>(assigned_node);
  1325. if (fg->debug_info()->name().empty()) {
  1326. fg->debug_info()->set_name(name_id);
  1327. }
  1328. }
  1329. block->WriteVariable(name_id, assigned_node);
  1330. }
  1331. void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
  1332. MS_EXCEPTION_IF_NULL(block);
  1333. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1334. py::list items = python_adapter::GetPyObjAttr(targ, "elts");
  1335. for (size_t i = 0; i < items.size(); i++) {
  1336. // Use the Primitive replace the operation resolve node (getitem)
  1337. // because the getitem will eventually be converted to Primitive node
  1338. CNodePtr item_apply =
  1339. block->func_graph()->NewCNodeInOrder({op_getitem, assigned_node, NewValueNode(static_cast<int64_t>(i))});
  1340. py::object elt = items[i];
  1341. WriteAssignVars(block, elt, item_apply);
  1342. }
  1343. }
  1344. void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ,
  1345. const AnfNodePtr &assigned_node) {
  1346. // Now only support the self.xx = xxxxx, can't support x.y = xxxx
  1347. AnfNodePtr target_node = ParseExprNode(block, targ);
  1348. MS_EXCEPTION_IF_NULL(target_node);
  1349. std::string attr_name = targ.attr("attr").cast<std::string>();
  1350. std::string var_name = "self." + attr_name;
  1351. // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
  1352. if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
  1353. MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined.";
  1354. }
  1355. auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
  1356. auto obj_type = obj.attr("__class__").attr("__name__");
  1357. if (!py::hasattr(obj, "__parameter__")) {
  1358. MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '"
  1359. << py::str(obj).cast<std::string>() << "' with type '"
  1360. << py::str(obj_type).cast<std::string>() << ".";
  1361. }
  1362. MS_EXCEPTION_IF_NULL(block);
  1363. MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString();
  1364. block->SetStateAssign(target_node, assigned_node);
  1365. }
  1366. void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ,
  1367. const AnfNodePtr &assigned_node) {
  1368. MS_EXCEPTION_IF_NULL(block);
  1369. AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
  1370. py::object value_obj = python_adapter::GetPyObjAttr(targ, "value");
  1371. py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice");
  1372. AnfNodePtr value_node = ParseExprNode(block, value_obj);
  1373. AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
  1374. CNodePtr setitem_app = block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node});
  1375. // getitem apply should return the sequence data structure itself
  1376. std::string var_name;
  1377. if (ast_->IsClassMember(value_obj)) {
  1378. std::string attr_name = value_obj.attr("attr").cast<std::string>();
  1379. var_name = "self." + attr_name;
  1380. if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
  1381. MS_EXCEPTION(TypeError) << "'" << var_name << "' was not defined in the class '__init__' function.";
  1382. }
  1383. auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
  1384. auto obj_type = obj.attr("__class__").attr("__name__");
  1385. if (!py::hasattr(obj, "__parameter__")) {
  1386. MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '"
  1387. << py::str(obj).cast<std::string>() << "' with type '"
  1388. << py::str(obj_type).cast<std::string>() << "'.";
  1389. }
  1390. block->WriteVariable(var_name, setitem_app);
  1391. return;
  1392. }
  1393. if (AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, value_obj))) ==
  1394. AST_SUB_TYPE_SUBSCRIPT) {
  1395. HandleAssignSubscript(block, value_obj, setitem_app);
  1396. return;
  1397. }
  1398. if (!py::hasattr(value_obj, "id")) {
  1399. MS_EXCEPTION(TypeError) << "Attribute id not found in " << py::str(value_obj).cast<std::string>();
  1400. }
  1401. var_name = value_obj.attr("id").cast<std::string>();
  1402. block->WriteVariable(var_name, setitem_app);
  1403. }
  1404. void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node) {
  1405. MS_EXCEPTION_IF_NULL(value_node);
  1406. MS_LOG(DEBUG) << "Process WriteAssignVars";
  1407. auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, targ)));
  1408. if (ast_type == AST_SUB_TYPE_NAME) {
  1409. HandleAssignName(block, targ, value_node);
  1410. } else if (ast_type == AST_SUB_TYPE_TUPLE) {
  1411. HandleAssignTuple(block, targ, value_node);
  1412. } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
  1413. HandleAssignSubscript(block, targ, value_node);
  1414. } else if (ast_->IsClassMember(targ)) {
  1415. HandleAssignClassMember(block, targ, value_node);
  1416. } else {
  1417. MS_LOG(EXCEPTION) << "Not supported assign type: " << ast_type
  1418. << " NodeInfo: " << trace::GetDebugInfo(value_node->debug_info());
  1419. }
  1420. }
  1421. // process a assign statement, such as a =b, a,b = tup
  1422. FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
  1423. MS_LOG(DEBUG) << "Process ast assign";
  1424. py::object value_object = python_adapter::GetPyObjAttr(node, "value");
  1425. AnfNodePtr value_node = ParseExprNode(block, value_object);
  1426. py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
  1427. py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__");
  1428. size_t count = LongToSize(pcount);
  1429. MS_LOG(DEBUG) << "The nodes count is " << count;
  1430. for (size_t i = 0; i < count; i++) {
  1431. auto target_node = py::cast<py::list>(targets_object)[i];
  1432. WriteAssignVars(block, target_node, value_node);
  1433. }
  1434. return block;
  1435. }
  1436. FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) {
  1437. if (loops_.empty()) {
  1438. // Report error if loop context not set for the 'break' statement.
  1439. MS_LOG(EXCEPTION) << "Unexpected 'break'.";
  1440. }
  1441. // Get current loop.
  1442. Loop &loop = loops_.top();
  1443. if (loop.end == nullptr) {
  1444. // Create end_block if it is not existed.
  1445. TraceGuard trace_guard(std::make_shared<TraceLoopEnd>(block->func_graph()->debug_info()));
  1446. loop.end = MakeFunctionBlock(*this);
  1447. }
  1448. // Jump to the end_block.
  1449. block->Jump(loop.end, nullptr);
  1450. return block;
  1451. }
  1452. FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) {
  1453. if (loops_.empty()) {
  1454. // Report error if loop context not set for the 'continue' statement.
  1455. MS_LOG(EXCEPTION) << "Unexpected 'continue.";
  1456. }
  1457. // Jump to the header of the loop with iterator called.
  1458. Loop &loop = loops_.top();
  1459. block->Jump(loop.header, loop.iterator);
  1460. return block;
  1461. }
  1462. FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) {
  1463. // We just bypass 'pass' statement.
  1464. return block;
  1465. }
  1466. AnfNodePtr FindPhis(const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis, const AnfNodePtr &node) {
  1467. const auto &inp = node->cast<ParameterPtr>();
  1468. const auto &iter = removable_phis.find(inp);
  1469. if (iter == removable_phis.end()) {
  1470. return node;
  1471. }
  1472. return FindPhis(removable_phis, iter->second);
  1473. }
  1474. void Parser::RemoveUnnecessaryPhis() {
  1475. // merge all removable phis to one map;
  1476. std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis;
  1477. std::vector<ParameterPtr> phis;
  1478. for (FunctionBlockPtr &block : func_block_list_) {
  1479. MS_EXCEPTION_IF_NULL(block);
  1480. removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end());
  1481. std::transform(block->removable_phis().begin(), block->removable_phis().end(), std::back_inserter(phis),
  1482. [](std::pair<ParameterPtr, AnfNodePtr> pair) { return pair.first; });
  1483. }
  1484. if (removable_phis.size() == 0) {
  1485. return;
  1486. }
  1487. auto fg_name = func_graph_->ToString();
  1488. auto mng = Manage(func_graph_, false);
  1489. // replace the nodes
  1490. // remove from inside to outside
  1491. for (int64_t idx = SizeToLong(phis.size() - 1); idx >= 0; idx--) {
  1492. auto phi = phis[LongToSize(idx)];
  1493. auto new_node = FindPhis(removable_phis, phi);
  1494. mng->Replace(phi, new_node);
  1495. }
  1496. // remove the parameter
  1497. for (FunctionBlockPtr &block : func_block_list_) {
  1498. MS_EXCEPTION_IF_NULL(block);
  1499. auto &local_removable_phis = block->removable_phis();
  1500. if (local_removable_phis.size() == 0) {
  1501. continue;
  1502. }
  1503. auto func_graph = block->func_graph();
  1504. auto &parameters = func_graph->parameters();
  1505. std::vector<AnfNodePtr> new_parameters(parameters.size());
  1506. auto it = std::copy_if(
  1507. parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](AnfNodePtr param) {
  1508. return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end();
  1509. });
  1510. // shrink container to new size
  1511. new_parameters.resize(std::distance(new_parameters.begin(), it));
  1512. func_graph->set_parameters(new_parameters);
  1513. }
  1514. for (auto fg : mng->func_graphs()) {
  1515. fg->ClearAllManagerInfo();
  1516. }
  1517. }
  1518. // ParseAst class code
  1519. bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) {
  1520. // init the type
  1521. target_type_ = PARSE_TARGET_UNKNOW;
  1522. // call python parse, get the parser fn
  1523. module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  1524. py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD);
  1525. // get the obj type
  1526. auto type = data_converter::GetObjType(obj_);
  1527. if (type == RESOLVE_TYPE_FUNCTION) {
  1528. target_type_ = PARSE_TARGET_FUNCTION;
  1529. function_ = obj_;
  1530. } else if (type == RESOLVE_TYPE_METHOD) {
  1531. // process the method ,need get the method's self obj
  1532. target_type_ = PARSE_TARGET_METHOD;
  1533. py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS);
  1534. if (py::isinstance<py::none>(method_object)) {
  1535. MS_LOG(ERROR) << "Get method's self object instance failed.";
  1536. return false;
  1537. }
  1538. target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
  1539. function_ = obj_;
  1540. obj_ = method_object;
  1541. } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) {
  1542. // obj is class instance, get the method to parse.
  1543. function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method);
  1544. if (py::isinstance<py::none>(function_)) {
  1545. MS_LOG(ERROR) << "Get obj method function failed.";
  1546. return false;
  1547. }
  1548. target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
  1549. // check the fn is method
  1550. auto obj_type = data_converter::GetObjType(function_);
  1551. if (obj_type != RESOLVE_TYPE_METHOD) {
  1552. MS_LOG(WARNING) << "Parse method function is invalid.";
  1553. return false;
  1554. }
  1555. } else {
  1556. MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type;
  1557. return false;
  1558. }
  1559. // call python parse get ast tree
  1560. parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method);
  1561. ast_tree_ = python_adapter::CallPyObjMethod(parser_, "parse");
  1562. // get fn name and module
  1563. function_module_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_module"));
  1564. function_name_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_name"));
  1565. function_filename_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "filename"));
  1566. function_line_offset_ = py::cast<int64_t>(python_adapter::GetPyObjAttr(parser_, "line_offset"));
  1567. return true;
  1568. }
  1569. // Get ast tree node : is the tree bode list[0]
  1570. py::object ParseAst::GetAstNode() {
  1571. py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body");
  1572. py::object ast_node = tree_body[0];
  1573. return ast_node;
  1574. }
  1575. py::list ParseAst::GetArgs(const py::object &func_node) {
  1576. py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS, func_node);
  1577. return ret;
  1578. }
  1579. py::list ParseAst::GetArgsDefaultValues(const py::object &func_node) {
  1580. py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node);
  1581. return ret;
  1582. }
  1583. AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) {
  1584. py::list list_value = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_NODE_TYPE, node);
  1585. if (list_value.size() < 2) {
  1586. MS_LOG(ERROR) << "The node of python method must has 2 values.";
  1587. return nullptr;
  1588. }
  1589. auto node_name = py::cast<std::string>(list_value[0]);
  1590. auto type = AstMainType(py::cast<int32_t>(list_value[1]));
  1591. return std::make_shared<AstNodeType>(node, node_name, type);
  1592. }
  1593. AstSubType ParseAst::GetOpType(const py::object &node) {
  1594. auto op_type = AstSubType(python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_AST_TYPE, node).cast<int32_t>());
  1595. return op_type;
  1596. }
  1597. bool ParseAst::IsClassMember(const py::object &node) {
  1598. py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node);
  1599. if (!py::isinstance<py::bool_>(ret)) {
  1600. MS_LOG(ERROR) << "The result of mod function parse, should be bool type.";
  1601. return false;
  1602. }
  1603. return ret.cast<bool>();
  1604. }
  1605. bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
  1606. if (func_graph == nullptr) {
  1607. MS_LOG(ERROR) << "FuncGraph is null";
  1608. return false;
  1609. }
  1610. if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) {
  1611. MS_LOG(DEBUG) << "No flags";
  1612. return true;
  1613. }
  1614. py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG);
  1615. for (auto &item : flags) {
  1616. if (!py::isinstance<py::str>(item.first)) {
  1617. MS_LOG(ERROR) << "Type error in flags dict convert";
  1618. return false;
  1619. }
  1620. auto name = py::cast<std::string>(item.first);
  1621. if (py::isinstance<py::bool_>(item.second)) {
  1622. auto value = py::cast<bool>(item.second);
  1623. MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
  1624. func_graph->set_flag(name, value);
  1625. } else if (py::isinstance<py::str>(item.second)) {
  1626. auto value = py::cast<std::string>(item.second);
  1627. MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
  1628. func_graph->set_attr(name, MakeValue(value));
  1629. } else {
  1630. MS_LOG(ERROR) << "Type error in flags/attrs dict convert";
  1631. return false;
  1632. }
  1633. }
  1634. return true;
  1635. }
  1636. // Generate and copy a ValueNode, or a CNode with its child nodes
  1637. static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr func_graph, const AnfNodePtr &param_node) {
  1638. MS_EXCEPTION_IF_NULL(param_node);
  1639. if (param_node->isa<ValueNode>()) {
  1640. return std::make_shared<ValueNode>(param_node->cast<ValueNodePtr>()->value());
  1641. }
  1642. // Parameter default value is CNode.
  1643. std::size_t index = 0;
  1644. std::vector<AnfNodePtr> old_cnodes;
  1645. old_cnodes.emplace_back(param_node);
  1646. auto res = func_graph->NewCNodeInOrder({});
  1647. std::vector<CNodePtr> new_cnodes;
  1648. new_cnodes.emplace_back(res);
  1649. while (index < old_cnodes.size()) {
  1650. auto current = old_cnodes[index];
  1651. auto current_new_cnode = new_cnodes[index];
  1652. index++;
  1653. MS_EXCEPTION_IF_NULL(current);
  1654. if (current->isa<CNode>()) {
  1655. auto &inputs = current->cast<CNodePtr>()->inputs();
  1656. for (auto it = inputs.begin(); it != inputs.end(); it++) {
  1657. AnfNodePtr input = *it;
  1658. if (input != nullptr && input->isa<CNode>()) {
  1659. old_cnodes.emplace_back(input);
  1660. auto new_cnode = func_graph->NewCNodeInOrder({});
  1661. new_cnodes.emplace_back(new_cnode);
  1662. current_new_cnode->add_input(new_cnode);
  1663. } else if (input->isa<ValueNode>()) {
  1664. current_new_cnode->add_input(std::make_shared<ValueNode>(input->cast<ValueNodePtr>()->value()));
  1665. } else {
  1666. MS_LOG(EXCEPTION) << "Wrong type item in default parameters: " << input->ToString();
  1667. }
  1668. }
  1669. }
  1670. }
  1671. return res;
  1672. }
  1673. FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
  1674. auto current_graph = dyn_cast<FuncGraph>(cell_ptr);
  1675. if (current_graph == nullptr) {
  1676. MS_LOG(EXCEPTION) << "Current graph cast failed from " << cell_ptr->ToString();
  1677. }
  1678. auto func_graph = std::make_shared<FuncGraph>();
  1679. func_graph->debug_info()->set_name(current_graph->debug_info()->name() + "_wrapper");
  1680. // Copy all parameters information
  1681. for (auto &para : current_graph->parameters()) {
  1682. auto param = func_graph->add_parameter();
  1683. auto orig_param = para->cast<ParameterPtr>();
  1684. auto name = orig_param->name();
  1685. param->set_name(name);
  1686. param->debug_info()->set_name(name);
  1687. }
  1688. func_graph->set_has_vararg(current_graph->has_vararg());
  1689. func_graph->set_has_kwarg(current_graph->has_kwarg());
  1690. func_graph->set_kwonlyargs_count(current_graph->kwonlyargs_count());
  1691. // Copy all default values
  1692. for (auto &d : current_graph->parameter_default_value()) {
  1693. func_graph->set_param_default_value(d.first, CopyNodesFromParamDefaultValue(func_graph, d.second));
  1694. }
  1695. // cell_obj
  1696. MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell));
  1697. parse::UpdateFuncGraphFlags(cell, func_graph);
  1698. // top graph's construct flag
  1699. if (py::hasattr(cell, "construct")) {
  1700. parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph);
  1701. }
  1702. auto unpacking = func_graph->has_vararg() || func_graph->has_kwarg();
  1703. if (!unpacking) {
  1704. std::vector<AnfNodePtr> inputs;
  1705. inputs.emplace_back(NewValueNode(cell_ptr));
  1706. auto &params = func_graph->parameters();
  1707. (void)std::transform(params.begin(), params.end(), std::back_inserter(inputs),
  1708. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  1709. func_graph->set_output(func_graph->NewCNodeInOrder(inputs));
  1710. } else {
  1711. // ret = cell_obj(*arg, *kwargs)
  1712. auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), func_graph->parameters());
  1713. // return ret
  1714. func_graph->set_output(call_fn);
  1715. }
  1716. return func_graph;
  1717. }
  1718. } // namespace parse
  1719. } // namespace mindspore