You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parse.cc 78 kB

5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/parse/parse.h"
  19. #include <utility>
  20. #include <string>
  21. #include <memory>
  22. #include <unordered_map>
  23. #include <sstream>
  24. #include <algorithm>
  25. #include "pipeline/jit/parse/resolve.h"
  26. #include "frontend/operator/ops.h"
  27. #include "pipeline/jit/parse/data_converter.h"
  28. #include "frontend/operator/composite/composite.h"
  29. #include "utils/ms_context.h"
  30. #include "debug/trace.h"
  31. namespace mindspore {
  32. namespace parse {
  33. FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) {
  34. (void)python_adapter::set_python_scoped();
  35. if (!obj || py::isinstance<py::none>(obj)) {
  36. MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none";
  37. return nullptr;
  38. }
  39. auto ast = std::make_shared<ParseAst>(obj);
  40. bool success = ast->InitParseAstInfo(python_mod_get_parse_method);
  41. if (!success) {
  42. MS_LOG(ERROR) << "Parse code to ast tree failed.";
  43. return nullptr;
  44. }
  45. auto parser = std::make_shared<Parser>(ast);
  46. FuncGraphPtr func_graph = parser->ParseFuncGraph();
  47. if (func_graph == nullptr) {
  48. MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode();
  49. return nullptr;
  50. }
  51. return func_graph;
  52. }
  53. TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
  54. if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
  55. return kFloat32;
  56. } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
  57. return kFloat16;
  58. } else {
  59. return nullptr;
  60. }
  61. }
  62. // If any mixed precision flag add a cast node after the parameter node.
  63. AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
  64. TypePtr dst_type;
  65. if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
  66. dst_type = kFloat32;
  67. } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
  68. dst_type = kFloat16;
  69. } else {
  70. return param;
  71. }
  72. auto cast_helper = prim::kPrimMixedPrecisionCast;
  73. auto cast = func_graph->NewCNodeAfter(param, {NewValueNode(cast_helper), NewValueNode(dst_type), param});
  74. return cast;
  75. }
  76. FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
  77. Parser::Parser(const std::shared_ptr<ParseAst> &ast) : ast_(ast) {
  78. errcode_ = PARSE_SUCCESS;
  79. BuildMethodMap();
  80. }
  81. void Parser::BuildMethodMap() {
  82. stmt_method_map_["Return"] = &Parser::ParseReturn;
  83. stmt_method_map_["Expr"] = &Parser::ParseExpr;
  84. stmt_method_map_["If"] = &Parser::ParseIf;
  85. stmt_method_map_["Assign"] = &Parser::ParseAssign;
  86. stmt_method_map_["While"] = &Parser::ParseWhile;
  87. stmt_method_map_["For"] = &Parser::ParseFor;
  88. stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef;
  89. stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign;
  90. stmt_method_map_["Global"] = &Parser::ParseGlobal;
  91. stmt_method_map_["Break"] = &Parser::ParseBreak;
  92. stmt_method_map_["Continue"] = &Parser::ParseContinue;
  93. stmt_method_map_["Pass"] = &Parser::ParsePass;
  94. expr_method_map_["NoneType"] = &Parser::ParseNone;
  95. expr_method_map_["BinOp"] = &Parser::ParseBinOp;
  96. expr_method_map_["Name"] = &Parser::ParseName;
  97. expr_method_map_["Num"] = &Parser::ParseNum;
  98. expr_method_map_["Str"] = &Parser::ParseStr;
  99. expr_method_map_["Constant"] = &Parser::ParseConstant;
  100. expr_method_map_["NameConstant"] = &Parser::ParseNameConstant;
  101. expr_method_map_["Call"] = &Parser::ParseCall;
  102. expr_method_map_["IfExp"] = &Parser::ParseIfExp;
  103. expr_method_map_["Attribute"] = &Parser::ParseAttribute;
  104. expr_method_map_["Compare"] = &Parser::ParseCompare;
  105. expr_method_map_["BoolOp"] = &Parser::ParseBoolOp;
  106. expr_method_map_["Lambda"] = &Parser::ParseLambda;
  107. expr_method_map_["Tuple"] = &Parser::ParseTuple;
  108. expr_method_map_["List"] = &Parser::ParseList;
  109. expr_method_map_["Subscript"] = &Parser::ParseSubscript;
  110. expr_method_map_["Slice"] = &Parser::ParseSlice;
  111. expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice;
  112. expr_method_map_["Index"] = &Parser::ParseIndex;
  113. expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
  114. expr_method_map_["Dict"] = &Parser::ParseDict;
  115. expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
  116. }
  117. void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
  118. void Parser::InitParserEnvironment(const py::object &obj) {
  119. Parser::top_func_graph_ = FuncGraphWeakPtr();
  120. ScopeManager::GetInstance().ClearScope();
  121. (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj);
  122. }
  123. void Parser::CleanParserResource() {
  124. Parser::top_func_graph_ = FuncGraphWeakPtr();
  125. ScopeManager::GetInstance().ClearScope();
  126. }
  127. void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &ast) {
  128. // Check whether the functions referred by this function and itself are missing 'return' statement
  129. auto mng = Manage(fn, false);
  130. for (const auto &func_graph : mng->func_graphs()) {
  131. if (func_graph->get_return() != nullptr) {
  132. continue;
  133. }
  134. py::object node = ast->GetAstNode();
  135. py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  136. py::str desc =
  137. python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]);
  138. MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
  139. }
  140. }
  141. FuncGraphPtr Parser::ParseFuncGraph() {
  142. // Get ast FunctionDef node
  143. py::object node = ast_->GetAstNode();
  144. FunctionBlockPtr pFnBlock = ParseFunction(node);
  145. if (errcode() != PARSE_SUCCESS) {
  146. MS_LOG(ERROR) << "Parse function error, code is " << errcode();
  147. return nullptr;
  148. }
  149. RemoveUnnecessaryPhis();
  150. MS_EXCEPTION_IF_NULL(pFnBlock);
  151. CheckFuncReturn(pFnBlock->func_graph(), ast_);
  152. return pFnBlock->func_graph();
  153. }
  154. void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
  155. py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args");
  156. py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg");
  157. block->func_graph()->set_has_vararg(!py::isinstance<py::none>(var_arg_node));
  158. py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg");
  159. block->func_graph()->set_has_kwarg(!py::isinstance<py::none>(kw_arg_node));
  160. py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs");
  161. block->func_graph()->set_kwonlyargs_count(SizeToLong(kwonly_args.size()));
  162. MS_EXCEPTION_IF_NULL(ast_);
  163. py::list args = ast_->GetArgs(fn_node);
  164. for (std::size_t i = 0; i < args.size(); i++) {
  165. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  166. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  167. if (arg_name == "self") {
  168. continue;
  169. }
  170. }
  171. TraceGuard guard(GetLocation(args[i]));
  172. auto para_node = std::make_shared<Parameter>(block->func_graph());
  173. MS_EXCEPTION_IF_NULL(para_node);
  174. para_node->set_name(arg_name);
  175. para_node->debug_info()->set_name(arg_name);
  176. block->func_graph()->add_parameter(para_node);
  177. AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block->func_graph(), para_node);
  178. block->WriteVariable(arg_name, para_after_cast);
  179. MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
  180. }
  181. }
  182. void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
  183. py::list defaults = ast_->GetArgsDefaultValues(fn_node);
  184. py::list args = ast_->GetArgs(fn_node);
  185. std::vector<std::string> namelist_for_default_value;
  186. std::vector<AnfNodePtr> default_values;
  187. for (std::size_t i = 0; i < args.size(); i++) {
  188. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  189. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  190. if (arg_name == "self") {
  191. continue;
  192. }
  193. }
  194. namelist_for_default_value.push_back(arg_name);
  195. if (py::isinstance<py::none>(defaults[i])) {
  196. default_values.push_back(NewValueNode(kNull));
  197. } else {
  198. default_values.push_back(ParseExprNode(block, defaults[i]));
  199. }
  200. }
  201. block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values);
  202. }
  203. ScopePtr Parser::GetScopeForParseFunction() {
  204. ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope();
  205. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  206. py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj());
  207. if (!py::isinstance<py::none>(scope_str)) {
  208. auto scope_name = py::cast<std::string>(scope_str);
  209. scope = std::make_shared<Scope>(scope_name);
  210. }
  211. }
  212. return scope;
  213. }
  214. FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) {
  215. ScopePtr scope = GetScopeForParseFunction();
  216. // The node created in the parsefunction context, will inherit the scope created using scope_guard
  217. ScopeGuard scope_guard(scope);
  218. TraceGuard trace_guard(data_converter::GetObjKey(ast()->obj())[0], GetLocation(node));
  219. FunctionBlockPtr pFunBlock = MakeFunctionBlock(*this);
  220. if (block != nullptr) {
  221. pFunBlock->AddPrevBlock(block);
  222. } else {
  223. func_graph_ = pFunBlock->func_graph();
  224. }
  225. pFunBlock->Mature();
  226. auto current_fg = pFunBlock->func_graph();
  227. auto function_name = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "name"));
  228. MS_LOG(DEBUG) << "The function name is " << function_name;
  229. current_fg->debug_info()->set_name(function_name);
  230. MS_EXCEPTION_IF_NULL(ast_);
  231. py::list deco_list = node.attr("decorator_list");
  232. if (!deco_list.empty()) {
  233. current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
  234. }
  235. bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
  236. if (!ast_->obj().is(ast_->function())) {
  237. set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg);
  238. }
  239. if (!set_flag) {
  240. MS_LOG(ERROR) << "Set flags failed";
  241. return nullptr;
  242. }
  243. GenerateArgsNodeForFunction(pFunBlock, node);
  244. // When parsing the top graph of construct, save the top graph
  245. if (GetTopFuncGraph() == nullptr) {
  246. UpdateTopFuncGraph(pFunBlock->func_graph());
  247. }
  248. // Save the function node to block
  249. pFunBlock->WriteVariable(function_name, NewValueNode(current_fg));
  250. py::object funcObj = python_adapter::GetPyObjAttr(node, "body");
  251. (void)ParseStatements(pFunBlock, funcObj);
  252. // Add unused variables as isolate nodes.
  253. for (auto &func_block : func_block_list_) {
  254. if (func_block->func_graph()->get_return() != nullptr) {
  255. // Find unused variables.
  256. func_block->FindIsolatedNodes();
  257. // Attach all isolated nodes.
  258. func_block->AttachIsolatedNodesBeforeReturn();
  259. }
  260. }
  261. if (current_fg->get_return() == nullptr) {
  262. py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  263. py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, node, ret[0], ret[1]);
  264. MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
  265. }
  266. GenerateArgsDefaultValueForFunction(pFunBlock, node);
  267. return pFunBlock;
  268. }
  269. FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) {
  270. auto node_list = py::cast<py::list>(nodes);
  271. size_t count = py::len(node_list);
  272. MS_LOG(DEBUG) << "The nodes count is " << count;
  273. for (size_t i = 0; i < count; ++i) {
  274. auto node = node_list[i];
  275. block = ParseStatement(block, node);
  276. // Insert appropriate depended items for the function block if it has a return node
  277. if (block->func_graph()->get_return() != nullptr) {
  278. // Skip statements after 'return' (or 'break', 'continue').
  279. break;
  280. }
  281. }
  282. return block;
  283. }
  284. FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) {
  285. TraceGuard trace_guard(GetLocation(node));
  286. auto node_type = ast_->GetNodeType(node);
  287. // Check the node type
  288. AstMainType nodeType = node_type->main_type();
  289. if (nodeType != AST_MAIN_TYPE_STMT) {
  290. MS_LOG(INFO) << "Node type is error : " << nodeType;
  291. return block;
  292. }
  293. // Call the process function
  294. std::string node_name = node_type->node_name();
  295. MS_LOG(DEBUG) << "Ast node is " << node_name;
  296. if (stmt_method_map_.count(node_name)) {
  297. auto stmt_block = (this->*stmt_method_map_[node_name])(block, node);
  298. TraceManager::ClearParseOrResolveDebugInfo();
  299. return stmt_block;
  300. } else {
  301. errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
  302. MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "'.";
  303. }
  304. }
  305. AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) {
  306. MS_LOG(DEBUG) << "Process ast expr";
  307. TraceGuard trace_guard(GetLocation(node));
  308. auto node_type = ast_->GetNodeType(node);
  309. // Check the node type
  310. AstMainType node_main_type = node_type->main_type();
  311. if (node_main_type != AST_MAIN_TYPE_EXPR) {
  312. MS_LOG(ERROR) << "Node type is error : " << node_main_type;
  313. errcode_ = PARSE_NODE_TYPE_NO_MATCH;
  314. return nullptr;
  315. }
  316. // Call the process function
  317. std::string node_name = node_type->node_name();
  318. MS_LOG(DEBUG) << "Ast node is " << node_name;
  319. if (expr_method_map_.count(node_name)) {
  320. auto expr_node = (this->*expr_method_map_[node_name])(block, node);
  321. TraceManager::ClearParseOrResolveDebugInfo();
  322. return expr_node;
  323. } else {
  324. errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
  325. MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "'.";
  326. }
  327. }
  328. // Process the expr statement and expand it
  329. FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) {
  330. MS_LOG(DEBUG) << "Process ast Expr";
  331. // Expr only have value, no target
  332. py::tuple expand_info = ast_->CallParseModFunction(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node);
  333. // Refer python function expand_expr_statement, expand_info is one of the following:
  334. // True, expr.value, x
  335. // True, expr.value
  336. // False, None, None
  337. //
  338. // Check the expand info result
  339. auto is_expand = py::cast<bool>(expand_info[0]);
  340. if (is_expand) {
  341. // Process the expr statement
  342. py::object value_object = expand_info[1];
  343. // Make a Expr CNode.
  344. AnfNodePtr call_node = ParseExprNode(block, value_object);
  345. if (py::len(expand_info) == 2) {
  346. // Expression that not assigned to any variable.
  347. // This is usually a call with side effects.
  348. // e.g.: print(x)
  349. // We save it as an isolated node.
  350. auto &no_return_node = call_node;
  351. MS_LOG(INFO) << "Isolated node found(NoReturn), no_return_node: " << no_return_node->DebugString(2)
  352. << ", block: " << block << "/"
  353. << (block->func_graph() ? block->func_graph()->ToString() : "FG(Null)")
  354. << ", Line: " << trace::GetDebugInfo(no_return_node->debug_info(), "", kSourceLineTipDiscard);
  355. block->AddIsolatedNode(no_return_node);
  356. } else {
  357. // Expand the assign statement,
  358. // e.g.: x.append(y) -> x = x.append(y)
  359. py::object target_node = expand_info[2];
  360. WriteAssignVars(block, target_node, call_node);
  361. }
  362. }
  363. return block;
  364. }
  365. LocationPtr Parser::GetLocation(const py::object &node) const {
  366. MS_EXCEPTION_IF_NULL(ast_);
  367. py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  368. if (ret.size() < 5) {
  369. MS_LOG(EXCEPTION) << "List size should not be less than 5.";
  370. }
  371. // Refer to Location::Location() for each member of ret: line, column, line_end, column_end.
  372. auto location = std::make_shared<Location>(ret[0].cast<std::string>(), ret[1].cast<int64_t>(), ret[2].cast<int64_t>(),
  373. ret[3].cast<int64_t>(), ret[4].cast<int64_t>());
  374. return location;
  375. }
  376. void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
  377. const FunctionBlockPtr &false_block) {
  378. true_block->AddPrevBlock(pre_block);
  379. true_block->Mature();
  380. false_block->AddPrevBlock(pre_block);
  381. false_block->Mature();
  382. }
  383. FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) {
  384. MS_LOG(DEBUG) << "Process ast return";
  385. MS_EXCEPTION_IF_NULL(block);
  386. // Create return valuenode
  387. AnfNodePtr pReturnValueNode = NewValueNode(prim::kPrimReturn);
  388. // Parse the return Statements value
  389. py::object value = python_adapter::GetPyObjAttr(node, "value");
  390. AnfNodePtr pReturnStatementNode = ParseExprNode(block, value);
  391. // Create the cnode
  392. CNodePtr pReturnCNode = block->func_graph()->NewCNodeInOrder({pReturnValueNode, pReturnStatementNode});
  393. block->func_graph()->set_return(pReturnCNode);
  394. return block;
  395. }
  396. // Process binary operators,eg: `a + b`, `a | b`, etc.
  397. AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) {
  398. MS_LOG(DEBUG) << "Process ast BinOP";
  399. py::object left = python_adapter::GetPyObjAttr(node, "left");
  400. py::object right = python_adapter::GetPyObjAttr(node, "right");
  401. py::object op = python_adapter::GetPyObjAttr(node, "op");
  402. // Create left and right ANF node
  403. AnfNodePtr left_node = ParseExprNode(block, left);
  404. if (left_node == nullptr) {
  405. MS_LOG(WARNING) << "DoBinOp process left node failed: " << errcode();
  406. return nullptr;
  407. }
  408. AnfNodePtr right_node = ParseExprNode(block, right);
  409. if (right_node == nullptr) {
  410. MS_LOG(WARNING) << "DoBinOp process right node failed:" << errcode();
  411. return nullptr;
  412. }
  413. // Resolve the op
  414. AnfNodePtr op_node = block->MakeResolveAstOp(op);
  415. // Create apply node
  416. return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
  417. }
  418. AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) {
  419. MS_LOG(DEBUG) << "Process ast Name";
  420. auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
  421. MS_LOG(DEBUG) << "The Name id is " << name_id;
  422. if (block->IsGlobalVar(name_id)) {
  423. return block->MakeResolveSymbol(name_id);
  424. }
  425. return block->ReadVariable(name_id);
  426. }
  427. AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) {
  428. MS_LOG(DEBUG) << "Process ast NoneType";
  429. return NewValueNode(kNone);
  430. }
  431. AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) {
  432. MS_LOG(DEBUG) << "Process ast Ellipsis";
  433. return NewValueNode(kEllipsis);
  434. }
  435. AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
  436. MS_LOG(DEBUG) << "Process ast Num";
  437. py::object obj = python_adapter::GetPyObjAttr(node, "n");
  438. if (py::isinstance<py::int_>(obj)) {
  439. MS_LOG(INFO) << "The Num is int64_t:" << (std::string)py::str(obj);
  440. auto data = py::cast<int64_t>(obj);
  441. return NewValueNode(data);
  442. } else if (py::isinstance<py::float_>(obj)) {
  443. MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj);
  444. auto data = py::cast<float>(obj);
  445. return NewValueNode(data);
  446. } else {
  447. // no else actually
  448. MS_LOG(ERROR) << "Unsupported Num type : " << (std::string)py::str(obj);
  449. errcode_ = PARSE_NODE_TYPE_UNKNOWN;
  450. return nullptr;
  451. }
  452. }
  453. AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) {
  454. MS_LOG(DEBUG) << "Process ast Str";
  455. auto str_s = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "s"));
  456. return NewValueNode(str_s);
  457. }
  458. AnfNodePtr Parser::ParseConstant(const FunctionBlockPtr &, const py::object &node) {
  459. MS_LOG(DEBUG) << "Process ast Constant";
  460. py::object obj = python_adapter::GetPyObjAttr(node, "value");
  461. if (py::isinstance<py::bool_>(obj)) {
  462. MS_LOG(INFO) << "The Constant is bool:" << (std::string)py::str(obj);
  463. auto data = py::cast<bool>(obj);
  464. return NewValueNode(data);
  465. } else if (py::isinstance<py::int_>(obj)) {
  466. MS_LOG(INFO) << "The Constant is int64_t:" << (std::string)py::str(obj);
  467. auto data = py::cast<int64_t>(obj);
  468. return NewValueNode(data);
  469. } else if (py::isinstance<py::float_>(obj)) {
  470. MS_LOG(INFO) << "The Constant is float:" << (std::string)py::str(obj);
  471. auto data = py::cast<float>(obj);
  472. return NewValueNode(data);
  473. } else if (py::isinstance<py::str>(obj)) {
  474. MS_LOG(INFO) << "The Constant is string:" << (std::string)py::str(obj);
  475. auto data = py::cast<std::string>(obj);
  476. return NewValueNode(data);
  477. } else if (py::isinstance<py::none>(obj)) {
  478. MS_LOG(INFO) << "The Constant is none:" << (std::string)py::str(obj);
  479. return NewValueNode(kNone);
  480. } else if (py::isinstance<py::ellipsis>(obj)) {
  481. MS_LOG(INFO) << "The Constance is ellipsis:" << (std::string)py::str(obj);
  482. return NewValueNode(kEllipsis);
  483. } else {
  484. // no else actually
  485. MS_EXCEPTION(TypeError) << "Unsupported Constant type : " << (std::string)py::str(obj);
  486. }
  487. }
  488. AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) {
  489. MS_LOG(DEBUG) << "Process ast NameConstant";
  490. py::object obj = python_adapter::GetPyObjAttr(node, "value");
  491. if (py::isinstance<py::bool_>(obj)) {
  492. MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj);
  493. auto data = py::cast<bool>(obj);
  494. return NewValueNode(data);
  495. } else if (py::isinstance<py::none>(obj)) {
  496. MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj);
  497. return NewValueNode(kNone);
  498. } else {
  499. // no else actually
  500. MS_LOG(ERROR) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
  501. errcode_ = PARSE_NODE_TYPE_UNKNOWN;
  502. return nullptr;
  503. }
  504. }
  505. AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes) {
  506. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  507. std::vector<AnfNodePtr> make_tuple_nodes;
  508. make_tuple_nodes.push_back(make_tuple_op);
  509. (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes),
  510. [](AnfNodePtr arg) -> AnfNodePtr { return arg; });
  511. return block->func_graph()->NewCNodeInOrder(make_tuple_nodes);
  512. }
  513. AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) {
  514. py::object father_class;
  515. if (args.empty()) {
  516. father_class = py::none();
  517. } else if (args.size() == 2) {
  518. father_class = args[0];
  519. auto arg_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[1])));
  520. if (arg_type != AST_SUB_TYPE_NAME || py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) != "self") {
  521. MS_EXCEPTION(ArgumentError) << "When call 'super', the second arg should be 'self'.";
  522. }
  523. } else {
  524. MS_EXCEPTION(ArgumentError) << "When call 'super', the args number should be 0 or 2, but got" << args.size() << ".";
  525. }
  526. py::object target_class_instance = ast()->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast()->obj());
  527. py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
  528. NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
  529. SymbolPtr symbol = std::make_shared<Symbol>("namespace");
  530. return block->MakeResolve(name_space, symbol);
  531. }
  532. // Process function call, eg : f1(x, y) ...
  533. AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) {
  534. MS_LOG(DEBUG) << "Process ast Call";
  535. // Process function call
  536. py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func");
  537. py::list args = python_adapter::GetPyObjAttr(node, "args");
  538. auto arg_type =
  539. AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, function_ast_node)));
  540. if (arg_type == AST_SUB_TYPE_NAME) {
  541. auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id"));
  542. if (name_id == "super") {
  543. return ParseSuper(block, args);
  544. }
  545. }
  546. AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node);
  547. // Function call arguments should be passed in as groups and unpacked later using unpack call
  548. std::vector<AnfNodePtr> packed_arguments;
  549. std::vector<AnfNodePtr> group_arguments;
  550. bool need_unpack_args = ParseArgsInCall(block, args, &packed_arguments, &group_arguments);
  551. bool need_unpack_keywords = ParseKeywordsInCall(block, node, &packed_arguments);
  552. // If there is stared or keyword argument, unpack may be needed
  553. bool need_unpack = need_unpack_args || need_unpack_keywords;
  554. return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack);
  555. }
  556. CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_anf_node,
  557. const std::vector<AnfNodePtr> &packed_arguments) {
  558. std::vector<AnfNodePtr> unpack_call_nodes;
  559. auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
  560. unpack_call_nodes.push_back(unpack_call_op);
  561. unpack_call_nodes.push_back(call_function_anf_node);
  562. (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes),
  563. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  564. CNodePtr unpack_call = func_graph->NewCNodeInOrder(unpack_call_nodes);
  565. return unpack_call;
  566. }
  567. AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
  568. const std::vector<AnfNodePtr> &packed_arguments,
  569. const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
  570. // If there is keyword arguments or starred, using an unpack_call op to unpack the argument
  571. if (need_unpack) {
  572. return MakeUnpackCall(block->func_graph(), call_function_anf_node, packed_arguments);
  573. }
  574. // else there is no keyword arguments and starred, parsed as normal arguments without unpack
  575. std::vector<AnfNodePtr> func_call_nodes;
  576. func_call_nodes.push_back(call_function_anf_node);
  577. (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes),
  578. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  579. CNodePtr call_anf_node = block->func_graph()->NewCNodeInOrder(func_call_nodes);
  580. return call_anf_node;
  581. }
  582. bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args,
  583. std::vector<AnfNodePtr> *packed_arguments, std::vector<AnfNodePtr> *group_arguments) {
  584. bool need_unpack = false;
  585. for (size_t i = 0; i < args.size(); i++) {
  586. auto arg_node = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[i])));
  587. if (arg_node == AST_SUB_TYPE_STARRED) {
  588. if (!group_arguments->empty()) {
  589. packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments));
  590. }
  591. packed_arguments->push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value")));
  592. group_arguments->clear();
  593. need_unpack = true;
  594. } else {
  595. group_arguments->push_back(ParseExprNode(block, args[i]));
  596. }
  597. }
  598. if (!group_arguments->empty()) {
  599. packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments));
  600. }
  601. return need_unpack;
  602. }
  603. bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node,
  604. std::vector<AnfNodePtr> *packed_arguments) {
  605. bool need_unpack = false;
  606. py::list keywords = python_adapter::GetPyObjAttr(node, "keywords");
  607. if (!keywords.empty()) {
  608. need_unpack = true;
  609. std::vector<AnfNodePtr> keys;
  610. std::vector<AnfNodePtr> values;
  611. for (size_t index = 0; index < keywords.size(); index++) {
  612. auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg");
  613. auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value");
  614. if (py::isinstance<py::none>(kw_key)) {
  615. packed_arguments->push_back(ParseExprNode(block, kw_value));
  616. } else {
  617. auto kw_key_c = kw_key.cast<std::string>();
  618. keys.push_back(NewValueNode(kw_key_c));
  619. values.push_back(ParseExprNode(block, kw_value));
  620. }
  621. }
  622. auto keys_tuple = GenerateMakeTuple(block, keys);
  623. auto values_tuple = GenerateMakeTuple(block, values);
  624. auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
  625. std::vector<AnfNodePtr> make_dict_nodes;
  626. make_dict_nodes.push_back(make_dict_op);
  627. make_dict_nodes.push_back(keys_tuple);
  628. make_dict_nodes.push_back(values_tuple);
  629. packed_arguments->push_back(block->func_graph()->NewCNodeInOrder(make_dict_nodes));
  630. }
  631. return need_unpack;
  632. }
  633. // Process call attributes of class type define, eg: x.y()
  634. AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) {
  635. MS_LOG(DEBUG) << "Process ast Attribute";
  636. // Process class value,eg: self.xx
  637. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  638. if (ast_->IsClassMember(node)) {
  639. std::string var_name = "self.";
  640. std::string attr_name = node.attr("attr").cast<std::string>();
  641. (void)var_name.append(attr_name);
  642. auto attr_obj = ast()->obj().attr(attr_name.c_str());
  643. if (py::hasattr(ast()->obj(), attr_name.c_str()) &&
  644. (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance<py::int_>(attr_obj) ||
  645. py::isinstance<py::float_>(attr_obj) || py::isinstance<py::bool_>(attr_obj) ||
  646. py::isinstance<py::str>(attr_obj) || data_converter::IsCellInstance(attr_obj))) {
  647. return block->MakeResolveSymbol(var_name);
  648. } else {
  649. return block->ReadVariable(var_name);
  650. }
  651. }
  652. }
  653. // Process the get attr
  654. // Use the Primitive replace the operation resolve node (getattr),
  655. // because the getattr will eventually be converted to Primitive node
  656. AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr);
  657. // Process the attr body
  658. py::object value_body = python_adapter::GetPyObjAttr(node, "value");
  659. AnfNodePtr value_node = ParseExprNode(block, value_body);
  660. if (value_node == nullptr) {
  661. MS_LOG(WARNING) << "Parse attribute failed";
  662. return nullptr;
  663. }
  664. // Process the node attr
  665. auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast<std::string>();
  666. MS_LOG(DEBUG) << "Attr = " << attr_str;
  667. AnfNodePtr attr_node = nullptr;
  668. {
  669. TraceGuard guard(GetLocation(python_adapter::GetPyObjAttr(node, "attr")));
  670. attr_node = NewValueNode(attr_str);
  671. }
  672. // Create the apply node
  673. return block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
  674. }
  675. // Process comparison expression : a == b. a > b etc.
  676. AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) {
  677. MS_LOG(DEBUG) << "Process ast Compare";
  678. // For python comparison ,there may be if x>y>5 ,
  679. // Which there is two ops , but we only support one now
  680. py::list ops = python_adapter::GetPyObjAttr(node, "ops");
  681. if (ops.size() > MAX_COMPARISON_OPS_SUPPORTED) {
  682. MS_LOG(ERROR) << "MindSpore does not support comparison with operators more than one now, ops size =" << ops.size();
  683. return nullptr;
  684. }
  685. py::object left = python_adapter::GetPyObjAttr(node, "left");
  686. py::list comparators = python_adapter::GetPyObjAttr(node, "comparators");
  687. AnfNodePtr left_node = ParseExprNode(block, left);
  688. AnfNodePtr right_node = ParseExprNode(block, comparators[0]);
  689. MS_EXCEPTION_IF_NULL(block);
  690. AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]);
  691. return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
  692. }
  693. AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) {
  694. // If there is only one bool op now
  695. if (value_list.size() == 1) {
  696. AnfNodePtr first_node = ParseExprNode(block, value_list[0]);
  697. return first_node;
  698. } else {
  699. py::object first = value_list[0];
  700. py::list rest;
  701. for (size_t i = 1; i < value_list.size(); i++) {
  702. rest.append(value_list[i]);
  703. }
  704. MS_EXCEPTION_IF_NULL(block);
  705. FunctionBlockPtr true_block = nullptr;
  706. FunctionBlockPtr false_block = nullptr;
  707. {
  708. TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
  709. true_block = MakeFunctionBlock(*this);
  710. }
  711. {
  712. TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
  713. false_block = MakeFunctionBlock(*this);
  714. }
  715. MakeConditionBlocks(block, true_block, false_block);
  716. FunctionBlockPtr b1, b2;
  717. // If it is and, we need to process the rest nodes;
  718. // If it is or, we continue to next
  719. if (mode == AST_SUB_TYPE_AND) {
  720. b1 = true_block;
  721. b2 = false_block;
  722. } else if (mode == AST_SUB_TYPE_OR) {
  723. b2 = true_block;
  724. b1 = false_block;
  725. } else {
  726. MS_LOG(ERROR) << "Not supported mode: " << mode;
  727. return nullptr;
  728. }
  729. AnfNodePtr test_node = ParseExprNode(block, first);
  730. AnfNodePtr rest_node = ProcessBoolOpValueList(b1, rest, mode);
  731. b1->func_graph()->set_output(rest_node);
  732. b2->func_graph()->set_output(test_node);
  733. auto cond_node = block->ForceToBoolNode(test_node);
  734. auto switch_app = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node,
  735. NewValueNode(true_block->func_graph()),
  736. NewValueNode(false_block->func_graph())});
  737. std::vector<AnfNodePtr> call_graph_nodes{switch_app};
  738. auto switch_app_call = block->func_graph()->NewCNodeInOrder(call_graph_nodes);
  739. return switch_app_call;
  740. }
  741. }
  742. // Process comparison expression : a and b. a or b .
  743. AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) {
  744. MS_LOG(DEBUG) << "Process ast BoolOp";
  745. py::object op_node = python_adapter::GetPyObjAttr(node, "op");
  746. AstSubType op_type = ast_->GetOpType(op_node);
  747. if (op_type == AST_SUB_TYPE_UNKNOWN) {
  748. MS_LOG(WARNING) << "ProcessBoolOp, got unknown op type";
  749. return nullptr;
  750. }
  751. py::list op_values = python_adapter::GetPyObjAttr(node, "values");
  752. return ProcessBoolOpValueList(block, op_values, op_type);
  753. }
  754. // Process a function def
  755. FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
  756. MS_LOG(DEBUG) << "Process ast FunctionDef";
  757. FunctionBlockPtr function_block = ParseFunction(node, block);
  758. MS_EXCEPTION_IF_NULL(function_block);
  759. // Get function name
  760. py::str name = python_adapter::GetPyObjAttr(node, "name");
  761. std::string function_name = name;
  762. ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph());
  763. block->WriteVariable(function_name, valuenode_graph);
  764. return block;
  765. }
  766. // Process a lambda expression . like lambda x,y: x + y
  767. AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
  768. MS_LOG(DEBUG) << "Process ast Lambda";
  769. FunctionBlockPtr func_block = MakeFunctionBlock(*this);
  770. func_block->AddPrevBlock(block);
  771. func_block->Mature();
  772. // Get lambda args
  773. py::list args = ast_->GetArgs(node);
  774. for (std::size_t i = 0; i < args.size(); i++) {
  775. std::string arg = py::cast<std::string>(args[i].attr("arg"));
  776. TraceGuard guard(GetLocation(args[i]));
  777. auto para_node = std::make_shared<Parameter>(func_block->func_graph());
  778. para_node->debug_info()->set_name(arg);
  779. func_block->func_graph()->add_parameter(para_node);
  780. func_block->WriteVariable(arg, para_node);
  781. MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg;
  782. }
  783. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  784. AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
  785. func_block->func_graph()->set_output(lambda_body_node);
  786. ValueNodePtr const_graph = NewValueNode(func_block->func_graph());
  787. return const_graph;
  788. }
  789. // Process a tuple
  790. AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) {
  791. MS_LOG(DEBUG) << "Process ast Tuple";
  792. MS_EXCEPTION_IF_NULL(block);
  793. py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
  794. if (elts.size() == 0) {
  795. auto empty_tuple = std::vector<ValuePtr>();
  796. return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
  797. }
  798. std::vector<AnfNodePtr> tuple_vec;
  799. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  800. tuple_vec.emplace_back(make_tuple_op);
  801. for (size_t i = 0; i < elts.size(); i++) {
  802. AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
  803. tuple_vec.emplace_back(node_ptr);
  804. }
  805. CNodePtr tuple_app = block->func_graph()->NewCNodeInOrder(tuple_vec);
  806. return tuple_app;
  807. }
  808. // Process a list
  809. AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) {
  810. MS_LOG(DEBUG) << "Process ast List";
  811. MS_EXCEPTION_IF_NULL(block);
  812. py::list elts = python_adapter::GetPyObjAttr(node, "elts");
  813. if (elts.size() == 0) {
  814. auto empty_list = std::vector<ValuePtr>();
  815. return NewValueNode(std::make_shared<ValueList>(empty_list));
  816. }
  817. std::vector<AnfNodePtr> list_vec;
  818. AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
  819. list_vec.emplace_back(make_list_op);
  820. for (size_t i = 0; i < elts.size(); i++) {
  821. AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
  822. list_vec.emplace_back(node_ptr);
  823. }
  824. CNodePtr list_app = block->func_graph()->NewCNodeInOrder(list_vec);
  825. return list_app;
  826. }
  827. // Process a subscript, such as x[y] , node expressed as value[slice]
  828. AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) {
  829. MS_LOG(DEBUG) << "Process ast Subscript";
  830. MS_EXCEPTION_IF_NULL(block);
  831. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  832. py::object value_node = python_adapter::GetPyObjAttr(node, "value");
  833. py::object slice_node = python_adapter::GetPyObjAttr(node, "slice");
  834. AnfNodePtr value = ParseExprNode(block, value_node);
  835. AnfNodePtr slice = ParseExprNode(block, slice_node);
  836. return block->func_graph()->NewCNodeInOrder({op_getitem, value, slice});
  837. }
  838. // Process a slice, get the slice value
  839. AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) {
  840. MS_LOG(DEBUG) << "Process ast Slice";
  841. MS_EXCEPTION_IF_NULL(block);
  842. AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE);
  843. py::object start = python_adapter::GetPyObjAttr(node, "lower");
  844. py::object stop = python_adapter::GetPyObjAttr(node, "upper");
  845. py::object step = python_adapter::GetPyObjAttr(node, "step");
  846. AnfNodePtr start_node = ParseExprNode(block, start);
  847. AnfNodePtr stop_node = ParseExprNode(block, stop);
  848. AnfNodePtr step_node = ParseExprNode(block, step);
  849. return block->func_graph()->NewCNodeInOrder({op_makeslice, start_node, stop_node, step_node});
  850. }
  851. // Process a extslice
  852. AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) {
  853. MS_LOG(DEBUG) << "Process ast ExtSlice";
  854. MS_EXCEPTION_IF_NULL(block);
  855. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  856. py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims");
  857. std::vector<AnfNodePtr> node_vec;
  858. node_vec.emplace_back(make_tuple_op);
  859. for (size_t i = 0; i < slice_tuple.size(); i++) {
  860. AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]);
  861. node_vec.emplace_back(node_ptr);
  862. }
  863. CNodePtr tuple_conde = block->func_graph()->NewCNodeInOrder(node_vec);
  864. return tuple_conde;
  865. }
  866. // Process a index, get the index number
  867. AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) {
  868. MS_LOG(DEBUG) << "Process ast Index";
  869. py::object value_node = python_adapter::GetPyObjAttr(node, "value");
  870. return ParseExprNode(block, value_node);
  871. }
  872. // Process a UnaryOp, +a, -b
  873. AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) {
  874. MS_LOG(DEBUG) << "Process ast UnaryOp";
  875. py::object op = python_adapter::GetPyObjAttr(node, "op");
  876. MS_EXCEPTION_IF_NULL(block);
  877. // Resolve the op
  878. AnfNodePtr op_node = block->MakeResolveAstOp(op);
  879. py::object operand = python_adapter::GetPyObjAttr(node, "operand");
  880. AnfNodePtr operand_node = ParseExprNode(block, operand);
  881. return block->func_graph()->NewCNodeInOrder({op_node, operand_node});
  882. }
  883. // Process a dict ast node expression
  884. AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
  885. MS_LOG(DEBUG) << "Process ast Dict";
  886. py::list keys = node.attr("keys");
  887. py::list values = node.attr("values");
  888. std::vector<AnfNodePtr> key_nodes;
  889. std::vector<AnfNodePtr> value_nodes;
  890. for (size_t i = 0; i < keys.size(); i++) {
  891. key_nodes.push_back(ParseExprNode(block, keys[i]));
  892. value_nodes.push_back(ParseExprNode(block, values[i]));
  893. }
  894. auto keys_tuple = GenerateMakeTuple(block, key_nodes);
  895. auto values_tuple = GenerateMakeTuple(block, value_nodes);
  896. MS_EXCEPTION_IF_NULL(block);
  897. auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
  898. return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple});
  899. }
  900. // Process a augment assign such as a += b or mat[stride_slice] += b.
  901. FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) {
  902. MS_LOG(DEBUG) << "Process ast AugAssign";
  903. MS_EXCEPTION_IF_NULL(block);
  904. MS_EXCEPTION_IF_NULL(ast_);
  905. py::object target_obj = python_adapter::GetPyObjAttr(node, "target");
  906. py::object op_obj = python_adapter::GetPyObjAttr(node, "op");
  907. py::object value_obj = python_adapter::GetPyObjAttr(node, "value");
  908. AnfNodePtr target_node = nullptr;
  909. AnfNodePtr op_node = block->MakeResolveAstOp(op_obj);
  910. AnfNodePtr value_node = ParseExprNode(block, value_obj);
  911. auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_obj)));
  912. if (ast_type == AST_SUB_TYPE_NAME) {
  913. target_node = ParseName(block, target_obj);
  914. } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
  915. target_node = ParseSubscript(block, target_obj);
  916. } else if (ast_->IsClassMember(target_obj)) {
  917. target_node = ParseAttribute(block, target_obj);
  918. } else {
  919. MS_LOG(EXCEPTION) << "Not supported augassign";
  920. }
  921. if (target_node == nullptr) {
  922. MS_LOG(EXCEPTION) << "Can not get target node ";
  923. }
  924. CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
  925. WriteAssignVars(block, target_obj, augassign_app);
  926. return block;
  927. }
  928. // Process global declaration such as 'global x';
  929. FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) {
  930. MS_LOG(DEBUG) << "Process ast Global";
  931. MS_EXCEPTION_IF_NULL(block);
  932. py::list vars = python_adapter::GetPyObjAttr(node, "names");
  933. for (auto &item : vars) {
  934. block->AddGlobalVar(py::cast<std::string>(item));
  935. }
  936. return block;
  937. }
  938. // Process a if statement
  939. FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) {
  940. MS_LOG(DEBUG) << "Process ast If";
  941. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  942. AnfNodePtr condition_node = ParseExprNode(block, test_node);
  943. MS_EXCEPTION_IF_NULL(block);
  944. CNodePtr bool_node = block->ForceToBoolNode(condition_node);
  945. FunctionBlockPtr true_block = nullptr;
  946. FunctionBlockPtr false_block = nullptr;
  947. {
  948. TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
  949. true_block = MakeFunctionBlock(*this);
  950. }
  951. {
  952. TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
  953. false_block = MakeFunctionBlock(*this);
  954. }
  955. MakeConditionBlocks(block, true_block, false_block);
  956. FunctionBlockPtr after_block = nullptr;
  957. {
  958. TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
  959. after_block = MakeFunctionBlock(*this);
  960. }
  961. if (MsContext::GetInstance()->backend_policy() != "ge") {
  962. // For backends excludes 'ge', it can handle multi graph call, use this flag to
  963. // generate call not inline `after_block` graph to reduce if by if switch expansion.
  964. after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
  965. }
  966. // Process the if-true branch
  967. py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
  968. FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode);
  969. // If the return_ is set ,it has its own continuation block
  970. if (true_end->func_graph()->get_return() == nullptr) {
  971. true_end->Jump(after_block, nullptr);
  972. }
  973. // Process the orelse branch
  974. py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
  975. FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode);
  976. // If the return_ is set ,it has its own continuation block
  977. if (false_end->func_graph()->get_return() == nullptr) {
  978. false_end->Jump(after_block, nullptr);
  979. }
  980. block->ConditionalJump(bool_node, true_block, false_block);
  981. after_block->Mature();
  982. return after_block;
  983. }
  984. FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) {
  985. MS_LOG(DEBUG) << "Process ast While";
  986. MS_EXCEPTION_IF_NULL(block);
  987. MS_LOG(INFO) << "Parse while statement";
  988. FunctionBlockPtr header_block = nullptr;
  989. FunctionBlockPtr body_block = nullptr;
  990. FunctionBlockPtr after_block = nullptr;
  991. {
  992. TraceGuard guard(std::make_shared<TraceWhileHeader>(block->func_graph()->debug_info()));
  993. header_block = MakeFunctionBlock(*this);
  994. }
  995. {
  996. TraceGuard guard(std::make_shared<TraceWhileBody>(block->func_graph()->debug_info()));
  997. body_block = MakeFunctionBlock(*this);
  998. }
  999. {
  1000. TraceGuard guard(std::make_shared<TraceWhileAfter>(block->func_graph()->debug_info()));
  1001. after_block = MakeFunctionBlock(*this);
  1002. }
  1003. body_block->AddPrevBlock(header_block);
  1004. after_block->AddPrevBlock(header_block);
  1005. block->Jump(header_block, nullptr);
  1006. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  1007. AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
  1008. condition_node = header_block->ForceToWhileCond(condition_node);
  1009. body_block->Mature();
  1010. header_block->ConditionalJump(condition_node, body_block, after_block);
  1011. // Parse loop body statements with loop context.
  1012. LoopContext loop_context{&loops_, header_block, nullptr};
  1013. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1014. FunctionBlockPtr after_body = ParseStatements(body_block, body_node);
  1015. if (after_body->func_graph()->get_return() == nullptr) {
  1016. after_body->Jump(header_block, nullptr);
  1017. }
  1018. header_block->Mature();
  1019. after_block->Mature();
  1020. auto &end_block = loop_context.EndBlock();
  1021. if (end_block) {
  1022. // end_block exists if we encounter 'break' in loop body.
  1023. after_block->Jump(end_block, nullptr);
  1024. end_block->Mature();
  1025. return end_block;
  1026. }
  1027. // No 'break', no end_block.
  1028. return after_block;
  1029. }
  1030. CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::object &node,
  1031. const AnfNodePtr &op_iter) {
  1032. py::object iter_node = python_adapter::GetPyObjAttr(node, "iter");
  1033. AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
  1034. return block->func_graph()->NewCNodeInOrder({op_iter, iter_anf_node});
  1035. }
  1036. CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
  1037. const AnfNodePtr &op_hasnext) {
  1038. MS_EXCEPTION_IF_NULL(header_block);
  1039. return header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
  1040. }
  1041. FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
  1042. TraceGuard trace_guard(trace_info);
  1043. FunctionBlockPtr body_block = MakeFunctionBlock(*this);
  1044. return body_block;
  1045. }
  1046. int64_t GetForTransToWhileLoop() {
  1047. static const auto loop_str = common::GetEnv("ENV_FOR_TO_WHILE_LOOP");
  1048. // int64 support 63bits positive num mostly.
  1049. if (loop_str.size() > 63 || loop_str.empty()) {
  1050. return MAX_FOR_LOOP_COUNT;
  1051. }
  1052. if (std::any_of(loop_str.begin(), loop_str.end(), [](char c) { return c < '0' || c > '9'; })) {
  1053. return MAX_FOR_LOOP_COUNT;
  1054. }
  1055. int64_t loop_count;
  1056. std::stringstream ss;
  1057. ss << loop_str;
  1058. ss >> loop_count;
  1059. return loop_count;
  1060. }
  1061. // A for loop will generate 3 functions :the test, the body, and the continuation
  1062. // for x in xs:
  1063. // body
  1064. // It is compiled to be following statement
  1065. // if len(xs) < max_loop_cnt:
  1066. // ParseForIter() // use iter to implement for loop, which always unroll loop
  1067. // else:
  1068. // ParseForLoop() // use loop var to implement for loop, which always sink loop
  1069. FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
  1070. MS_LOG(DEBUG) << "Process ast For, create an if else statement";
  1071. MS_EXCEPTION_IF_NULL(block);
  1072. // Create statement 'len(xs) < MAX_FOR_LOOP_COUNT'
  1073. AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
  1074. py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
  1075. AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
  1076. CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
  1077. CNodePtr bool_node = block->func_graph()->NewCNodeInOrder(
  1078. {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(GetForTransToWhileLoop())});
  1079. // Create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
  1080. FunctionBlockPtr true_block = nullptr;
  1081. FunctionBlockPtr false_block = nullptr;
  1082. {
  1083. TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
  1084. true_block = MakeFunctionBlock(*this);
  1085. }
  1086. {
  1087. TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
  1088. false_block = MakeFunctionBlock(*this);
  1089. }
  1090. MakeConditionBlocks(block, true_block, false_block);
  1091. FunctionBlockPtr after_block = nullptr;
  1092. {
  1093. TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
  1094. after_block = MakeFunctionBlock(*this);
  1095. }
  1096. FunctionBlockPtr true_end = ParseForIter(true_block, node);
  1097. true_end->Jump(after_block, nullptr);
  1098. FunctionBlockPtr false_end = ParseForLoop(false_block, node);
  1099. false_end->Jump(after_block, nullptr);
  1100. block->ConditionalJump(bool_node, true_block, false_block);
  1101. after_block->Mature();
  1102. return after_block;
  1103. }
  1104. // A for loop will generate 3 functions :the test, the body, and the continuation
  1105. // for x in xs:
  1106. // body
  1107. // It is compiled to be following statement
  1108. // it = iter(xs)
  1109. // while hastnext(it)
  1110. // x, it = next(it)
  1111. // body
  1112. FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) {
  1113. MS_LOG(DEBUG) << "Process ast For";
  1114. MS_EXCEPTION_IF_NULL(block);
  1115. AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
  1116. AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
  1117. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1118. AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
  1119. // Generate the iterator apply
  1120. CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter);
  1121. MS_EXCEPTION_IF_NULL(iter_apply);
  1122. FunctionBlockPtr header_block =
  1123. GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
  1124. MS_EXCEPTION_IF_NULL(header_block);
  1125. // Generate the hasnext apply which is a condition
  1126. ParameterPtr iter_param = header_block->func_graph()->add_parameter();
  1127. CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext);
  1128. // Generate the body of the for statement
  1129. FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
  1130. MS_EXCEPTION_IF_NULL(body_block);
  1131. body_block->AddPrevBlock(header_block);
  1132. // Generate the iterator next apply
  1133. // Process as following: `app = next(it); target = app[0]; it = app[1];`
  1134. CNodePtr app = body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
  1135. CNodePtr target_app =
  1136. body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(0))});
  1137. py::object target_node = python_adapter::GetPyObjAttr(node, "target");
  1138. CNodePtr iter2_app =
  1139. body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(1))});
  1140. WriteAssignVars(body_block, target_node, target_app);
  1141. // Link the variable name with the target
  1142. auto it_info = std::make_shared<TraceIterator>(target_app->debug_info());
  1143. iter_param->debug_info()->set_trace_info(it_info);
  1144. iter2_app->debug_info()->set_trace_info(it_info);
  1145. iter_apply->debug_info()->set_trace_info(it_info);
  1146. FunctionBlockPtr after_block = nullptr;
  1147. {
  1148. TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
  1149. after_block = MakeFunctionBlock(*this);
  1150. }
  1151. MS_EXCEPTION_IF_NULL(after_block);
  1152. after_block->AddPrevBlock(header_block);
  1153. block->Jump(header_block, iter_apply);
  1154. body_block->Mature();
  1155. header_block->ConditionalJump(cond_apply, body_block, after_block);
  1156. // Parse loop body statements with loop context.
  1157. LoopContext loop_context{&loops_, header_block, iter2_app};
  1158. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1159. FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
  1160. if (after_body_block->func_graph()->get_return() == nullptr) {
  1161. after_body_block->Jump(header_block, iter2_app);
  1162. }
  1163. header_block->Mature();
  1164. after_block->Mature();
  1165. auto &end_block = loop_context.EndBlock();
  1166. if (end_block) {
  1167. // end_block exists if we encounter 'break' in loop body.
  1168. after_block->Jump(end_block, nullptr);
  1169. end_block->Mature();
  1170. return end_block;
  1171. }
  1172. // No 'break', no end_block.
  1173. return after_block;
  1174. }
  1175. // A for loop will generate 3 functions :the test, the body, and the continuation
  1176. // for x in xs:
  1177. // body
  1178. // It is compiled to be following statement
  1179. // i = 0
  1180. // while i < len(xs)
  1181. // x = xs[i]
  1182. // i = i + 1
  1183. // body
  1184. FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) {
  1185. MS_LOG(DEBUG) << "Process ast For by loop variable";
  1186. MS_EXCEPTION_IF_NULL(block);
  1187. AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
  1188. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1189. // Get variable name of 'x' in statement 'for x in xs'
  1190. py::object target_node = python_adapter::GetPyObjAttr(node, "target");
  1191. // Create statement 'len(xs)'
  1192. py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
  1193. AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
  1194. MS_EXCEPTION_IF_NULL(iter_node);
  1195. // Generate node for loop count and convert it to tensor, to make the loop not unroll
  1196. CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
  1197. auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations");
  1198. auto scalar_to_tensor_node = block->func_graph()->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
  1199. CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, scalar_len});
  1200. FunctionBlockPtr header_block =
  1201. GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
  1202. MS_EXCEPTION_IF_NULL(header_block);
  1203. // Create loop variable 'i'
  1204. ParameterPtr loop_var = header_block->func_graph()->add_parameter();
  1205. // Create loop condition 'i < len(xs)'
  1206. auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations");
  1207. auto less_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(prim_less)});
  1208. CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({less_node, loop_var, len_iter});
  1209. // Generate the body of the for statement
  1210. FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
  1211. MS_EXCEPTION_IF_NULL(body_block);
  1212. body_block->AddPrevBlock(header_block);
  1213. // Create 'x = xs[i]'
  1214. CNodePtr target_var = body_block->func_graph()->NewCNodeInOrder({op_getitem, iter_node, loop_var});
  1215. WriteAssignVars(body_block, target_node, target_var);
  1216. // Create 'i = i + 1'
  1217. CNodePtr loop_var_inc = body_block->func_graph()->NewCNodeInOrder(
  1218. {NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(static_cast<int64_t>(1))});
  1219. body_block->WriteVariable(loop_var->name(), loop_var_inc);
  1220. // Link the variable name with the target
  1221. auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
  1222. loop_var->debug_info()->set_trace_info(it_info);
  1223. len_iter->debug_info()->set_trace_info(it_info);
  1224. FunctionBlockPtr after_block = nullptr;
  1225. {
  1226. TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
  1227. after_block = MakeFunctionBlock(*this);
  1228. }
  1229. MS_EXCEPTION_IF_NULL(after_block);
  1230. after_block->AddPrevBlock(header_block);
  1231. block->Jump(header_block, NewValueNode(static_cast<int64_t>(0)));
  1232. body_block->Mature();
  1233. header_block->ConditionalJump(cond_node, body_block, after_block, false);
  1234. // Parse loop body statements with loop context.
  1235. LoopContext loop_context{&loops_, header_block, loop_var_inc};
  1236. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1237. FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
  1238. if (after_body_block->func_graph()->get_return() == nullptr) {
  1239. after_body_block->Jump(header_block, loop_var_inc);
  1240. }
  1241. header_block->Mature();
  1242. after_block->Mature();
  1243. auto &end_block = loop_context.EndBlock();
  1244. if (end_block) {
  1245. // end_block exists if we encounter 'break' in loop body.
  1246. after_block->Jump(end_block, nullptr);
  1247. end_block->Mature();
  1248. return end_block;
  1249. }
  1250. // No 'break', no end_block.
  1251. return after_block;
  1252. }
  1253. AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) {
  1254. MS_LOG(DEBUG) << "Process ast IfExp";
  1255. MS_EXCEPTION_IF_NULL(block);
  1256. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  1257. AnfNodePtr condition_node = ParseExprNode(block, test_node);
  1258. CNodePtr bool_node = block->ForceToBoolNode(condition_node);
  1259. FunctionBlockPtr true_block = nullptr;
  1260. FunctionBlockPtr false_block = nullptr;
  1261. {
  1262. TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
  1263. true_block = MakeFunctionBlock(*this);
  1264. }
  1265. {
  1266. TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
  1267. false_block = MakeFunctionBlock(*this);
  1268. }
  1269. MakeConditionBlocks(block, true_block, false_block);
  1270. // Process the if-true branch
  1271. py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
  1272. true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode));
  1273. AnfNodePtr true_node = ParseExprNode(true_block, bodyNode);
  1274. // Process the orelse branch
  1275. py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
  1276. false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode));
  1277. AnfNodePtr false_node = ParseExprNode(false_block, orelseNode);
  1278. true_block->func_graph()->set_output(true_node);
  1279. false_block->func_graph()->set_output(false_node);
  1280. // Use the Primitive replace the operation resolve node (switch),
  1281. // because the switch will eventually be converted to Primitive node
  1282. CNodePtr switch_app = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), bool_node,
  1283. NewValueNode(true_block->func_graph()),
  1284. NewValueNode(false_block->func_graph())});
  1285. std::vector<AnfNodePtr> call_graph_nodes{switch_app};
  1286. CNodePtr switch_app_call = block->func_graph()->NewCNodeInOrder(call_graph_nodes);
  1287. return switch_app_call;
  1288. }
  1289. void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
  1290. MS_EXCEPTION_IF_NULL(block);
  1291. MS_EXCEPTION_IF_NULL(assigned_node);
  1292. py::str name = python_adapter::GetPyObjAttr(targ, "id");
  1293. std::string name_id = name;
  1294. assigned_node->debug_info()->set_name(name_id);
  1295. // Set the debug name of the constant graph
  1296. if (IsValueNode<FuncGraph>(assigned_node)) {
  1297. // The value should be graph
  1298. auto fg = GetValueNode<FuncGraphPtr>(assigned_node);
  1299. if (fg->debug_info()->name().empty()) {
  1300. fg->debug_info()->set_name(name_id);
  1301. }
  1302. }
  1303. block->WriteVariable(name_id, assigned_node);
  1304. }
  1305. void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
  1306. MS_EXCEPTION_IF_NULL(block);
  1307. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1308. py::list items = python_adapter::GetPyObjAttr(targ, "elts");
  1309. for (size_t i = 0; i < items.size(); i++) {
  1310. // Use the Primitive replace the operation resolve node (getitem),
  1311. // because the getitem will eventually be converted to Primitive node
  1312. CNodePtr item_apply =
  1313. block->func_graph()->NewCNodeInOrder({op_getitem, assigned_node, NewValueNode(static_cast<int64_t>(i))});
  1314. py::object elt = items[i];
  1315. WriteAssignVars(block, elt, item_apply);
  1316. }
  1317. }
  1318. void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ,
  1319. const AnfNodePtr &assigned_node) {
  1320. // Now only support the self.xx = xxxxx, can't support x.y = xxxx
  1321. AnfNodePtr target_node = ParseExprNode(block, targ);
  1322. MS_EXCEPTION_IF_NULL(target_node);
  1323. auto attr_name = targ.attr("attr").cast<std::string>();
  1324. std::string var_name = "self." + attr_name;
  1325. // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
  1326. if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
  1327. MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined.";
  1328. }
  1329. auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
  1330. auto obj_type = obj.attr("__class__").attr("__name__");
  1331. if (!py::hasattr(obj, "__parameter__")) {
  1332. MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '"
  1333. << py::str(obj).cast<std::string>() << "' with type '"
  1334. << py::str(obj_type).cast<std::string>() << ".";
  1335. }
  1336. MS_EXCEPTION_IF_NULL(block);
  1337. MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString();
  1338. block->SetStateAssign(target_node, assigned_node);
  1339. }
  1340. void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ,
  1341. const AnfNodePtr &assigned_node) {
  1342. MS_EXCEPTION_IF_NULL(block);
  1343. AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
  1344. py::object value_obj = python_adapter::GetPyObjAttr(targ, "value");
  1345. py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice");
  1346. AnfNodePtr value_node = ParseExprNode(block, value_obj);
  1347. AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
  1348. CNodePtr setitem_app = block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node});
  1349. // Getitem apply should return the sequence data structure itself
  1350. std::string var_name;
  1351. if (ast_->IsClassMember(value_obj)) {
  1352. auto attr_name = value_obj.attr("attr").cast<std::string>();
  1353. var_name = "self." + attr_name;
  1354. if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
  1355. MS_EXCEPTION(TypeError) << "'" << var_name << "' was not defined in the class '__init__' function.";
  1356. }
  1357. auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
  1358. auto obj_type = obj.attr("__class__").attr("__name__");
  1359. if (!py::hasattr(obj, "__parameter__")) {
  1360. MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '"
  1361. << py::str(obj).cast<std::string>() << "' with type '"
  1362. << py::str(obj_type).cast<std::string>() << "'.";
  1363. }
  1364. block->WriteVariable(var_name, setitem_app);
  1365. return;
  1366. }
  1367. if (AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, value_obj))) ==
  1368. AST_SUB_TYPE_SUBSCRIPT) {
  1369. HandleAssignSubscript(block, value_obj, setitem_app);
  1370. return;
  1371. }
  1372. if (!py::hasattr(value_obj, "id")) {
  1373. MS_EXCEPTION(TypeError) << "Attribute id not found in " << py::str(value_obj).cast<std::string>();
  1374. }
  1375. var_name = value_obj.attr("id").cast<std::string>();
  1376. block->WriteVariable(var_name, setitem_app);
  1377. }
  1378. void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node) {
  1379. MS_EXCEPTION_IF_NULL(value_node);
  1380. MS_LOG(DEBUG) << "Process WriteAssignVars";
  1381. auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, targ)));
  1382. if (ast_type == AST_SUB_TYPE_NAME) {
  1383. HandleAssignName(block, targ, value_node);
  1384. } else if (ast_type == AST_SUB_TYPE_TUPLE) {
  1385. HandleAssignTuple(block, targ, value_node);
  1386. } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
  1387. HandleAssignSubscript(block, targ, value_node);
  1388. } else if (ast_->IsClassMember(targ)) {
  1389. HandleAssignClassMember(block, targ, value_node);
  1390. } else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
  1391. MS_LOG(EXCEPTION) << "The subnet attributes cannot be changed in the network"
  1392. << " NodeInfo: " << trace::GetDebugInfo(value_node->debug_info());
  1393. } else {
  1394. MS_LOG(EXCEPTION) << "Not supported assign type: " << ast_type
  1395. << " NodeInfo: " << trace::GetDebugInfo(value_node->debug_info());
  1396. }
  1397. }
  1398. // Process a assign statement, such as a =b, a,b = tup
  1399. FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
  1400. MS_LOG(DEBUG) << "Process ast assign";
  1401. py::object value_object = python_adapter::GetPyObjAttr(node, "value");
  1402. AnfNodePtr value_node = ParseExprNode(block, value_object);
  1403. py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
  1404. py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__");
  1405. size_t count = LongToSize(pcount);
  1406. MS_LOG(DEBUG) << "The nodes count is " << count;
  1407. for (size_t i = 0; i < count; i++) {
  1408. auto target_node = py::cast<py::list>(targets_object)[i];
  1409. WriteAssignVars(block, target_node, value_node);
  1410. }
  1411. return block;
  1412. }
  1413. FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) {
  1414. if (loops_.empty()) {
  1415. // Report error if loop context not set for the 'break' statement.
  1416. MS_LOG(EXCEPTION) << "Unexpected 'break'.";
  1417. }
  1418. // Get current loop.
  1419. Loop &loop = loops_.top();
  1420. if (loop.end == nullptr) {
  1421. // Create end_block if it is not existed.
  1422. TraceGuard trace_guard(std::make_shared<TraceLoopEnd>(block->func_graph()->debug_info()));
  1423. loop.end = MakeFunctionBlock(*this);
  1424. }
  1425. // Jump to the end_block.
  1426. block->Jump(loop.end, nullptr);
  1427. return block;
  1428. }
  1429. FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) {
  1430. if (loops_.empty()) {
  1431. // Report error if loop context not set for the 'continue' statement.
  1432. MS_LOG(EXCEPTION) << "Unexpected 'continue.";
  1433. }
  1434. // Jump to the header of the loop with iterator called.
  1435. Loop &loop = loops_.top();
  1436. block->Jump(loop.header, loop.iterator);
  1437. return block;
  1438. }
  1439. FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) {
  1440. // We just bypass 'pass' statement.
  1441. return block;
  1442. }
  1443. AnfNodePtr FindPhis(const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis, const AnfNodePtr &node) {
  1444. const auto &inp = node->cast<ParameterPtr>();
  1445. const auto &iter = removable_phis.find(inp);
  1446. if (iter == removable_phis.end()) {
  1447. return node;
  1448. }
  1449. return FindPhis(removable_phis, iter->second);
  1450. }
  1451. void Parser::RemoveUnnecessaryPhis() {
  1452. // Merge all removable phis to one map;
  1453. std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis;
  1454. std::vector<ParameterPtr> phis;
  1455. for (FunctionBlockPtr &block : func_block_list_) {
  1456. MS_EXCEPTION_IF_NULL(block);
  1457. removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end());
  1458. std::transform(block->removable_phis().begin(), block->removable_phis().end(), std::back_inserter(phis),
  1459. [](const std::pair<ParameterPtr, AnfNodePtr> &pair) { return pair.first; });
  1460. }
  1461. if (removable_phis.empty()) {
  1462. return;
  1463. }
  1464. auto fg_name = func_graph_->ToString();
  1465. auto mng = Manage(func_graph_, false);
  1466. // Replace the nodes
  1467. // Remove from inside to outside
  1468. for (int64_t idx = SizeToLong(phis.size() - 1); idx >= 0; idx--) {
  1469. auto phi = phis[LongToSize(idx)];
  1470. auto new_node = FindPhis(removable_phis, phi);
  1471. mng->Replace(phi, new_node);
  1472. }
  1473. // Remove the parameter
  1474. for (FunctionBlockPtr &block : func_block_list_) {
  1475. MS_EXCEPTION_IF_NULL(block);
  1476. auto &local_removable_phis = block->removable_phis();
  1477. if (local_removable_phis.empty()) {
  1478. continue;
  1479. }
  1480. auto func_graph = block->func_graph();
  1481. auto &parameters = func_graph->parameters();
  1482. std::vector<AnfNodePtr> new_parameters(parameters.size());
  1483. auto it = std::copy_if(
  1484. parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](const AnfNodePtr &param) {
  1485. return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end();
  1486. });
  1487. // Shrink container to new size
  1488. new_parameters.resize(std::distance(new_parameters.begin(), it));
  1489. func_graph->set_parameters(new_parameters);
  1490. }
  1491. }
  1492. // ParseAst class code
  1493. bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) {
  1494. // Init the type
  1495. target_type_ = PARSE_TARGET_UNKNOW;
  1496. // Call python parse, get the parser fn
  1497. module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  1498. py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD);
  1499. // Get the obj type
  1500. auto type = data_converter::GetObjType(obj_);
  1501. if (type == RESOLVE_TYPE_FUNCTION) {
  1502. target_type_ = PARSE_TARGET_FUNCTION;
  1503. function_ = obj_;
  1504. } else if (type == RESOLVE_TYPE_METHOD) {
  1505. // Process the method ,need get the method's self obj
  1506. target_type_ = PARSE_TARGET_METHOD;
  1507. py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS);
  1508. if (py::isinstance<py::none>(method_object)) {
  1509. MS_LOG(ERROR) << "Get method's self object instance failed.";
  1510. return false;
  1511. }
  1512. target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
  1513. function_ = obj_;
  1514. obj_ = method_object;
  1515. } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) {
  1516. // obj is class instance, get the method to parse.
  1517. function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method);
  1518. if (py::isinstance<py::none>(function_)) {
  1519. MS_LOG(ERROR) << "Get obj method function failed.";
  1520. return false;
  1521. }
  1522. target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
  1523. // Check the fn is method
  1524. auto obj_type = data_converter::GetObjType(function_);
  1525. if (obj_type != RESOLVE_TYPE_METHOD) {
  1526. MS_LOG(WARNING) << "Parse method function is invalid.";
  1527. return false;
  1528. }
  1529. } else {
  1530. MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type;
  1531. return false;
  1532. }
  1533. // Call python parse get ast tree
  1534. parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method);
  1535. ast_tree_ = python_adapter::CallPyObjMethod(parser_, "parse");
  1536. // Get fn name and module
  1537. function_module_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_module"));
  1538. function_name_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_name"));
  1539. function_filename_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "filename"));
  1540. function_line_offset_ = py::cast<int64_t>(python_adapter::GetPyObjAttr(parser_, "line_offset"));
  1541. return true;
  1542. }
  1543. // Get ast tree node : is the tree bode list[0]
  1544. py::object ParseAst::GetAstNode() {
  1545. py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body");
  1546. py::object ast_node = tree_body[0];
  1547. return ast_node;
  1548. }
  1549. py::list ParseAst::GetArgs(const py::object &func_node) {
  1550. py::list ret = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS, func_node);
  1551. return ret;
  1552. }
  1553. py::list ParseAst::GetArgsDefaultValues(const py::object &func_node) {
  1554. py::list ret = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node);
  1555. return ret;
  1556. }
  1557. AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) {
  1558. py::list list_value = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_NODE_TYPE, node);
  1559. if (list_value.size() < 2) {
  1560. MS_LOG(ERROR) << "The node of python method must has 2 values.";
  1561. return nullptr;
  1562. }
  1563. auto node_name = py::cast<std::string>(list_value[0]);
  1564. auto type = AstMainType(py::cast<int32_t>(list_value[1]));
  1565. return std::make_shared<AstNodeType>(node, node_name, type);
  1566. }
  1567. AstSubType ParseAst::GetOpType(const py::object &node) {
  1568. auto op_type = AstSubType(python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_AST_TYPE, node).cast<int32_t>());
  1569. return op_type;
  1570. }
  1571. bool ParseAst::IsClassMember(const py::object &node) {
  1572. py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node);
  1573. if (!py::isinstance<py::bool_>(ret)) {
  1574. MS_LOG(ERROR) << "The result of mod function parse, should be bool type.";
  1575. return false;
  1576. }
  1577. return ret.cast<bool>();
  1578. }
  1579. bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph) {
  1580. if (func_graph == nullptr) {
  1581. MS_LOG(ERROR) << "FuncGraph is null";
  1582. return false;
  1583. }
  1584. if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) {
  1585. MS_LOG(DEBUG) << "No flags";
  1586. return true;
  1587. }
  1588. py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG);
  1589. for (auto &item : flags) {
  1590. if (!py::isinstance<py::str>(item.first)) {
  1591. MS_LOG(ERROR) << "Type error in flags dict convert";
  1592. return false;
  1593. }
  1594. auto name = py::cast<std::string>(item.first);
  1595. if (py::isinstance<py::bool_>(item.second)) {
  1596. auto value = py::cast<bool>(item.second);
  1597. MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
  1598. func_graph->set_flag(name, value);
  1599. } else if (py::isinstance<py::str>(item.second)) {
  1600. auto value = py::cast<std::string>(item.second);
  1601. MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
  1602. func_graph->set_attr(name, MakeValue(value));
  1603. } else {
  1604. MS_LOG(ERROR) << "Type error in flags/attrs dict convert";
  1605. return false;
  1606. }
  1607. }
  1608. return true;
  1609. }
  1610. // Generate and copy a ValueNode, or a CNode with its child nodes
  1611. static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr &func_graph, const AnfNodePtr &param_node) {
  1612. MS_EXCEPTION_IF_NULL(param_node);
  1613. if (param_node->isa<ValueNode>()) {
  1614. return std::make_shared<ValueNode>(param_node->cast<ValueNodePtr>()->value());
  1615. }
  1616. // Parameter default value is CNode.
  1617. std::size_t index = 0;
  1618. std::vector<AnfNodePtr> old_cnodes;
  1619. old_cnodes.emplace_back(param_node);
  1620. auto res = func_graph->NewCNodeInOrder({});
  1621. std::vector<CNodePtr> new_cnodes;
  1622. new_cnodes.emplace_back(res);
  1623. while (index < old_cnodes.size()) {
  1624. auto current = old_cnodes[index];
  1625. auto current_new_cnode = new_cnodes[index];
  1626. index++;
  1627. MS_EXCEPTION_IF_NULL(current);
  1628. if (current->isa<CNode>()) {
  1629. auto &inputs = current->cast<CNodePtr>()->inputs();
  1630. for (auto it = inputs.begin(); it != inputs.end(); it++) {
  1631. AnfNodePtr input = *it;
  1632. if (input != nullptr && input->isa<CNode>()) {
  1633. old_cnodes.emplace_back(input);
  1634. auto new_cnode = func_graph->NewCNodeInOrder({});
  1635. new_cnodes.emplace_back(new_cnode);
  1636. current_new_cnode->add_input(new_cnode);
  1637. } else if (input->isa<ValueNode>()) {
  1638. current_new_cnode->add_input(std::make_shared<ValueNode>(input->cast<ValueNodePtr>()->value()));
  1639. } else {
  1640. MS_LOG(EXCEPTION) << "Wrong type item in default parameters: " << input->ToString();
  1641. }
  1642. }
  1643. }
  1644. }
  1645. return res;
  1646. }
  1647. FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
  1648. auto current_graph = dyn_cast<FuncGraph>(cell_ptr);
  1649. if (current_graph == nullptr) {
  1650. MS_LOG(EXCEPTION) << "Current graph cast failed from " << cell_ptr->ToString();
  1651. }
  1652. auto func_graph = std::make_shared<FuncGraph>();
  1653. func_graph->debug_info()->set_name(current_graph->debug_info()->name() + "_wrapper");
  1654. // Copy all parameters information
  1655. for (auto &para : current_graph->parameters()) {
  1656. auto param = func_graph->add_parameter();
  1657. auto orig_param = para->cast<ParameterPtr>();
  1658. auto name = orig_param->name();
  1659. param->set_name(name);
  1660. param->debug_info()->set_name(name);
  1661. }
  1662. func_graph->set_has_vararg(current_graph->has_vararg());
  1663. func_graph->set_has_kwarg(current_graph->has_kwarg());
  1664. func_graph->set_kwonlyargs_count(current_graph->kwonlyargs_count());
  1665. // Copy all default values
  1666. for (auto &d : current_graph->parameter_default_value()) {
  1667. func_graph->set_param_default_value(d.first, CopyNodesFromParamDefaultValue(func_graph, d.second));
  1668. }
  1669. // cell_obj
  1670. MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell));
  1671. parse::UpdateFuncGraphFlags(cell, func_graph);
  1672. // Top graph's construct flag
  1673. if (py::hasattr(cell, "construct")) {
  1674. parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph);
  1675. }
  1676. auto unpacking = func_graph->has_vararg() || func_graph->has_kwarg();
  1677. if (!unpacking) {
  1678. std::vector<AnfNodePtr> inputs;
  1679. inputs.emplace_back(NewValueNode(cell_ptr));
  1680. auto &params = func_graph->parameters();
  1681. (void)std::transform(params.begin(), params.end(), std::back_inserter(inputs),
  1682. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  1683. func_graph->set_output(func_graph->NewCNodeInOrder(inputs));
  1684. } else {
  1685. // ret = cell_obj(*arg, *kwargs)
  1686. auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), func_graph->parameters());
  1687. // Set output as ret
  1688. func_graph->set_output(call_fn);
  1689. }
  1690. return func_graph;
  1691. }
  1692. } // namespace parse
  1693. } // namespace mindspore