You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parse.cc 97 kB

5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/parse/parse.h"
  19. #include <utility>
  20. #include <string>
  21. #include <memory>
  22. #include <sstream>
  23. #include <algorithm>
  24. #include "utils/hash_map.h"
  25. #include "pybind_api/pybind_patch.h"
  26. #include "pipeline/jit/parse/resolve.h"
  27. #include "pipeline/jit/parse/data_converter.h"
  28. #include "frontend/operator/ops.h"
  29. #include "frontend/operator/composite/composite.h"
  30. #include "utils/ms_context.h"
  31. #include "utils/interpret_node_recorder.h"
  32. #include "debug/trace.h"
  33. #include "mindspore/core/ir/cell.h"
  34. namespace mindspore {
  35. namespace parse {
  36. FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) {
  37. (void)python_adapter::set_python_scoped();
  38. if (!obj || py::isinstance<py::none>(obj)) {
  39. MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none";
  40. return nullptr;
  41. }
  42. auto ast = std::make_shared<ParseFunctionAst>(obj);
  43. bool success = ast->InitParseAstInfo(python_mod_get_parse_method);
  44. if (!success) {
  45. MS_LOG(ERROR) << "Parse code to ast tree failed.";
  46. return nullptr;
  47. }
  48. auto parser = std::make_shared<Parser>(ast);
  49. FuncGraphPtr func_graph = parser->ParseFuncGraph();
  50. if (func_graph == nullptr) {
  51. MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode();
  52. return nullptr;
  53. }
  54. return func_graph;
  55. }
  56. TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
  57. MS_EXCEPTION_IF_NULL(func_graph);
  58. if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
  59. return kFloat32;
  60. } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
  61. return kFloat16;
  62. } else {
  63. return nullptr;
  64. }
  65. }
  66. // If any mixed precision flag add a cast node after the parameter node.
  67. AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
  68. MS_EXCEPTION_IF_NULL(func_graph);
  69. TypePtr dst_type;
  70. if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
  71. dst_type = kFloat32;
  72. } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
  73. dst_type = kFloat16;
  74. } else {
  75. return param;
  76. }
  77. auto cast_helper = prim::kPrimMixedPrecisionCast;
  78. auto cast = func_graph->NewCNodeAfter(param, {NewValueNode(cast_helper), NewValueNode(dst_type), param});
  79. return cast;
  80. }
  81. FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
  82. Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast) : ast_(ast) {
  83. max_for_loop_count_str_ = common::GetEnv("ENV_FOR_TO_WHILE_LOOP");
  84. support_fallback_ = common::GetEnv("ENV_SUPPORT_FALLBACK");
  85. errcode_ = PARSE_SUCCESS;
  86. BuildMethodMap();
  87. }
  88. void Parser::BuildMethodMap() {
  89. stmt_method_map_["Return"] = &Parser::ParseReturn;
  90. stmt_method_map_["Expr"] = &Parser::ParseExpr;
  91. stmt_method_map_["If"] = &Parser::ParseIf;
  92. stmt_method_map_["Assign"] = &Parser::ParseAssign;
  93. stmt_method_map_["While"] = &Parser::ParseWhile;
  94. stmt_method_map_["For"] = &Parser::ParseFor;
  95. stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef;
  96. stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign;
  97. stmt_method_map_["Global"] = &Parser::ParseGlobal;
  98. stmt_method_map_["Break"] = &Parser::ParseBreak;
  99. stmt_method_map_["Continue"] = &Parser::ParseContinue;
  100. stmt_method_map_["Pass"] = &Parser::ParsePass;
  101. expr_method_map_["NoneType"] = &Parser::ParseNone;
  102. expr_method_map_["BinOp"] = &Parser::ParseBinOp;
  103. expr_method_map_["Name"] = &Parser::ParseName;
  104. expr_method_map_["Num"] = &Parser::ParseNum;
  105. expr_method_map_["Str"] = &Parser::ParseStr;
  106. expr_method_map_["Constant"] = &Parser::ParseConstant;
  107. expr_method_map_["NameConstant"] = &Parser::ParseNameConstant;
  108. expr_method_map_["Call"] = &Parser::ParseCall;
  109. expr_method_map_["IfExp"] = &Parser::ParseIfExp;
  110. expr_method_map_["Attribute"] = &Parser::ParseAttribute;
  111. expr_method_map_["Compare"] = &Parser::ParseCompare;
  112. expr_method_map_["BoolOp"] = &Parser::ParseBoolOp;
  113. expr_method_map_["Lambda"] = &Parser::ParseLambda;
  114. expr_method_map_["Tuple"] = &Parser::ParseTuple;
  115. expr_method_map_["List"] = &Parser::ParseList;
  116. expr_method_map_["Subscript"] = &Parser::ParseSubscript;
  117. expr_method_map_["Slice"] = &Parser::ParseSlice;
  118. expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice;
  119. expr_method_map_["Index"] = &Parser::ParseIndex;
  120. expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
  121. expr_method_map_["Dict"] = &Parser::ParseDict;
  122. expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
  123. expr_method_map_["ListComp"] = &Parser::ParseListComp;
  124. expr_method_map_["GeneratorExp"] = &Parser::ParseListComp; // We treat 'GeneratorExp' the same as 'ListComp'.
  125. }
  126. void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
  127. void Parser::InitParserEnvironment(const py::object &obj) {
  128. Parser::top_func_graph_ = FuncGraphWeakPtr();
  129. ScopeManager::GetInstance().ClearScope();
  130. (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj);
  131. }
  132. void Parser::CleanParserResource() {
  133. Parser::top_func_graph_ = FuncGraphWeakPtr();
  134. ScopeManager::GetInstance().ClearScope();
  135. }
  136. void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunctionAst> &ast) {
  137. // Check whether the functions referred by this function and itself are missing 'return' statement
  138. auto mng = Manage(fn, false);
  139. MS_EXCEPTION_IF_NULL(ast);
  140. for (const auto &func_graph : mng->func_graphs()) {
  141. MS_EXCEPTION_IF_NULL(func_graph);
  142. if (func_graph->get_return() != nullptr) {
  143. continue;
  144. }
  145. py::object node = ast->GetAstNode();
  146. py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  147. constexpr auto min_list_size = 2;
  148. if (ret.size() < min_list_size) {
  149. MS_LOG(EXCEPTION) << "list size:" << ret.size() << " is less than 2.";
  150. }
  151. py::str desc =
  152. python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]);
  153. MS_EXCEPTION(TypeError) << "Function must has 'return' statement, but missing in " << desc.cast<std::string>()
  154. << ". FuncGraph: " << func_graph->ToString();
  155. }
  156. }
  157. FuncGraphPtr Parser::ParseFuncGraph() {
  158. // Get ast FunctionDef node
  159. py::object node = ast_->GetAstNode();
  160. constexpr char function_def_name[] = "FunctionDef";
  161. constexpr char lambda_name[] = "Lambda";
  162. FunctionBlockPtr fn_block = nullptr;
  163. if (ast_->GetNodeType(node)->node_name() == function_def_name) {
  164. fn_block = ParseDefFunction(node);
  165. } else {
  166. auto lambda_node = python_adapter::GetPyObjAttr(node, "value");
  167. if (py::isinstance<py::none>(lambda_node) || ast_->GetNodeType(lambda_node)->node_name() != lambda_name) {
  168. MS_EXCEPTION(TypeError) << "Parse Lambda Function Fail. Node type must be Lambda, but got "
  169. << ast_->GetNodeType(lambda_node)->node_name() << ".";
  170. }
  171. fn_block = ParseLambdaFunction(lambda_node);
  172. }
  173. if (errcode() != PARSE_SUCCESS) {
  174. MS_LOG(ERROR) << "Parse function error, code is " << errcode();
  175. return nullptr;
  176. }
  177. RemoveUnnecessaryPhis();
  178. MS_EXCEPTION_IF_NULL(fn_block);
  179. CheckFuncReturn(fn_block->func_graph(), ast_);
  180. return fn_block->func_graph();
  181. }
  182. void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
  183. py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args");
  184. py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg");
  185. MS_EXCEPTION_IF_NULL(block);
  186. auto block_fg = block->func_graph();
  187. block_fg->set_has_vararg(!py::isinstance<py::none>(var_arg_node));
  188. py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg");
  189. block_fg->set_has_kwarg(!py::isinstance<py::none>(kw_arg_node));
  190. py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs");
  191. block_fg->set_kwonlyargs_count(SizeToInt(kwonly_args.size()));
  192. MS_EXCEPTION_IF_NULL(ast_);
  193. py::list args = ast_->GetArgs(fn_node);
  194. for (std::size_t i = 0; i < args.size(); i++) {
  195. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  196. if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  197. if (arg_name == "self") {
  198. continue;
  199. }
  200. }
  201. TraceGuard guard(GetLocation(args[i]));
  202. auto para_node = std::make_shared<Parameter>(block_fg);
  203. MS_EXCEPTION_IF_NULL(para_node);
  204. para_node->set_name(arg_name);
  205. para_node->debug_info()->set_name(arg_name);
  206. block_fg->add_parameter(para_node);
  207. AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block_fg, para_node);
  208. MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
  209. block->WriteVariable(arg_name, para_after_cast);
  210. }
  211. }
  212. void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
  213. MS_EXCEPTION_IF_NULL(block);
  214. py::list defaults = ast_->GetArgsDefaultValues(fn_node);
  215. py::list args = ast_->GetArgs(fn_node);
  216. std::vector<std::string> namelist_for_default_value;
  217. std::vector<AnfNodePtr> default_values;
  218. for (std::size_t i = 0; i < args.size(); i++) {
  219. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  220. if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  221. if (arg_name == "self") {
  222. continue;
  223. }
  224. }
  225. namelist_for_default_value.push_back(arg_name);
  226. if (i >= defaults.size()) {
  227. MS_LOG(EXCEPTION) << "Index:" << i << " out of range:" << defaults.size();
  228. }
  229. if (py::isinstance<py::none>(defaults[i])) {
  230. default_values.push_back(NewValueNode(kNull));
  231. } else {
  232. default_values.push_back(ParseExprNode(block, defaults[i]));
  233. }
  234. }
  235. block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values);
  236. }
  237. ScopePtr Parser::GetScopeForParseFunction() {
  238. ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope();
  239. if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  240. py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj());
  241. if (!py::isinstance<py::none>(scope_str)) {
  242. auto scope_name = py::cast<std::string>(scope_str);
  243. scope = std::make_shared<Scope>(scope_name);
  244. }
  245. }
  246. return scope;
  247. }
  248. FunctionBlockPtr Parser::ParseDefFunction(const py::object &node, const FunctionBlockPtr &block) {
  249. ScopePtr scope = GetScopeForParseFunction();
  250. // The node created in the parsefunction context, will inherit the scope created using scope_guard
  251. ScopeGuard scope_guard(scope);
  252. TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node));
  253. FunctionBlockPtr func_block = MakeFunctionBlock(*this);
  254. if (block != nullptr) {
  255. func_block->AddPrevBlock(block);
  256. } else {
  257. func_graph_ = func_block->func_graph();
  258. }
  259. func_block->Mature();
  260. auto current_fg = func_block->func_graph();
  261. auto function_name = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "name"));
  262. MS_LOG(DEBUG) << "The function name is " << function_name;
  263. current_fg->debug_info()->set_name(function_name);
  264. MS_EXCEPTION_IF_NULL(ast_);
  265. py::list deco_list = node.attr("decorator_list");
  266. if (!deco_list.empty()) {
  267. current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
  268. }
  269. bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
  270. if (!ast_->obj().is(ast_->function())) {
  271. set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg);
  272. }
  273. if (!set_flag) {
  274. MS_LOG(ERROR) << "Set flags failed";
  275. return nullptr;
  276. }
  277. GenerateArgsNodeForFunction(func_block, node);
  278. // When parsing the top graph of construct, save the top graph
  279. if (GetTopFuncGraph() == nullptr) {
  280. UpdateTopFuncGraph(func_block->func_graph());
  281. }
  282. // Save the function node to block
  283. func_block->WriteVariable(function_name, NewValueNode(current_fg));
  284. py::object funcObj = python_adapter::GetPyObjAttr(node, "body");
  285. (void)ParseStatements(func_block, funcObj);
  286. // Add unused variables as isolate nodes.
  287. for (auto &func_block_item : func_block_list_) {
  288. MS_EXCEPTION_IF_NULL(func_block_item);
  289. if (func_block_item->func_graph()->get_return() != nullptr) {
  290. // Find unused variables.
  291. func_block_item->FindIsolatedNodes();
  292. // Attach all isolated nodes.
  293. func_block_item->AttachIsolatedNodesBeforeReturn();
  294. }
  295. }
  296. if (current_fg->get_return() == nullptr) {
  297. py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  298. py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, node, ret[0], ret[1]);
  299. MS_EXCEPTION(TypeError) << "Function must has 'return' statement, but missing in " << desc.cast<std::string>()
  300. << ".";
  301. }
  302. GenerateArgsDefaultValueForFunction(func_block, node);
  303. return func_block;
  304. }
  305. FunctionBlockPtr Parser::ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block) {
  306. MS_EXCEPTION_IF_NULL(ast_);
  307. ScopePtr scope = GetScopeForParseFunction();
  308. ScopeGuard scope_guard(scope);
  309. TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node));
  310. FunctionBlockPtr func_block = MakeFunctionBlock(*this);
  311. if (block != nullptr) {
  312. func_block->AddPrevBlock(block);
  313. } else {
  314. func_graph_ = func_block->func_graph();
  315. }
  316. func_block->Mature();
  317. auto current_fg = func_block->func_graph();
  318. auto function_name = ast_->function_name();
  319. MS_LOG(DEBUG) << "The function name is " << function_name;
  320. current_fg->debug_info()->set_name(function_name);
  321. GenerateArgsNodeForFunction(func_block, node);
  322. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  323. AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
  324. current_fg->set_output(lambda_body_node);
  325. GenerateArgsDefaultValueForFunction(func_block, node);
  326. return func_block;
  327. }
  328. FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) {
  329. auto node_list = py::cast<py::list>(nodes);
  330. size_t count = py::len(node_list);
  331. MS_LOG(DEBUG) << "The nodes count is " << count;
  332. for (size_t i = 0; i < count; ++i) {
  333. MS_LOG(DEBUG) << "Start parse statement[" << i << "]: " << py::str(node_list[i]);
  334. auto node = node_list[i];
  335. block = ParseStatement(block, node);
  336. MS_EXCEPTION_IF_NULL(block);
  337. // Insert appropriate depended items for the function block if it has a return node
  338. if (block->func_graph()->get_return() != nullptr || block->is_dead_block()) {
  339. // If break is not the last expr.
  340. if (i != count - 1) {
  341. TraceGuard trace_guard(GetLocation(node_list[i + 1]));
  342. MS_LOG(EXCEPTION) << "Dead code exist, please remove it.";
  343. }
  344. // Skip statements after 'return' (or 'break', 'continue').
  345. break;
  346. }
  347. }
  348. return block;
  349. }
  350. FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) {
  351. TraceGuard trace_guard(GetLocation(node));
  352. auto node_type = ast_->GetNodeType(node);
  353. // Check the node type
  354. AstMainType nodeType = node_type->main_type();
  355. if (nodeType != AST_MAIN_TYPE_STMT) {
  356. MS_LOG(INFO) << "Node type is error : " << nodeType;
  357. return block;
  358. }
  359. // Call the process function
  360. std::string node_name = node_type->node_name();
  361. MS_LOG(DEBUG) << "Ast node is " << node_name;
  362. if (stmt_method_map_.count(node_name)) {
  363. auto stmt_block = (this->*stmt_method_map_[node_name])(block, node);
  364. TraceManager::ClearParseOrResolveDebugInfo();
  365. return stmt_block;
  366. } else {
  367. errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
  368. MS_LOG(EXCEPTION) << "Unsupported statement '" << node_name
  369. << "'.\nMore details please refer to syntax support at https://www.mindspore.cn";
  370. }
  371. }
  372. AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) {
  373. MS_LOG(DEBUG) << "Process ast expr.";
  374. TraceGuard trace_guard(GetLocation(node));
  375. auto node_type = ast_->GetNodeType(node);
  376. // Check the node type
  377. AstMainType node_main_type = node_type->main_type();
  378. if (node_main_type != AST_MAIN_TYPE_EXPR) {
  379. errcode_ = PARSE_NODE_TYPE_NO_MATCH;
  380. MS_LOG(EXCEPTION) << "Node type is error : " << node_main_type;
  381. }
  382. // Call the process function
  383. std::string node_name = node_type->node_name();
  384. MS_LOG(DEBUG) << "Ast node is " << node_name;
  385. if (expr_method_map_.count(node_name)) {
  386. auto expr_node = (this->*expr_method_map_[node_name])(block, node);
  387. TraceManager::ClearParseOrResolveDebugInfo();
  388. return expr_node;
  389. } else {
  390. errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
  391. MS_LOG(EXCEPTION) << "Unsupported expression '" << node_name
  392. << "'.\nMore details please refer to syntax support at https://www.mindspore.cn";
  393. }
  394. }
  395. // Process the expr statement and expand it
  396. FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) {
  397. MS_LOG(DEBUG) << "Process ast Expr";
  398. // Expr only have value, no target
  399. py::tuple expand_info = ast_->CallParseModFunction(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node);
  400. // Refer python function expand_expr_statement, expand_info is one of the following:
  401. // True, expr.value, x
  402. // True, expr.value
  403. // False, None, None
  404. //
  405. // Check the expand info result
  406. if (expand_info.empty()) {
  407. MS_LOG(EXCEPTION) << "Empty expand_info.";
  408. }
  409. auto is_expand = py::cast<bool>(expand_info[0]);
  410. if (is_expand) {
  411. // Process the expr statement
  412. constexpr size_t expect_size = 2;
  413. if (expand_info.size() < expect_size) {
  414. MS_LOG(EXCEPTION) << "expand_info size:" << expand_info.size() << " less than " << expect_size << ".";
  415. }
  416. py::object value_object = expand_info[1];
  417. // Make a Expr CNode.
  418. AnfNodePtr call_node = ParseExprNode(block, value_object);
  419. if (py::len(expand_info) == 2) {
  420. // Expression that not assigned to any variable.
  421. // This is usually a call with side effects.
  422. // e.g.: print(x)
  423. // We save it as an isolated node.
  424. auto &no_return_node = call_node;
  425. MS_LOG(INFO) << "Isolated node found(NoReturn), no_return_node: " << no_return_node->DebugString(2)
  426. << ", block: " << block << "/"
  427. << (block->func_graph() ? block->func_graph()->ToString() : "FG(Null)")
  428. << ", Line: " << trace::GetDebugInfo(no_return_node->debug_info(), "", kSourceLineTipDiscard);
  429. block->AddIsolatedNode(no_return_node);
  430. } else {
  431. // Expand the assign statement,
  432. // e.g.: x.append(y) -> x = x.append(y)
  433. py::object target_node = expand_info[2];
  434. WriteAssignVars(block, target_node, call_node);
  435. }
  436. }
  437. return block;
  438. }
  439. LocationPtr Parser::GetLocation(const py::object &node) const {
  440. MS_EXCEPTION_IF_NULL(ast_);
  441. py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  442. constexpr size_t list_size = 5;
  443. if (ret.size() < list_size) {
  444. MS_LOG(EXCEPTION) << "List size should not be less than 5.";
  445. }
  446. const size_t file_name_index = 0;
  447. const size_t line_index = 1;
  448. const size_t column_index = 2;
  449. const size_t line_end_index = 3;
  450. const size_t column_end_index = 4;
  451. // Refer to Location::Location() for each member of ret: line, column, line_end, column_end.
  452. auto location = std::make_shared<Location>(ret[file_name_index].cast<std::string>(), ret[line_index].cast<int64_t>(),
  453. ret[column_index].cast<int64_t>(), ret[line_end_index].cast<int64_t>(),
  454. ret[column_end_index].cast<int64_t>());
  455. return location;
  456. }
  457. void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
  458. const FunctionBlockPtr &false_block) {
  459. MS_EXCEPTION_IF_NULL(true_block);
  460. MS_EXCEPTION_IF_NULL(false_block);
  461. true_block->AddPrevBlock(pre_block);
  462. true_block->Mature();
  463. false_block->AddPrevBlock(pre_block);
  464. false_block->Mature();
  465. }
  466. FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) {
  467. MS_LOG(DEBUG) << "Process ast return";
  468. MS_EXCEPTION_IF_NULL(block);
  469. // Parse the return Statements value.
  470. py::object value_object = python_adapter::GetPyObjAttr(node, "value");
  471. AnfNodePtr return_expr_node = ParseExprNode(block, value_object);
  472. // Check if need interpreting.
  473. return_expr_node = HandleInterpret(block, return_expr_node, value_object);
  474. // Create the `return` CNode.
  475. auto func_graph = block->func_graph();
  476. CNodePtr return_cnode = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), return_expr_node});
  477. func_graph->set_return(return_cnode);
  478. return block;
  479. }
  480. // Process binary operators,eg: `a + b`, `a | b`, etc.
  481. AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) {
  482. MS_LOG(DEBUG) << "Process ast BinOP";
  483. py::object left = python_adapter::GetPyObjAttr(node, "left");
  484. py::object right = python_adapter::GetPyObjAttr(node, "right");
  485. py::object op = python_adapter::GetPyObjAttr(node, "op");
  486. // Create left and right ANF node
  487. AnfNodePtr left_node = ParseExprNode(block, left);
  488. if (left_node == nullptr) {
  489. MS_LOG(EXCEPTION) << "DoBinOp process left node failed: " << errcode();
  490. }
  491. AnfNodePtr right_node = ParseExprNode(block, right);
  492. if (right_node == nullptr) {
  493. MS_LOG(EXCEPTION) << "DoBinOp process right node failed:" << errcode();
  494. }
  495. // Resolve the op
  496. MS_EXCEPTION_IF_NULL(block);
  497. AnfNodePtr op_node = block->MakeResolveAstOp(op);
  498. // Create apply node
  499. MS_EXCEPTION_IF_NULL(block->func_graph());
  500. return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
  501. }
  502. AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) {
  503. MS_LOG(DEBUG) << "Process ast Name";
  504. auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
  505. MS_LOG(DEBUG) << "The Name id is " << name_id;
  506. MS_EXCEPTION_IF_NULL(block);
  507. if (block->IsGlobalVar(name_id)) {
  508. MS_LOG(DEBUG) << "name_id: " << name_id;
  509. return block->MakeResolveSymbol(name_id);
  510. }
  511. return block->ReadVariable(name_id);
  512. }
  513. AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) {
  514. MS_LOG(DEBUG) << "Process ast NoneType";
  515. return NewValueNode(kNone);
  516. }
  517. AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) {
  518. MS_LOG(DEBUG) << "Process ast Ellipsis";
  519. return NewValueNode(kEllipsis);
  520. }
  521. AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
  522. MS_LOG(DEBUG) << "Process ast Num";
  523. py::object obj = python_adapter::GetPyObjAttr(node, "n");
  524. if (py::isinstance<py::int_>(obj)) {
  525. MS_LOG(INFO) << "The Num is int64_t:" << (std::string)py::str(obj);
  526. auto data = py::cast<int64_t>(obj);
  527. return NewValueNode(data);
  528. } else if (py::isinstance<py::float_>(obj)) {
  529. MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj);
  530. auto data = py::cast<float>(obj);
  531. return NewValueNode(data);
  532. } else {
  533. // no else actually
  534. errcode_ = PARSE_NODE_TYPE_UNKNOWN;
  535. MS_EXCEPTION(TypeError) << "Only support 'Number' type of 'int` and 'float', but got type: " << obj.get_type()
  536. << " Value:" << py::str(obj);
  537. }
  538. }
  539. AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) {
  540. MS_LOG(DEBUG) << "Process ast Str";
  541. auto str_s = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "s"));
  542. return NewValueNode(str_s);
  543. }
  544. AnfNodePtr Parser::ParseConstant(const FunctionBlockPtr &, const py::object &node) {
  545. MS_LOG(DEBUG) << "Process ast Constant";
  546. py::object obj = python_adapter::GetPyObjAttr(node, "value");
  547. if (py::isinstance<py::bool_>(obj)) {
  548. MS_LOG(INFO) << "The Constant is bool:" << (std::string)py::str(obj);
  549. return NewValueNode(py::cast<bool>(obj));
  550. } else if (py::isinstance<py::int_>(obj)) {
  551. MS_LOG(INFO) << "The Constant is int64_t:" << (std::string)py::str(obj);
  552. return NewValueNode(py::cast<int64_t>(obj));
  553. } else if (py::isinstance<py::float_>(obj)) {
  554. MS_LOG(INFO) << "The Constant is float:" << (std::string)py::str(obj);
  555. return NewValueNode(py::cast<float>(obj));
  556. } else if (py::isinstance<py::str>(obj)) {
  557. MS_LOG(INFO) << "The Constant is string:" << (std::string)py::str(obj);
  558. return NewValueNode(py::cast<std::string>(obj));
  559. } else if (py::isinstance<py::none>(obj)) {
  560. MS_LOG(INFO) << "The Constant is none:" << (std::string)py::str(obj);
  561. return NewValueNode(kNone);
  562. } else if (py::isinstance<py::ellipsis>(obj)) {
  563. MS_LOG(INFO) << "The Constance is ellipsis:" << (std::string)py::str(obj);
  564. return NewValueNode(kEllipsis);
  565. } else {
  566. // no else actually
  567. MS_EXCEPTION(TypeError) << "Unsupported Constant type : " << (std::string)py::str(obj);
  568. }
  569. }
  570. AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) {
  571. MS_LOG(DEBUG) << "Process ast NameConstant";
  572. py::object obj = python_adapter::GetPyObjAttr(node, "value");
  573. if (py::isinstance<py::bool_>(obj)) {
  574. MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj);
  575. auto data = py::cast<bool>(obj);
  576. return NewValueNode(data);
  577. } else if (py::isinstance<py::none>(obj)) {
  578. MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj);
  579. return NewValueNode(kNone);
  580. }
  581. // no else actually
  582. errcode_ = PARSE_NODE_TYPE_UNKNOWN;
  583. MS_LOG(EXCEPTION) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
  584. }
  585. AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes) {
  586. MS_EXCEPTION_IF_NULL(block);
  587. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  588. std::vector<AnfNodePtr> make_tuple_nodes;
  589. make_tuple_nodes.push_back(make_tuple_op);
  590. (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes),
  591. [](AnfNodePtr arg) -> AnfNodePtr { return arg; });
  592. return block->func_graph()->NewCNodeInOrder(std::move(make_tuple_nodes));
  593. }
  594. AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) {
  595. MS_EXCEPTION_IF_NULL(block);
  596. py::object father_class;
  597. const size_t expect_args_size = 2;
  598. if (args.empty()) {
  599. father_class = py::none();
  600. } else if (args.size() == expect_args_size) {
  601. father_class = args[0];
  602. auto arg_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[1])));
  603. if (arg_type != AST_SUB_TYPE_NAME || py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) != "self") {
  604. MS_EXCEPTION(ArgumentError) << "Argument 2 of 'super()' must be 'self', but got '"
  605. << py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) << "'.";
  606. }
  607. } else {
  608. MS_EXCEPTION(ArgumentError) << "Arguments number of 'super()' should be 0 or 2, but got " << args.size() << ".";
  609. }
  610. py::object target_class_instance = ast_->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast_->obj());
  611. py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
  612. NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
  613. SymbolPtr symbol = std::make_shared<Symbol>("namespace");
  614. MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
  615. return block->MakeResolve(name_space, symbol);
  616. }
  617. // Process function call, eg : f1(x, y) ...
  618. AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) {
  619. MS_LOG(DEBUG) << "Process ast Call";
  620. // Process function call
  621. py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func");
  622. py::list args = python_adapter::GetPyObjAttr(node, "args");
  623. auto arg_type =
  624. AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, function_ast_node)));
  625. if (arg_type == AST_SUB_TYPE_NAME) {
  626. auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id"));
  627. if (name_id == "super") {
  628. return ParseSuper(block, args);
  629. }
  630. }
  631. AnfNodePtr call_function_node = ParseExprNode(block, function_ast_node);
  632. // Function call arguments should be passed in as groups and unpacked later using unpack call
  633. std::vector<AnfNodePtr> packed_arguments;
  634. std::vector<AnfNodePtr> group_arguments;
  635. bool need_unpack_args = ParseArgsInCall(block, args, &packed_arguments, &group_arguments);
  636. bool need_unpack_keywords = ParseKeywordsInCall(block, node, &packed_arguments);
  637. // If there is stared or keyword argument, unpack may be needed
  638. bool need_unpack = need_unpack_args || need_unpack_keywords;
  639. auto call_cnode = GenerateAnfNodeForCall(block, call_function_node, packed_arguments, group_arguments, need_unpack);
  640. if (call_function_node->interpret()) {
  641. call_cnode->set_interpret(true);
  642. }
  643. return call_cnode;
  644. }
  645. CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_node,
  646. const std::vector<AnfNodePtr> &packed_arguments) {
  647. MS_EXCEPTION_IF_NULL(func_graph);
  648. std::vector<AnfNodePtr> unpack_call_nodes;
  649. auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
  650. unpack_call_nodes.push_back(unpack_call_op);
  651. unpack_call_nodes.push_back(call_function_node);
  652. (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes),
  653. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  654. CNodePtr unpack_call = func_graph->NewCNodeInOrder(std::move(unpack_call_nodes));
  655. return unpack_call;
  656. }
  657. AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_node,
  658. const std::vector<AnfNodePtr> &packed_arguments,
  659. const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
  660. // If there is keyword arguments or starred, using an unpack_call op to unpack the argument
  661. MS_EXCEPTION_IF_NULL(block);
  662. if (need_unpack) {
  663. return MakeUnpackCall(block->func_graph(), call_function_node, packed_arguments);
  664. }
  665. // else there is no keyword arguments and starred, parsed as normal arguments without unpack
  666. std::vector<AnfNodePtr> func_call_nodes;
  667. func_call_nodes.push_back(call_function_node);
  668. (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes),
  669. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  670. CNodePtr call_anf_node = block->func_graph()->NewCNodeInOrder(std::move(func_call_nodes));
  671. return call_anf_node;
  672. }
  673. bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args,
  674. std::vector<AnfNodePtr> *packed_arguments, std::vector<AnfNodePtr> *group_arguments) {
  675. MS_LOG(DEBUG) << "Process ast args in call";
  676. MS_EXCEPTION_IF_NULL(packed_arguments);
  677. MS_EXCEPTION_IF_NULL(group_arguments);
  678. bool need_unpack = false;
  679. for (size_t i = 0; i < args.size(); i++) {
  680. auto arg_node = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[i])));
  681. if (arg_node == AST_SUB_TYPE_STARRED) {
  682. if (!group_arguments->empty()) {
  683. packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments));
  684. }
  685. packed_arguments->push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value")));
  686. group_arguments->clear();
  687. need_unpack = true;
  688. } else {
  689. auto node = ParseExprNode(block, args[i]);
  690. node = HandleInterpret(block, node, args[i]);
  691. group_arguments->push_back(node);
  692. }
  693. }
  694. if (!group_arguments->empty()) {
  695. packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments));
  696. }
  697. return need_unpack;
  698. }
  699. bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node,
  700. std::vector<AnfNodePtr> *packed_arguments) {
  701. MS_LOG(DEBUG) << "Process ast key words in call";
  702. bool need_unpack = false;
  703. py::list keywords = python_adapter::GetPyObjAttr(node, "keywords");
  704. if (!keywords.empty()) {
  705. MS_EXCEPTION_IF_NULL(block);
  706. MS_EXCEPTION_IF_NULL(packed_arguments);
  707. need_unpack = true;
  708. std::vector<AnfNodePtr> keys;
  709. std::vector<AnfNodePtr> values;
  710. for (size_t index = 0; index < keywords.size(); index++) {
  711. auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg");
  712. auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value");
  713. if (py::isinstance<py::none>(kw_key)) {
  714. packed_arguments->push_back(ParseExprNode(block, kw_value));
  715. } else {
  716. auto kw_key_c = kw_key.cast<std::string>();
  717. keys.push_back(NewValueNode(kw_key_c));
  718. auto ret_node = ParseExprNode(block, kw_value);
  719. ret_node = HandleInterpret(block, ret_node, kw_value);
  720. values.push_back(ret_node);
  721. }
  722. }
  723. auto keys_tuple = GenerateMakeTuple(block, keys);
  724. auto values_tuple = GenerateMakeTuple(block, values);
  725. auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
  726. std::vector<AnfNodePtr> make_dict_nodes;
  727. make_dict_nodes.push_back(make_dict_op);
  728. make_dict_nodes.push_back(keys_tuple);
  729. make_dict_nodes.push_back(values_tuple);
  730. packed_arguments->push_back(block->func_graph()->NewCNodeInOrder(std::move(make_dict_nodes)));
  731. }
  732. return need_unpack;
  733. }
  734. // Process call attributes of class type define, eg: x.y()
  735. AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) {
  736. MS_LOG(DEBUG) << "Process ast Attribute";
  737. // Process class value, eg: self.xx
  738. if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  739. if (ast_->IsClassMember(node)) {
  740. std::string var_name = "self.";
  741. std::string attr_name = node.attr("attr").cast<std::string>();
  742. (void)var_name.append(attr_name);
  743. auto attr_obj = ast()->obj().attr(attr_name.c_str());
  744. MS_EXCEPTION_IF_NULL(block);
  745. if (py::hasattr(ast()->obj(), attr_name.c_str()) &&
  746. (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance<py::int_>(attr_obj) ||
  747. py::isinstance<py::float_>(attr_obj) || py::isinstance<py::bool_>(attr_obj) ||
  748. py::isinstance<py::str>(attr_obj) || data_converter::IsCellInstance(attr_obj))) {
  749. MS_LOG(DEBUG) << "var_name: " << var_name;
  750. return block->MakeResolveSymbol(var_name);
  751. } else {
  752. return block->ReadVariable(var_name);
  753. }
  754. }
  755. }
  756. // Process the get attr
  757. // Use the Primitive replace the operation resolve node (getattr),
  758. // because the getattr will eventually be converted to Primitive node
  759. AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr);
  760. // Process the attr body
  761. py::object value_body = python_adapter::GetPyObjAttr(node, "value");
  762. AnfNodePtr value_node = ParseExprNode(block, value_body);
  763. if (value_node == nullptr) {
  764. MS_LOG(EXCEPTION) << "Parse attribute failed";
  765. }
  766. // Process the node attr
  767. auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast<std::string>();
  768. MS_LOG(DEBUG) << "Attr = " << attr_str;
  769. AnfNodePtr attr_node = nullptr;
  770. {
  771. TraceGuard guard(GetLocation(python_adapter::GetPyObjAttr(node, "attr")));
  772. attr_node = NewValueNode(attr_str);
  773. }
  774. // Create the apply node
  775. auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
  776. if (value_node->interpret() || IsPrimitiveCNode(value_node, prim::kPrimPyInterpret)) {
  777. attr_cnode->set_interpret(true);
  778. }
  779. return attr_cnode;
  780. }
  781. // Process comparison expression : a == b. a > b etc.
  782. AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) {
  783. MS_LOG(DEBUG) << "Process ast Compare";
  784. // For python comparison ,there may be if x>y>5 ,
  785. // Which there is two ops , but we only support one now
  786. py::list ops = python_adapter::GetPyObjAttr(node, "ops");
  787. if (ops.size() != MAX_COMPARISON_OPS_SUPPORTED) {
  788. MS_EXCEPTION(NotSupportError) << "Only support comparison with 1 operator, but got " << ops.size() << ", which is "
  789. << py::str(ops);
  790. }
  791. py::object left = python_adapter::GetPyObjAttr(node, "left");
  792. py::list comparators = python_adapter::GetPyObjAttr(node, "comparators");
  793. if (comparators.empty()) {
  794. MS_LOG(EXCEPTION) << "Comparators can't be empty.";
  795. }
  796. AnfNodePtr left_node = ParseExprNode(block, left);
  797. AnfNodePtr right_node = ParseExprNode(block, comparators[0]);
  798. MS_EXCEPTION_IF_NULL(block);
  799. AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]);
  800. return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
  801. }
  802. AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) {
  803. // If there is only one bool op now
  804. MS_EXCEPTION_IF_NULL(block);
  805. if (value_list.empty()) {
  806. MS_LOG(EXCEPTION) << "value list is empty.";
  807. }
  808. if (value_list.size() == 1) {
  809. AnfNodePtr first_node = ParseExprNode(block, value_list[0]);
  810. return first_node;
  811. } else {
  812. py::object first = value_list[0];
  813. py::list rest;
  814. for (size_t i = 1; i < value_list.size(); i++) {
  815. rest.append(value_list[i]);
  816. }
  817. FunctionBlockPtr true_block = nullptr;
  818. FunctionBlockPtr false_block = nullptr;
  819. auto block_fg = block->func_graph();
  820. {
  821. TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block_fg->debug_info()));
  822. true_block = MakeFunctionBlock(*this);
  823. }
  824. {
  825. TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block_fg->debug_info()));
  826. false_block = MakeFunctionBlock(*this);
  827. }
  828. MakeConditionBlocks(block, true_block, false_block);
  829. FunctionBlockPtr b1, b2;
  830. // If it is and, we need to process the rest nodes;
  831. // If it is or, we continue to next
  832. if (mode == AST_SUB_TYPE_AND) {
  833. b1 = true_block;
  834. b2 = false_block;
  835. } else if (mode == AST_SUB_TYPE_OR) {
  836. b2 = true_block;
  837. b1 = false_block;
  838. } else {
  839. MS_LOG(ERROR) << "Not supported mode: " << mode;
  840. return nullptr;
  841. }
  842. AnfNodePtr test_node = ParseExprNode(block, first);
  843. AnfNodePtr rest_node = ProcessBoolOpValueList(b1, rest, mode);
  844. b1->func_graph()->set_output(rest_node);
  845. b2->func_graph()->set_output(test_node);
  846. auto cond_node = block->ForceToBoolNode(test_node);
  847. auto switch_app =
  848. block_fg->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()),
  849. NewValueNode(false_block->func_graph())});
  850. std::vector<AnfNodePtr> call_graph_nodes{switch_app};
  851. auto switch_app_call = block_fg->NewCNodeInOrder(std::move(call_graph_nodes));
  852. return switch_app_call;
  853. }
  854. }
  855. // Process comparison expression : a and b. a or b .
  856. AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) {
  857. MS_LOG(DEBUG) << "Process ast BoolOp";
  858. py::object op_node = python_adapter::GetPyObjAttr(node, "op");
  859. AstSubType op_type = ast_->GetOpType(op_node);
  860. if (op_type == AST_SUB_TYPE_UNKNOWN) {
  861. MS_LOG(EXCEPTION) << "ProcessBoolOp, got unknown op type";
  862. }
  863. py::list op_values = python_adapter::GetPyObjAttr(node, "values");
  864. return ProcessBoolOpValueList(block, op_values, op_type);
  865. }
  866. // Process a function def
  867. FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
  868. MS_LOG(DEBUG) << "Process ast FunctionDef";
  869. FunctionBlockPtr function_block = ParseDefFunction(node, block);
  870. MS_EXCEPTION_IF_NULL(function_block);
  871. // Get function name
  872. py::str name = python_adapter::GetPyObjAttr(node, "name");
  873. std::string function_name = name;
  874. ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph());
  875. block->WriteVariable(function_name, valuenode_graph);
  876. return block;
  877. }
  878. // Process a lambda expression . like lambda x,y: x + y
  879. AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
  880. MS_LOG(DEBUG) << "Process ast Lambda";
  881. FunctionBlockPtr function_block = ParseLambdaFunction(node, block);
  882. MS_EXCEPTION_IF_NULL(function_block);
  883. auto block_fg = function_block->func_graph();
  884. ValueNodePtr const_graph = NewValueNode(block_fg);
  885. return const_graph;
  886. }
  887. // Process a tuple
  888. AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) {
  889. MS_LOG(DEBUG) << "Process ast Tuple";
  890. MS_EXCEPTION_IF_NULL(block);
  891. py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
  892. if (elts.empty()) {
  893. auto empty_tuple = std::vector<ValuePtr>();
  894. return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
  895. }
  896. std::vector<AnfNodePtr> tuple_vec;
  897. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  898. tuple_vec.emplace_back(make_tuple_op);
  899. for (size_t i = 0; i < elts.size(); i++) {
  900. AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
  901. tuple_vec.emplace_back(node_ptr);
  902. }
  903. CNodePtr tuple_app = block->func_graph()->NewCNodeInOrder(std::move(tuple_vec));
  904. return tuple_app;
  905. }
  906. // Process a list
  907. AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) {
  908. MS_LOG(DEBUG) << "Process ast List";
  909. MS_EXCEPTION_IF_NULL(block);
  910. py::list elts = python_adapter::GetPyObjAttr(node, "elts");
  911. if (elts.empty()) {
  912. auto empty_list = std::vector<ValuePtr>();
  913. return NewValueNode(std::make_shared<ValueList>(empty_list));
  914. }
  915. std::vector<AnfNodePtr> list_vec;
  916. AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
  917. list_vec.emplace_back(make_list_op);
  918. for (size_t i = 0; i < elts.size(); i++) {
  919. AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
  920. list_vec.emplace_back(node_ptr);
  921. }
  922. CNodePtr list_app = block->func_graph()->NewCNodeInOrder(std::move(list_vec));
  923. return list_app;
  924. }
  925. // Process a subscript, such as x[y] , node expressed as value[slice]
  926. AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) {
  927. MS_LOG(DEBUG) << "Process ast Subscript";
  928. MS_EXCEPTION_IF_NULL(block);
  929. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  930. py::object value_node = python_adapter::GetPyObjAttr(node, "value");
  931. py::object slice_node = python_adapter::GetPyObjAttr(node, "slice");
  932. AnfNodePtr value = ParseExprNode(block, value_node);
  933. AnfNodePtr slice = ParseExprNode(block, slice_node);
  934. return block->func_graph()->NewCNodeInOrder({op_getitem, value, slice});
  935. }
  936. // Process a slice, get the slice value
  937. AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) {
  938. MS_LOG(DEBUG) << "Process ast Slice";
  939. MS_EXCEPTION_IF_NULL(block);
  940. AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE);
  941. py::object start = python_adapter::GetPyObjAttr(node, "lower");
  942. py::object stop = python_adapter::GetPyObjAttr(node, "upper");
  943. py::object step = python_adapter::GetPyObjAttr(node, "step");
  944. AnfNodePtr start_node = ParseExprNode(block, start);
  945. AnfNodePtr stop_node = ParseExprNode(block, stop);
  946. AnfNodePtr step_node = ParseExprNode(block, step);
  947. return block->func_graph()->NewCNodeInOrder({op_makeslice, start_node, stop_node, step_node});
  948. }
  949. // Process a extslice
  950. AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) {
  951. MS_LOG(DEBUG) << "Process ast ExtSlice";
  952. MS_EXCEPTION_IF_NULL(block);
  953. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  954. py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims");
  955. std::vector<AnfNodePtr> node_vec;
  956. node_vec.emplace_back(make_tuple_op);
  957. for (size_t i = 0; i < slice_tuple.size(); i++) {
  958. AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]);
  959. node_vec.emplace_back(node_ptr);
  960. }
  961. CNodePtr tuple_conde = block->func_graph()->NewCNodeInOrder(std::move(node_vec));
  962. return tuple_conde;
  963. }
  964. // Process a index, get the index number
  965. AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) {
  966. MS_LOG(DEBUG) << "Process ast Index";
  967. py::object value_node = python_adapter::GetPyObjAttr(node, "value");
  968. return ParseExprNode(block, value_node);
  969. }
  970. // Process a UnaryOp, +a, -b
  971. AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) {
  972. MS_LOG(DEBUG) << "Process ast UnaryOp";
  973. py::object op = python_adapter::GetPyObjAttr(node, "op");
  974. MS_EXCEPTION_IF_NULL(block);
  975. // Resolve the op
  976. AnfNodePtr op_node = block->MakeResolveAstOp(op);
  977. py::object operand = python_adapter::GetPyObjAttr(node, "operand");
  978. AnfNodePtr operand_node = ParseExprNode(block, operand);
  979. return block->func_graph()->NewCNodeInOrder({op_node, operand_node});
  980. }
  981. // Process a dict ast node expression
  982. AnfNodePtr Parser::ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes,
  983. const std::vector<AnfNodePtr> &value_nodes) {
  984. auto keys_tuple = GenerateMakeTuple(block, key_nodes);
  985. auto values_tuple = GenerateMakeTuple(block, value_nodes);
  986. MS_EXCEPTION_IF_NULL(block);
  987. auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
  988. return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple});
  989. }
  990. AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
  991. MS_LOG(DEBUG) << "Process ast Dict";
  992. py::list keys = node.attr("keys");
  993. py::list values = node.attr("values");
  994. std::vector<AnfNodePtr> key_nodes;
  995. std::vector<AnfNodePtr> value_nodes;
  996. for (size_t i = 0; i < keys.size(); i++) {
  997. key_nodes.push_back(ParseExprNode(block, keys[i]));
  998. value_nodes.push_back(ParseExprNode(block, values[i]));
  999. }
  1000. return ParseDictByKeysAndValues(block, key_nodes, value_nodes);
  1001. }
  1002. // Process a augment assign such as a += b or mat[stride_slice] += b.
  1003. FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) {
  1004. MS_LOG(DEBUG) << "Process ast AugAssign";
  1005. MS_EXCEPTION_IF_NULL(block);
  1006. MS_EXCEPTION_IF_NULL(ast_);
  1007. py::object target_object = python_adapter::GetPyObjAttr(node, "target");
  1008. py::object op_object = python_adapter::GetPyObjAttr(node, "op");
  1009. py::object value_object = python_adapter::GetPyObjAttr(node, "value");
  1010. AnfNodePtr target_node = nullptr;
  1011. AnfNodePtr op_node = block->MakeResolveAstOp(op_object);
  1012. AnfNodePtr value_node = ParseExprNode(block, value_object);
  1013. auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object)));
  1014. if (ast_type == AST_SUB_TYPE_NAME) {
  1015. target_node = ParseName(block, target_object);
  1016. } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
  1017. target_node = ParseSubscript(block, target_object);
  1018. } else if (ast_->IsClassMember(target_object)) {
  1019. target_node = ParseAttribute(block, target_object);
  1020. } else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
  1021. TraceGuard(GetLocation(target_object));
  1022. MS_EXCEPTION(TypeError) << "Only support augassign to attribute of self, but got attribute of "
  1023. << py::str(target_object.attr("value").attr("id")) << ".\n"
  1024. << "More details please refer to syntax support at https://www.mindspore.cn";
  1025. } else {
  1026. TraceGuard(GetLocation(target_object));
  1027. MS_EXCEPTION(TypeError) << "Only supported augassign to attribute of self, variable and index value, but got "
  1028. << target_object.get_type()
  1029. << ".\nMore details please refer to syntax support at https://www.mindspore.cn";
  1030. }
  1031. if (target_node == nullptr) {
  1032. MS_LOG(EXCEPTION) << "Can not get target node ";
  1033. }
  1034. CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
  1035. WriteAssignVars(block, target_object, augassign_app);
  1036. return block;
  1037. }
  1038. // Process global declaration such as 'global x';
  1039. FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) {
  1040. MS_LOG(DEBUG) << "Process ast Global";
  1041. MS_EXCEPTION_IF_NULL(block);
  1042. py::list vars = python_adapter::GetPyObjAttr(node, "names");
  1043. for (auto &item : vars) {
  1044. block->AddGlobalVar(py::cast<std::string>(item));
  1045. }
  1046. return block;
  1047. }
  1048. // Process a if statement
  1049. FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) {
  1050. MS_LOG(DEBUG) << "Process ast If";
  1051. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  1052. AnfNodePtr condition_node = ParseExprNode(block, test_node);
  1053. MS_EXCEPTION_IF_NULL(block);
  1054. CNodePtr bool_node = block->ForceToBoolNode(condition_node);
  1055. FunctionBlockPtr true_block = nullptr;
  1056. FunctionBlockPtr false_block = nullptr;
  1057. auto block_fg = block->func_graph();
  1058. {
  1059. TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block_fg->debug_info()));
  1060. true_block = MakeFunctionBlock(*this);
  1061. }
  1062. {
  1063. TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block_fg->debug_info()));
  1064. false_block = MakeFunctionBlock(*this);
  1065. }
  1066. MakeConditionBlocks(block, true_block, false_block);
  1067. FunctionBlockPtr after_block = nullptr;
  1068. {
  1069. TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block_fg->debug_info()));
  1070. after_block = MakeFunctionBlock(*this);
  1071. }
  1072. if (MsContext::GetInstance()->backend_policy() != "ge") {
  1073. // For backends excludes 'ge', it can handle multi graph call, use this flag to
  1074. // generate call not inline `after_block` graph to reduce if by if switch expansion.
  1075. after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
  1076. }
  1077. // Process the if-true branch
  1078. py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
  1079. FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode);
  1080. // If the return_ is set, it has its own continuation block
  1081. if (true_end->func_graph()->get_return() == nullptr) {
  1082. MS_LOG(DEBUG) << "true end jump to after.";
  1083. true_end->Jump(after_block, {});
  1084. }
  1085. // Process the orelse branch
  1086. py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
  1087. FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode);
  1088. // If the return_ is set, it has its own continuation block
  1089. if (false_end->func_graph()->get_return() == nullptr) {
  1090. MS_LOG(DEBUG) << "false_end jump to after.";
  1091. false_end->Jump(after_block, {});
  1092. }
  1093. block->ConditionalJump(bool_node, true_block, false_block);
  1094. if (after_block->prev_blocks().empty()) {
  1095. after_block->SetAsDeadBlock();
  1096. }
  1097. after_block->Mature();
  1098. return after_block;
  1099. }
  1100. FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) {
  1101. MS_LOG(DEBUG) << "Process ast While";
  1102. MS_EXCEPTION_IF_NULL(block);
  1103. FunctionBlockPtr header_block = nullptr;
  1104. FunctionBlockPtr body_block = nullptr;
  1105. FunctionBlockPtr after_block = nullptr;
  1106. {
  1107. TraceGuard guard(std::make_shared<TraceWhileHeader>(block->func_graph()->debug_info()));
  1108. header_block = MakeFunctionBlock(*this);
  1109. auto func_graph = header_block->func_graph();
  1110. MS_EXCEPTION_IF_NULL(func_graph);
  1111. func_graph->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
  1112. }
  1113. {
  1114. TraceGuard guard(std::make_shared<TraceWhileBody>(block->func_graph()->debug_info()));
  1115. body_block = MakeFunctionBlock(*this);
  1116. }
  1117. {
  1118. TraceGuard guard(std::make_shared<TraceWhileAfter>(block->func_graph()->debug_info()));
  1119. after_block = MakeFunctionBlock(*this);
  1120. }
  1121. body_block->AddPrevBlock(header_block);
  1122. after_block->AddPrevBlock(header_block);
  1123. block->Jump(header_block, {});
  1124. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  1125. AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
  1126. condition_node = header_block->ForceToWhileCond(condition_node);
  1127. body_block->Mature();
  1128. header_block->ConditionalJump(condition_node, body_block, after_block);
  1129. // Parse loop body statements with loop context.
  1130. LoopContext loop_context{&loops_, header_block, nullptr};
  1131. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1132. FunctionBlockPtr after_body = ParseStatements(body_block, body_node);
  1133. if (after_body->func_graph()->get_return() == nullptr) {
  1134. after_body->Jump(header_block, {});
  1135. }
  1136. header_block->Mature();
  1137. after_block->Mature();
  1138. auto &end_block = loop_context.EndBlock();
  1139. // end_block exists if we encounter 'break' in loop body.
  1140. if (end_block) {
  1141. after_block->Jump(end_block, {});
  1142. end_block->Mature();
  1143. return end_block;
  1144. }
  1145. // No 'break', no end_block.
  1146. return after_block;
  1147. }
  1148. CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::object &node,
  1149. const AnfNodePtr &op_iter) {
  1150. py::object iter_node = python_adapter::GetPyObjAttr(node, "iter");
  1151. AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
  1152. return block->func_graph()->NewCNodeInOrder({op_iter, iter_anf_node});
  1153. }
  1154. CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
  1155. const AnfNodePtr &op_hasnext) {
  1156. MS_EXCEPTION_IF_NULL(header_block);
  1157. return header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
  1158. }
  1159. FunctionBlockPtr Parser::GenerateBlock(const TraceInfoPtr &trace_info) {
  1160. TraceGuard trace_guard(trace_info);
  1161. FunctionBlockPtr block = MakeFunctionBlock(*this);
  1162. MS_EXCEPTION_IF_NULL(block);
  1163. return block;
  1164. }
  1165. int64_t Parser::GetForTransToWhileLoop() {
  1166. // int64 support 63bits positive num mostly.
  1167. constexpr auto max_num_length = 10;
  1168. if (max_for_loop_count_str_.size() > max_num_length || max_for_loop_count_str_.empty()) {
  1169. return MAX_FOR_LOOP_COUNT;
  1170. }
  1171. if (std::any_of(max_for_loop_count_str_.begin(), max_for_loop_count_str_.end(),
  1172. [](char c) { return c < '0' || c > '9'; })) {
  1173. return MAX_FOR_LOOP_COUNT;
  1174. }
  1175. int64_t loop_count;
  1176. std::stringstream ss;
  1177. ss << max_for_loop_count_str_;
  1178. ss >> loop_count;
  1179. return loop_count;
  1180. }
  1181. // A for loop will generate 3 functions :the test, the body, and the continuation
  1182. // for x in xs:
  1183. // body
  1184. // It is compiled to be following statement
  1185. // if len(xs) < max_loop_cnt, ParseForIter. Use iter to implement for loop, which always unroll loop
  1186. // else, ParseForLoop. Use loop var to implement for loop, which always sink loop
  1187. FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
  1188. MS_LOG(DEBUG) << "Process ast For, create an if else statement";
  1189. MS_EXCEPTION_IF_NULL(block);
  1190. // Create statement 'len(xs) < MAX_FOR_LOOP_COUNT'
  1191. AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
  1192. py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
  1193. AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
  1194. CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
  1195. CNodePtr bool_node = block->func_graph()->NewCNodeInOrder(
  1196. {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(GetForTransToWhileLoop())});
  1197. // Create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
  1198. FunctionBlockPtr true_block = nullptr;
  1199. FunctionBlockPtr false_block = nullptr;
  1200. {
  1201. TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
  1202. true_block = MakeFunctionBlock(*this);
  1203. }
  1204. {
  1205. TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
  1206. false_block = MakeFunctionBlock(*this);
  1207. }
  1208. MakeConditionBlocks(block, true_block, false_block);
  1209. FunctionBlockPtr after_block = nullptr;
  1210. {
  1211. TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
  1212. after_block = MakeFunctionBlock(*this);
  1213. }
  1214. FunctionBlockPtr true_end = ParseForIter(true_block, node);
  1215. true_end->Jump(after_block, {});
  1216. FunctionBlockPtr false_end = ParseForLoop(false_block, node);
  1217. false_end->Jump(after_block, {});
  1218. block->ConditionalJump(bool_node, true_block, false_block);
  1219. after_block->Mature();
  1220. return after_block;
  1221. }
  1222. // A for loop will generate 3 functions :the test, the body, and the continuation
  1223. // for x in xs:
  1224. // body
  1225. // It is compiled to be following statement
  1226. // it = iter(xs)
  1227. // while hastnext(it)
  1228. // x, it = next(it)
  1229. // body
  1230. FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) {
  1231. MS_LOG(DEBUG) << "Process ast For";
  1232. MS_EXCEPTION_IF_NULL(block);
  1233. AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
  1234. AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
  1235. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1236. AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
  1237. // Generate the iterator apply
  1238. CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter);
  1239. MS_EXCEPTION_IF_NULL(iter_apply);
  1240. FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
  1241. MS_EXCEPTION_IF_NULL(header_block);
  1242. // Generate the hasnext apply which is a condition
  1243. ParameterPtr iter_param = header_block->func_graph()->add_parameter();
  1244. CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext);
  1245. // Generate the body of the for statement
  1246. FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
  1247. MS_EXCEPTION_IF_NULL(body_block);
  1248. body_block->AddPrevBlock(header_block);
  1249. // Generate the iterator next apply
  1250. // Process as following: `app = next(it); target = app[0]; it = app[1];`
  1251. CNodePtr app = body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
  1252. CNodePtr target_app =
  1253. body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(0))});
  1254. py::object target_node = python_adapter::GetPyObjAttr(node, "target");
  1255. CNodePtr iter2_app =
  1256. body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(1))});
  1257. WriteAssignVars(body_block, target_node, target_app);
  1258. // Link the variable name with the target
  1259. auto it_info = std::make_shared<TraceIterator>(target_app->debug_info());
  1260. iter_param->debug_info()->set_trace_info(it_info);
  1261. iter2_app->debug_info()->set_trace_info(it_info);
  1262. iter_apply->debug_info()->set_trace_info(it_info);
  1263. FunctionBlockPtr after_block = nullptr;
  1264. {
  1265. TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
  1266. after_block = MakeFunctionBlock(*this);
  1267. }
  1268. MS_EXCEPTION_IF_NULL(after_block);
  1269. after_block->AddPrevBlock(header_block);
  1270. block->Jump(header_block, {iter_apply});
  1271. body_block->Mature();
  1272. header_block->ConditionalJump(cond_apply, body_block, after_block);
  1273. // Parse loop body statements with loop context.
  1274. LoopContext loop_context{&loops_, header_block, iter2_app};
  1275. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1276. FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
  1277. if (after_body_block->func_graph()->get_return() == nullptr) {
  1278. after_body_block->Jump(header_block, {iter2_app});
  1279. }
  1280. header_block->Mature();
  1281. after_block->Mature();
  1282. auto &end_block = loop_context.EndBlock();
  1283. if (end_block) {
  1284. // end_block exists if we encounter 'break' in loop body.
  1285. after_block->Jump(end_block, {});
  1286. end_block->Mature();
  1287. return end_block;
  1288. }
  1289. // No 'break', no end_block.
  1290. return after_block;
  1291. }
  1292. // A for loop will generate 3 functions :the test, the body, and the continuation
  1293. // for x in xs:
  1294. // body
  1295. // It is compiled to be following statement
  1296. // i = 0
  1297. // while i < len(xs)
  1298. // x = xs[i]
  1299. // i = i + 1
  1300. // body
  1301. FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) {
  1302. MS_LOG(DEBUG) << "Process ast For by loop variable";
  1303. MS_EXCEPTION_IF_NULL(block);
  1304. AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
  1305. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1306. // Get variable name of 'x' in statement 'for x in xs'
  1307. py::object target_node = python_adapter::GetPyObjAttr(node, "target");
  1308. // Create statement 'len(xs)'
  1309. py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
  1310. AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
  1311. MS_EXCEPTION_IF_NULL(iter_node);
  1312. // Generate node for loop count and convert it to tensor, to make the loop not unroll
  1313. CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
  1314. auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations");
  1315. auto scalar_to_tensor_node = block->func_graph()->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
  1316. CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, scalar_len});
  1317. FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
  1318. MS_EXCEPTION_IF_NULL(header_block);
  1319. // Create loop variable 'i'
  1320. ParameterPtr loop_var = header_block->func_graph()->add_parameter();
  1321. // Create loop condition 'i < len(xs)'
  1322. auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations");
  1323. auto less_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(prim_less)});
  1324. CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({less_node, loop_var, len_iter});
  1325. // Generate the body of the for statement
  1326. FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
  1327. MS_EXCEPTION_IF_NULL(body_block);
  1328. body_block->AddPrevBlock(header_block);
  1329. // Create 'x = xs[i]'
  1330. auto body_func_graph = body_block->func_graph();
  1331. CNodePtr target_var = body_func_graph->NewCNodeInOrder({op_getitem, iter_node, loop_var});
  1332. WriteAssignVars(body_block, target_node, target_var);
  1333. // Create 'i = i + 1'
  1334. auto prim_add = prim::GetPythonOps("Add", "mindspore.ops.operations");
  1335. auto add_node = body_func_graph->NewCNodeInOrder({NewValueNode(prim_add)});
  1336. auto body_scalar_to_tensor_node = body_func_graph->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
  1337. auto add_tensor_node =
  1338. body_func_graph->NewCNodeInOrder({body_scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(1))});
  1339. CNodePtr loop_var_inc = body_func_graph->NewCNodeInOrder({add_node, loop_var, add_tensor_node});
  1340. body_block->WriteVariable(loop_var->name(), loop_var_inc);
  1341. // Link the variable name with the target
  1342. auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
  1343. loop_var->debug_info()->set_trace_info(it_info);
  1344. len_iter->debug_info()->set_trace_info(it_info);
  1345. FunctionBlockPtr after_block = nullptr;
  1346. {
  1347. TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
  1348. after_block = MakeFunctionBlock(*this);
  1349. }
  1350. MS_EXCEPTION_IF_NULL(after_block);
  1351. after_block->AddPrevBlock(header_block);
  1352. CNodePtr zero_tensor =
  1353. block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(0))});
  1354. block->Jump(header_block, {zero_tensor});
  1355. body_block->Mature();
  1356. header_block->ConditionalJump(cond_node, body_block, after_block, false);
  1357. // Parse loop body statements with loop context.
  1358. LoopContext loop_context{&loops_, header_block, loop_var_inc};
  1359. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1360. FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
  1361. if (after_body_block->func_graph()->get_return() == nullptr) {
  1362. after_body_block->Jump(header_block, {loop_var_inc});
  1363. }
  1364. header_block->Mature();
  1365. after_block->Mature();
  1366. auto &end_block = loop_context.EndBlock();
  1367. if (end_block) {
  1368. // end_block exists if we encounter 'break' in loop body.
  1369. after_block->Jump(end_block, {});
  1370. end_block->Mature();
  1371. return end_block;
  1372. }
  1373. // No 'break', no end_block.
  1374. return after_block;
  1375. }
  1376. AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) {
  1377. MS_LOG(DEBUG) << "Process ast IfExp";
  1378. MS_EXCEPTION_IF_NULL(block);
  1379. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  1380. AnfNodePtr condition_node = ParseExprNode(block, test_node);
  1381. CNodePtr bool_node = block->ForceToBoolNode(condition_node);
  1382. FunctionBlockPtr true_block = nullptr;
  1383. FunctionBlockPtr false_block = nullptr;
  1384. {
  1385. TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
  1386. true_block = MakeFunctionBlock(*this);
  1387. }
  1388. {
  1389. TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
  1390. false_block = MakeFunctionBlock(*this);
  1391. }
  1392. MakeConditionBlocks(block, true_block, false_block);
  1393. // Process the if-true branch
  1394. py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
  1395. true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode));
  1396. AnfNodePtr true_node = ParseExprNode(true_block, bodyNode);
  1397. // Process the orelse branch
  1398. py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
  1399. false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode));
  1400. AnfNodePtr false_node = ParseExprNode(false_block, orelseNode);
  1401. true_block->func_graph()->set_output(true_node);
  1402. false_block->func_graph()->set_output(false_node);
  1403. // Use the Primitive replace the operation resolve node (switch),
  1404. // because the switch will eventually be converted to Primitive node
  1405. CNodePtr switch_app = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), bool_node,
  1406. NewValueNode(true_block->func_graph()),
  1407. NewValueNode(false_block->func_graph())});
  1408. std::vector<AnfNodePtr> call_graph_nodes{switch_app};
  1409. CNodePtr switch_app_call = block->func_graph()->NewCNodeInOrder(std::move(call_graph_nodes));
  1410. return switch_app_call;
  1411. }
  1412. FunctionBlockPtr Parser::ParseListCompIter(const FunctionBlockPtr &block, const py::object &node,
  1413. const py::object &generator_node) {
  1414. // Create a header block.
  1415. FunctionBlockPtr top_block = GenerateBlock(std::make_shared<TraceListComp>(block->func_graph()->debug_info()));
  1416. // Handle iter attribute.
  1417. py::object iter_node = python_adapter::GetPyObjAttr(generator_node, "iter");
  1418. AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
  1419. AnfNodePtr op_iter = top_block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
  1420. CNodePtr iter_apply = top_block->func_graph()->NewCNodeInOrder({op_iter, iter_anf_node});
  1421. // Create header graph.
  1422. FunctionBlockPtr list_header_block =
  1423. GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
  1424. list_header_block->AddPrevBlock(top_block);
  1425. // Create hasNext apply.
  1426. AnfNodePtr op_hasnext = top_block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
  1427. ParameterPtr iter_param = list_header_block->func_graph()->add_parameter();
  1428. constexpr auto iter_param_name = "iter";
  1429. iter_param->set_name(iter_param_name);
  1430. iter_param->debug_info()->set_name(iter_param_name);
  1431. CNodePtr cond_apply = list_header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
  1432. // Call the header graph with iter.
  1433. ParameterPtr list_param = list_header_block->func_graph()->add_parameter();
  1434. constexpr auto list_param_name = "list";
  1435. list_param->set_name(list_param_name);
  1436. list_param->debug_info()->set_name(list_param_name);
  1437. auto empty_list = std::vector<ValuePtr>();
  1438. AnfNodePtr empty_list_node = NewValueNode(std::make_shared<ValueList>(empty_list));
  1439. top_block->Jump(list_header_block, {iter_apply, empty_list_node});
  1440. // Create body graph.
  1441. FunctionBlockPtr list_body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
  1442. list_body_block->AddPrevBlock(list_header_block);
  1443. AnfNodePtr op_next = top_block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
  1444. CNodePtr next_apply = list_body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
  1445. AnfNodePtr op_getitem = top_block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1446. CNodePtr item_apply =
  1447. list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(0))});
  1448. CNodePtr new_iter =
  1449. list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(1))});
  1450. // Save the `target` in a variable.
  1451. py::object gen_target_node = python_adapter::GetPyObjAttr(generator_node, "target");
  1452. WriteAssignVars(list_body_block, gen_target_node, item_apply);
  1453. auto ifs_new_list = ParseListCompIfs(list_body_block, list_param, node, generator_node);
  1454. list_body_block->Jump(list_header_block, {new_iter, ifs_new_list});
  1455. // Create after graph.
  1456. FunctionBlockPtr list_after_block = GenerateBlock(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
  1457. list_after_block->AddPrevBlock(list_header_block);
  1458. // Return the list in after graph.
  1459. list_after_block->func_graph()->set_output(list_param);
  1460. // Run the branches.
  1461. list_header_block->ConditionalJump(cond_apply, list_body_block, list_after_block);
  1462. top_block->Mature();
  1463. list_header_block->Mature();
  1464. list_body_block->Mature();
  1465. list_after_block->Mature();
  1466. return top_block;
  1467. }
  1468. AnfNodePtr Parser::ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
  1469. const py::object &node, const py::object &generator_node) {
  1470. // Handle ifs attribute.
  1471. py::list ifs_node = python_adapter::GetPyObjAttr(generator_node, "ifs");
  1472. AnfNodePtr ifs_bool_node;
  1473. if (ifs_node.empty()) {
  1474. ifs_bool_node = NewValueNode(true);
  1475. } else {
  1476. ifs_bool_node = ProcessBoolOpValueList(list_body_block, ifs_node, AST_SUB_TYPE_AND);
  1477. }
  1478. // Create if-true graph.
  1479. FunctionBlockPtr if_true_block =
  1480. GenerateBlock(std::make_shared<TraceIfStmtTrueBranch>(list_body_block->func_graph()->debug_info()));
  1481. if_true_block->AddPrevBlock(list_body_block);
  1482. // Handle elt attribute in body block.
  1483. py::object elt_obj = python_adapter::GetPyObjAttr(node, "elt");
  1484. AnfNodePtr elt_node = ParseExprNode(list_body_block, elt_obj);
  1485. // Append the element.
  1486. auto list_append_op = prim::kPrimListAppend;
  1487. auto new_list = list_body_block->func_graph()->NewCNodeInOrder({NewValueNode(list_append_op), list_param, elt_node});
  1488. // Return new list in true branch graph.
  1489. if_true_block->func_graph()->set_output(new_list);
  1490. // Create if-false graph.
  1491. FunctionBlockPtr if_false_block =
  1492. GenerateBlock(std::make_shared<TraceIfStmtFalseBranch>(list_body_block->func_graph()->debug_info()));
  1493. if_false_block->AddPrevBlock(list_body_block);
  1494. // Return original list in false branch graph.
  1495. if_false_block->func_graph()->set_output(list_param);
  1496. // We don't want to create a header graph, where to get and wrap the result of Switch().
  1497. // So just call ConditionalJump() to set Switch() as output, and reset it later, as tricky.
  1498. list_body_block->ConditionalJump(ifs_bool_node, if_true_block, if_false_block);
  1499. // Output is Switch() result, i.e. updated list.
  1500. auto switch_apply_node = list_body_block->func_graph()->output();
  1501. auto ifs_new_list = switch_apply_node;
  1502. // Since we call ConditionalJump() above, to reset the Return as null before call Jump().
  1503. list_body_block->func_graph()->set_return(nullptr);
  1504. if_true_block->Mature();
  1505. if_false_block->Mature();
  1506. return ifs_new_list;
  1507. }
  1508. // A ListComp contains: `elt` and `generators`.
  1509. // `generators` contains: `target`, `iter` and `ifs`.
  1510. // For example:
  1511. // [x * x for x in range(0, 10) if x % 2 == 0]
  1512. // It is compiled to be following statement:
  1513. // list = []
  1514. // for x in range(0, 10):
  1515. // if x % 2 == 0:
  1516. // list.append(x * x)
  1517. // return list
  1518. AnfNodePtr Parser::ParseListComp(const FunctionBlockPtr &block, const py::object &node) {
  1519. MS_LOG(DEBUG) << "Process ast ListComp";
  1520. MS_EXCEPTION_IF_NULL(block);
  1521. // Handle generators attribute.
  1522. py::list generators_node = python_adapter::GetPyObjAttr(node, "generators");
  1523. if (generators_node.size() != 1) {
  1524. MS_EXCEPTION(TypeError) << "The 'generators' supports 1 'comprehension' in ListComp/GeneratorExp, but got "
  1525. << generators_node.size() << " comprehensions.";
  1526. }
  1527. py::object generator_node = generators_node[0];
  1528. auto generator_node_type = ast_->GetNodeType(generator_node);
  1529. auto generator_node_name = generator_node_type->node_name();
  1530. constexpr auto comprehension_name = "comprehension";
  1531. if (generator_node_name != comprehension_name) {
  1532. MS_LOG(EXCEPTION) << "Generator node name should be " << comprehension_name << ", but got " << generator_node_name;
  1533. }
  1534. // Parse ListComp's `iter` and add `elt` in it.
  1535. auto top_block = ParseListCompIter(block, node, generator_node);
  1536. // Call the top graph and return the list.
  1537. auto call_function_node = NewValueNode(top_block->func_graph());
  1538. std::vector<AnfNodePtr> func_call_nodes;
  1539. func_call_nodes.push_back(call_function_node);
  1540. AnfNodePtr output = block->func_graph()->NewCNodeInOrder(std::move(func_call_nodes));
  1541. return output;
  1542. }
  1543. void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target_object,
  1544. const AnfNodePtr &assigned_node) {
  1545. MS_EXCEPTION_IF_NULL(block);
  1546. MS_EXCEPTION_IF_NULL(assigned_node);
  1547. py::str name = python_adapter::GetPyObjAttr(target_object, "id");
  1548. std::string name_id = name;
  1549. assigned_node->debug_info()->set_name(name_id);
  1550. // Set the debug name of the constant graph
  1551. if (IsValueNode<FuncGraph>(assigned_node)) {
  1552. // The value should be graph
  1553. auto fg = GetValueNode<FuncGraphPtr>(assigned_node);
  1554. if (fg->debug_info()->name().empty()) {
  1555. fg->debug_info()->set_name(name_id);
  1556. }
  1557. }
  1558. MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString();
  1559. block->WriteVariable(name_id, assigned_node);
  1560. }
  1561. void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &target_object,
  1562. const AnfNodePtr &assigned_node) {
  1563. MS_EXCEPTION_IF_NULL(block);
  1564. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1565. py::list items = python_adapter::GetPyObjAttr(target_object, "elts");
  1566. for (size_t i = 0; i < items.size(); i++) {
  1567. // Use the Primitive replace the operation resolve node (getitem),
  1568. // because the getitem will eventually be converted to Primitive node
  1569. CNodePtr item_apply =
  1570. block->func_graph()->NewCNodeInOrder({op_getitem, assigned_node, NewValueNode(static_cast<int64_t>(i))});
  1571. py::object elt = items[i];
  1572. WriteAssignVars(block, elt, item_apply);
  1573. }
  1574. }
  1575. void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target_object,
  1576. const AnfNodePtr &assigned_node) {
  1577. // Now only support the self.xx = xxxxx, can't support x.y = xxxx
  1578. AnfNodePtr target_node = ParseExprNode(block, target_object);
  1579. MS_EXCEPTION_IF_NULL(target_node);
  1580. auto attr_name = target_object.attr("attr").cast<std::string>();
  1581. std::string var_name = "self." + attr_name;
  1582. // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
  1583. if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
  1584. MS_EXCEPTION(TypeError)
  1585. << "'" << var_name << "' should be initialized as a 'Parameter' in the '__init__' function before assigning.\n\n"
  1586. << trace::GetDebugInfo(target_node->debug_info());
  1587. }
  1588. auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
  1589. auto obj_type = obj.attr("__class__").attr("__name__");
  1590. if (!py::hasattr(obj, "__parameter__")) {
  1591. MS_EXCEPTION(TypeError) << "'" << var_name
  1592. << "' should be initialized as a 'Parameter' type in the '__init__' function, but got '"
  1593. << py::str(obj).cast<std::string>() << "' with type '"
  1594. << py::str(obj_type).cast<std::string>() << ".\n\n"
  1595. << trace::GetDebugInfo(target_node->debug_info());
  1596. }
  1597. MS_EXCEPTION_IF_NULL(block);
  1598. MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString();
  1599. block->SetStateAssign(target_node, assigned_node);
  1600. }
  1601. void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target_object,
  1602. const AnfNodePtr &assigned_node) {
  1603. MS_EXCEPTION_IF_NULL(block);
  1604. AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
  1605. py::object value_obj = python_adapter::GetPyObjAttr(target_object, "value");
  1606. py::object slice_obj = python_adapter::GetPyObjAttr(target_object, "slice");
  1607. AnfNodePtr value_node = ParseExprNode(block, value_obj);
  1608. AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
  1609. CNodePtr setitem_app = block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node});
  1610. // Getitem apply should return the sequence data structure itself
  1611. std::string var_name;
  1612. if (ast_->IsClassMember(value_obj)) {
  1613. auto attr_name = value_obj.attr("attr").cast<std::string>();
  1614. var_name = "self." + attr_name;
  1615. if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
  1616. MS_EXCEPTION(TypeError)
  1617. << "'" << var_name
  1618. << "' should be initialized as a 'Parameter' in the '__init__' function before assigning.\n\n"
  1619. << trace::GetDebugInfo(value_node->debug_info());
  1620. }
  1621. auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
  1622. auto obj_type = obj.attr("__class__").attr("__name__");
  1623. if (!py::hasattr(obj, "__parameter__")) {
  1624. MS_EXCEPTION(TypeError) << "'" << var_name
  1625. << "' should be initialized as a 'Parameter' in the '__init__' function, but got '"
  1626. << py::str(obj).cast<std::string>() << "' with type '"
  1627. << py::str(obj_type).cast<std::string>() << "'.\n\n"
  1628. << trace::GetDebugInfo(value_node->debug_info());
  1629. }
  1630. block->WriteVariable(var_name, setitem_app);
  1631. return;
  1632. }
  1633. if (AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, value_obj))) ==
  1634. AST_SUB_TYPE_SUBSCRIPT) {
  1635. HandleAssignSubscript(block, value_obj, setitem_app);
  1636. return;
  1637. }
  1638. if (!py::hasattr(value_obj, "id")) {
  1639. MS_EXCEPTION(TypeError) << "Attribute id not found in " << py::str(value_obj).cast<std::string>() << "\n\n"
  1640. << trace::GetDebugInfo(value_node->debug_info());
  1641. }
  1642. var_name = value_obj.attr("id").cast<std::string>();
  1643. block->WriteVariable(var_name, setitem_app);
  1644. }
  1645. void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_object,
  1646. const AnfNodePtr &value_node) {
  1647. MS_EXCEPTION_IF_NULL(value_node);
  1648. MS_LOG(DEBUG) << "Process WriteAssignVars";
  1649. auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object)));
  1650. if (ast_type == AST_SUB_TYPE_NAME) {
  1651. HandleAssignName(block, target_object, value_node);
  1652. } else if (ast_type == AST_SUB_TYPE_TUPLE) {
  1653. HandleAssignTuple(block, target_object, value_node);
  1654. } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
  1655. HandleAssignSubscript(block, target_object, value_node);
  1656. } else if (ast_->IsClassMember(target_object)) {
  1657. HandleAssignClassMember(block, target_object, value_node);
  1658. } else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
  1659. TraceGuard(GetLocation(target_object));
  1660. MS_EXCEPTION(TypeError) << "Only support assign to attribute of self, but got attribute of "
  1661. << py::str(target_object.attr("value").attr("id")) << ".\n"
  1662. << "More details please refer to syntax support at https://www.mindspore.cn";
  1663. } else {
  1664. TraceGuard(GetLocation(target_object));
  1665. MS_EXCEPTION(TypeError) << "Only supported augassign to attribute of self, variable and index value, but got "
  1666. << target_object.get_type()
  1667. << ".\nMore details please refer to syntax support at https://www.mindspore.cn";
  1668. }
  1669. }
  1670. AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
  1671. const py::object &value_object) {
  1672. // The fallback feature is enabled in default.
  1673. // Not support change the flag during the process is alive.
  1674. static const auto use_fallback = (support_fallback() != "0");
  1675. if (!use_fallback || !value_node->interpret()) {
  1676. return value_node;
  1677. }
  1678. const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
  1679. // Prepare global parameters.
  1680. py::dict global_dict = block->global_py_params();
  1681. ValuePtr globals_converted_value = nullptr;
  1682. if (!ConvertData(global_dict, &globals_converted_value)) {
  1683. MS_LOG(EXCEPTION) << "Convert data failed";
  1684. }
  1685. auto global_dict_node = NewValueNode(globals_converted_value);
  1686. // Prepare local parameters.
  1687. auto [keys, values] = block->local_py_params();
  1688. auto local_dict_node = ParseDictByKeysAndValues(block, keys, values);
  1689. // Update the valued node if it need interpreting.
  1690. constexpr int recursive_level = 2;
  1691. MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: `" << script_text
  1692. << "`,\nvalue_node: " << value_node->DebugString(recursive_level)
  1693. << ",\nglobal_dict_node: " << global_dict_node->ToString()
  1694. << ",\nlocal_dict_node: " << local_dict_node->DebugString(recursive_level);
  1695. AnfNodePtr interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node);
  1696. // Print a hint for user.
  1697. auto line_info = trace::GetDebugInfo(value_node->debug_info());
  1698. MS_LOG(INFO) << "Found unsupported syntax in Graph mode, those codes would be fallen back to Python interpreter:"
  1699. << "\n\n"
  1700. << line_info;
  1701. InterpretNodeRecorder::GetInstance().PushLineInfo(line_info);
  1702. return interpreted_node;
  1703. }
  1704. // Process a assign statement, such as a = b, a, b = tuple(xx, xx)
  1705. FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
  1706. MS_LOG(DEBUG) << "Process ast assign";
  1707. py::object value_object = python_adapter::GetPyObjAttr(node, "value");
  1708. AnfNodePtr value_node = ParseExprNode(block, value_object);
  1709. value_node = HandleInterpret(block, value_node, value_object);
  1710. py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
  1711. py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__");
  1712. size_t count = LongToSize(pcount);
  1713. MS_LOG(DEBUG) << "The nodes count is " << count;
  1714. for (size_t i = 0; i < count; i++) {
  1715. auto target_node = py::cast<py::list>(targets_object)[i];
  1716. WriteAssignVars(block, target_node, value_node);
  1717. }
  1718. return block;
  1719. }
  1720. FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) {
  1721. if (loops_.empty()) {
  1722. // Report error if loop context not set for the 'break' statement.
  1723. MS_LOG(EXCEPTION) << "Unexpected 'break'.";
  1724. }
  1725. // Get current loop.
  1726. Loop &loop = loops_.top();
  1727. if (loop.end == nullptr) {
  1728. // Create end_block if it is not existed.
  1729. TraceGuard trace_guard(std::make_shared<TraceLoopEnd>(block->func_graph()->debug_info()));
  1730. loop.end = MakeFunctionBlock(*this);
  1731. }
  1732. // Jump to the end_block.
  1733. block->Jump(loop.end, {});
  1734. return block;
  1735. }
  1736. FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) {
  1737. if (loops_.empty()) {
  1738. // Report error if loop context not set for the 'continue' statement.
  1739. MS_LOG(EXCEPTION) << "Unexpected 'continue.";
  1740. }
  1741. // Jump to the header of the loop with iterator called.
  1742. Loop &loop = loops_.top();
  1743. std::vector<AnfNodePtr> args;
  1744. if (loop.iterator != nullptr) {
  1745. args.emplace_back(loop.iterator);
  1746. }
  1747. block->Jump(loop.header, args);
  1748. return block;
  1749. }
  1750. FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) {
  1751. // We just bypass 'pass' statement.
  1752. return block;
  1753. }
  1754. AnfNodePtr FindPhis(const mindspore::HashMap<ParameterPtr, AnfNodePtr> &removable_phis, const AnfNodePtr &node) {
  1755. MS_EXCEPTION_IF_NULL(node);
  1756. const auto &inp = node->cast<ParameterPtr>();
  1757. const auto &iter = removable_phis.find(inp);
  1758. if (iter == removable_phis.end()) {
  1759. return node;
  1760. }
  1761. return FindPhis(removable_phis, iter->second);
  1762. }
  1763. void Parser::RemoveUnnecessaryPhis() {
  1764. // Merge all removable phis to one map;
  1765. mindspore::HashMap<ParameterPtr, AnfNodePtr> removable_phis;
  1766. std::vector<ParameterPtr> phis;
  1767. for (FunctionBlockPtr &block : func_block_list_) {
  1768. MS_EXCEPTION_IF_NULL(block);
  1769. removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end());
  1770. std::transform(block->removable_phis().begin(), block->removable_phis().end(), std::back_inserter(phis),
  1771. [](const auto &pair) { return pair.first; });
  1772. }
  1773. if (removable_phis.empty()) {
  1774. return;
  1775. }
  1776. auto mng = Manage(func_graph_, false);
  1777. // Replace the nodes
  1778. // Remove from inside to outside
  1779. for (int64_t idx = SizeToLong(phis.size() - 1); idx >= 0; idx--) {
  1780. auto phi = phis[LongToSize(idx)];
  1781. auto new_node = FindPhis(removable_phis, phi);
  1782. mng->Replace(phi, new_node);
  1783. }
  1784. // Remove the parameter
  1785. for (FunctionBlockPtr &block : func_block_list_) {
  1786. MS_EXCEPTION_IF_NULL(block);
  1787. auto &local_removable_phis = block->removable_phis();
  1788. if (local_removable_phis.empty()) {
  1789. continue;
  1790. }
  1791. auto func_graph = block->func_graph();
  1792. auto &parameters = func_graph->parameters();
  1793. std::vector<AnfNodePtr> new_parameters(parameters.size());
  1794. auto it = std::copy_if(
  1795. parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](const AnfNodePtr &param) {
  1796. MS_EXCEPTION_IF_NULL(param);
  1797. return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end();
  1798. });
  1799. // Shrink container to new size
  1800. new_parameters.resize(static_cast<size_t>(std::distance(new_parameters.begin(), it)));
  1801. func_graph->set_parameters(new_parameters);
  1802. }
  1803. }
  1804. // ParseFunctionAst class code
  1805. bool ParseFunctionAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) {
  1806. // Init the type
  1807. target_type_ = PARSE_TARGET_UNKNOW;
  1808. // Call python parse, get the parser fn
  1809. module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  1810. py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD);
  1811. // Get the obj type
  1812. auto type = data_converter::GetObjType(obj_);
  1813. if (type == RESOLVE_TYPE_FUNCTION) {
  1814. target_type_ = PARSE_TARGET_FUNCTION;
  1815. function_ = obj_;
  1816. } else if (type == RESOLVE_TYPE_METHOD) {
  1817. // Process the method ,need get the method's self obj
  1818. target_type_ = PARSE_TARGET_METHOD;
  1819. py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS);
  1820. if (py::isinstance<py::none>(method_object)) {
  1821. MS_LOG(ERROR) << "Get method's self object instance failed.";
  1822. return false;
  1823. }
  1824. target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
  1825. function_ = obj_;
  1826. obj_ = method_object;
  1827. } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) {
  1828. // obj is class instance, get the method to parse.
  1829. function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method);
  1830. if (py::isinstance<py::none>(function_)) {
  1831. MS_LOG(ERROR) << "Get obj method function failed.";
  1832. return false;
  1833. }
  1834. target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
  1835. // Check the fn is method
  1836. auto obj_type = data_converter::GetObjType(function_);
  1837. if (obj_type != RESOLVE_TYPE_METHOD) {
  1838. MS_LOG(WARNING) << "Parse method function is invalid.";
  1839. return false;
  1840. }
  1841. } else {
  1842. MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type;
  1843. return false;
  1844. }
  1845. // Call python parse get ast tree
  1846. parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method);
  1847. py::tuple ast_info = python_adapter::CallPyObjMethod(parser_, "parse");
  1848. const size_t ast_info_size = 2;
  1849. if (ast_info.size() != ast_info_size) {
  1850. MS_EXCEPTION(NameError) << "ast info size is not equal to 2.";
  1851. }
  1852. ast_tokens_ = ast_info[0];
  1853. ast_tree_ = ast_info[1];
  1854. // Get fn name and module
  1855. function_module_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_module"));
  1856. function_name_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_name"));
  1857. function_filename_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "filename"));
  1858. function_line_offset_ = py::cast<int64_t>(python_adapter::GetPyObjAttr(parser_, "line_offset"));
  1859. return true;
  1860. }
  1861. // Get ast tree node : is the tree bode list[0]
  1862. py::object ParseFunctionAst::GetAstNode() {
  1863. py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body");
  1864. py::object ast_node = tree_body[0];
  1865. return ast_node;
  1866. }
  1867. // Get ast tokens node text.
  1868. py::str ParseFunctionAst::GetAstNodeText(const py::object &node_obj) {
  1869. return python_adapter::CallPyObjMethod(ast_tokens_, "get_text", node_obj);
  1870. }
  1871. py::list ParseFunctionAst::GetArgs(const py::object &func_node) {
  1872. py::list ret = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS, func_node);
  1873. return ret;
  1874. }
  1875. py::list ParseFunctionAst::GetArgsDefaultValues(const py::object &func_node) {
  1876. py::list ret = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node);
  1877. return ret;
  1878. }
  1879. AstNodeTypePtr ParseFunctionAst::GetNodeType(const py::object &node) {
  1880. py::list list_value = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_NODE_TYPE, node);
  1881. const size_t list_value_size = 2;
  1882. if (list_value.size() < list_value_size) {
  1883. MS_LOG(EXCEPTION) << "The node of python method must has 2 values.";
  1884. }
  1885. auto node_name = py::cast<std::string>(list_value[0]);
  1886. auto type = AstMainType(py::cast<int32_t>(list_value[1]));
  1887. return std::make_shared<AstNodeType>(node, node_name, type);
  1888. }
  1889. AstSubType ParseFunctionAst::GetOpType(const py::object &node) {
  1890. auto op_type = AstSubType(python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_AST_TYPE, node).cast<int32_t>());
  1891. return op_type;
  1892. }
  1893. bool ParseFunctionAst::IsClassMember(const py::object &node) {
  1894. py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node);
  1895. if (!py::isinstance<py::bool_>(ret)) {
  1896. MS_LOG(ERROR) << "The result of mod function parse, should be bool type.";
  1897. return false;
  1898. }
  1899. return ret.cast<bool>();
  1900. }
  1901. void SetMixedPrecisionFlag(const py::object &obj, const FuncGraphPtr &func_graph) {
  1902. MS_EXCEPTION_IF_NULL(func_graph);
  1903. if (!py::isinstance<Cell>(obj)) {
  1904. return;
  1905. }
  1906. auto cell = py::cast<CellPtr>(obj);
  1907. MS_EXCEPTION_IF_NULL(cell);
  1908. auto mixed_type = cell->GetMixedPrecisionType();
  1909. if (mixed_type != MixedPrecisionType::kNotSet) {
  1910. func_graph->set_flag(GRAPH_FLAG_MIX_PRECISION_FP16, mixed_type == MixedPrecisionType::kFP16);
  1911. func_graph->set_flag(GRAPH_FLAG_MIX_PRECISION_FP32, mixed_type == MixedPrecisionType::kFP32);
  1912. }
  1913. }
  1914. bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph) {
  1915. if (func_graph == nullptr) {
  1916. MS_LOG(ERROR) << "FuncGraph is null";
  1917. return false;
  1918. }
  1919. SetMixedPrecisionFlag(obj, func_graph);
  1920. if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) {
  1921. MS_LOG(DEBUG) << "No flags";
  1922. return true;
  1923. }
  1924. py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG);
  1925. for (auto &item : flags) {
  1926. if (!py::isinstance<py::str>(item.first)) {
  1927. MS_LOG(ERROR) << "Type error in flags dict convert";
  1928. return false;
  1929. }
  1930. auto name = py::cast<std::string>(item.first);
  1931. if (py::isinstance<py::bool_>(item.second)) {
  1932. auto value = py::cast<bool>(item.second);
  1933. MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
  1934. func_graph->set_flag(name, value);
  1935. } else if (py::isinstance<py::str>(item.second)) {
  1936. auto value = py::cast<std::string>(item.second);
  1937. MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
  1938. func_graph->set_attr(name, MakeValue(value));
  1939. } else {
  1940. MS_LOG(ERROR) << "Type error in flags/attrs dict convert";
  1941. return false;
  1942. }
  1943. }
  1944. return true;
  1945. }
  1946. // Generate and copy a ValueNode, or a CNode with its child nodes
  1947. static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr &func_graph, const AnfNodePtr &param_node) {
  1948. MS_EXCEPTION_IF_NULL(param_node);
  1949. if (param_node->isa<ValueNode>()) {
  1950. return std::make_shared<ValueNode>(param_node->cast<ValueNodePtr>()->value());
  1951. }
  1952. // Parameter default value is CNode.
  1953. std::size_t index = 0;
  1954. std::vector<AnfNodePtr> old_cnodes;
  1955. old_cnodes.emplace_back(param_node);
  1956. MS_EXCEPTION_IF_NULL(func_graph);
  1957. auto res = func_graph->NewCNodeInOrder({});
  1958. std::vector<CNodePtr> new_cnodes;
  1959. new_cnodes.emplace_back(res);
  1960. while (index < old_cnodes.size()) {
  1961. auto current = old_cnodes[index];
  1962. auto current_new_cnode = new_cnodes[index];
  1963. index++;
  1964. if (current->isa<CNode>()) {
  1965. auto &inputs = current->cast<CNodePtr>()->inputs();
  1966. for (auto it = inputs.begin(); it != inputs.end(); it++) {
  1967. AnfNodePtr input = *it;
  1968. if (input != nullptr && input->isa<CNode>()) {
  1969. old_cnodes.emplace_back(input);
  1970. auto new_cnode = func_graph->NewCNodeInOrder({});
  1971. new_cnodes.emplace_back(new_cnode);
  1972. current_new_cnode->add_input(new_cnode);
  1973. } else if (input->isa<ValueNode>()) {
  1974. current_new_cnode->add_input(std::make_shared<ValueNode>(input->cast<ValueNodePtr>()->value()));
  1975. } else {
  1976. MS_LOG(EXCEPTION) << "Wrong type item in default parameters: " << input->ToString();
  1977. }
  1978. }
  1979. }
  1980. }
  1981. return res;
  1982. }
  1983. FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
  1984. auto current_graph = dyn_cast<FuncGraph>(cell_ptr);
  1985. if (current_graph == nullptr) {
  1986. MS_LOG(EXCEPTION) << "Current graph cast failed from " << cell_ptr->ToString();
  1987. }
  1988. auto func_graph = std::make_shared<FuncGraph>();
  1989. func_graph->debug_info()->set_name(current_graph->debug_info()->name() + "_wrapper");
  1990. func_graph->debug_info()->set_location(current_graph->debug_info()->location());
  1991. // Copy all parameters information
  1992. for (auto &para : current_graph->parameters()) {
  1993. auto param = func_graph->add_parameter();
  1994. auto orig_param = para->cast<ParameterPtr>();
  1995. auto name = orig_param->name();
  1996. param->set_name(name);
  1997. param->debug_info()->set_name(name);
  1998. param->debug_info()->set_location(param->debug_info()->location());
  1999. param->set_is_top_graph_param(true);
  2000. }
  2001. func_graph->set_has_vararg(current_graph->has_vararg());
  2002. func_graph->set_has_kwarg(current_graph->has_kwarg());
  2003. func_graph->set_kwonlyargs_count(current_graph->kwonlyargs_count());
  2004. // Copy all default values
  2005. for (auto &d : current_graph->parameter_default_value()) {
  2006. func_graph->set_param_default_value(d.first, CopyNodesFromParamDefaultValue(func_graph, d.second));
  2007. }
  2008. // cell_obj
  2009. MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell));
  2010. parse::UpdateFuncGraphFlags(cell, func_graph);
  2011. // Top graph's construct flag
  2012. if (py::hasattr(cell, "construct")) {
  2013. parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph);
  2014. }
  2015. auto unpacking = func_graph->has_vararg() || func_graph->has_kwarg();
  2016. if (!unpacking) {
  2017. std::vector<AnfNodePtr> inputs;
  2018. inputs.emplace_back(NewValueNode(cell_ptr));
  2019. auto &params = func_graph->parameters();
  2020. (void)std::transform(params.begin(), params.end(), std::back_inserter(inputs),
  2021. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  2022. auto call_node = func_graph->NewCNodeInOrder(std::move(inputs));
  2023. TraceGuard guard(current_graph->get_return()->debug_info()->location());
  2024. func_graph->set_output(call_node);
  2025. } else {
  2026. // ret = cell_obj(*arg, *kwargs)
  2027. auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), func_graph->parameters());
  2028. TraceGuard guard(current_graph->get_return()->debug_info()->location());
  2029. // Set output as ret
  2030. func_graph->set_output(call_fn);
  2031. }
  2032. return func_graph;
  2033. }
  2034. } // namespace parse
  2035. } // namespace mindspore