You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parse.cc 79 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/parse/parse.h"
  19. #include <utility>
  20. #include <string>
  21. #include <memory>
  22. #include <unordered_map>
  23. #include <algorithm>
  24. #include "pipeline/jit/parse/resolve.h"
  25. #include "frontend/operator/ops.h"
  26. #include "pipeline/jit/parse/data_converter.h"
  27. #include "frontend/operator/composite/composite.h"
  28. #include "utils/ms_context.h"
  29. #include "debug/trace.h"
  30. namespace mindspore {
  31. namespace parse {
  32. FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) {
  33. (void)python_adapter::set_python_scoped();
  34. if (!obj || py::isinstance<py::none>(obj)) {
  35. MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none";
  36. return nullptr;
  37. }
  38. auto ast = std::make_shared<ParseAst>(obj);
  39. bool success = ast->InitParseAstInfo(python_mod_get_parse_method);
  40. if (!success) {
  41. MS_LOG(ERROR) << "Parse code to ast tree failed.";
  42. return nullptr;
  43. }
  44. auto parser = std::make_shared<Parser>(ast);
  45. FuncGraphPtr func_graph = parser->ParseFuncGraph();
  46. if (func_graph == nullptr) {
  47. MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode();
  48. return nullptr;
  49. }
  50. return func_graph;
  51. }
  52. TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
  53. if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
  54. return kFloat32;
  55. } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
  56. return kFloat16;
  57. } else {
  58. return nullptr;
  59. }
  60. }
  61. // if any mixed precision flag add a cast node after the parameter node.
  62. AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
  63. TypePtr dst_type;
  64. if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
  65. dst_type = kFloat32;
  66. } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
  67. dst_type = kFloat16;
  68. } else {
  69. return param;
  70. }
  71. auto cast_helper = prim::kPrimMixedPrecisionCast;
  72. auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param});
  73. return cast;
  74. }
  75. FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
  76. Parser::Parser(const std::shared_ptr<ParseAst> &ast) : ast_(ast) {
  77. errcode_ = PARSE_SUCCESS;
  78. BuildMethodMap();
  79. }
  80. void Parser::BuildMethodMap() {
  81. stmt_method_map_["Return"] = &Parser::ParseReturn;
  82. stmt_method_map_["Expr"] = &Parser::ParseExpr;
  83. stmt_method_map_["If"] = &Parser::ParseIf;
  84. stmt_method_map_["Assign"] = &Parser::ParseAssign;
  85. stmt_method_map_["While"] = &Parser::ParseWhile;
  86. stmt_method_map_["For"] = &Parser::ParseFor;
  87. stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef;
  88. stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign;
  89. stmt_method_map_["Global"] = &Parser::ParseGlobal;
  90. stmt_method_map_["Break"] = &Parser::ParseBreak;
  91. stmt_method_map_["Continue"] = &Parser::ParseContinue;
  92. stmt_method_map_["Pass"] = &Parser::ParsePass;
  93. expr_method_map_["NoneType"] = &Parser::ParseNone;
  94. expr_method_map_["BinOp"] = &Parser::ParseBinOp;
  95. expr_method_map_["Name"] = &Parser::ParseName;
  96. expr_method_map_["Num"] = &Parser::ParseNum;
  97. expr_method_map_["Str"] = &Parser::ParseStr;
  98. expr_method_map_["NameConstant"] = &Parser::ParseNameConstant;
  99. expr_method_map_["Call"] = &Parser::ParseCall;
  100. expr_method_map_["IfExp"] = &Parser::ParseIfExp;
  101. expr_method_map_["Attribute"] = &Parser::ParseAttribute;
  102. expr_method_map_["Compare"] = &Parser::ParseCompare;
  103. expr_method_map_["BoolOp"] = &Parser::ParseBoolOp;
  104. expr_method_map_["Lambda"] = &Parser::ParseLambda;
  105. expr_method_map_["Tuple"] = &Parser::ParseTuple;
  106. expr_method_map_["List"] = &Parser::ParseList;
  107. expr_method_map_["Subscript"] = &Parser::ParseSubscript;
  108. expr_method_map_["Slice"] = &Parser::ParseSlice;
  109. expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice;
  110. expr_method_map_["Index"] = &Parser::ParseIndex;
  111. expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
  112. expr_method_map_["Dict"] = &Parser::ParseDict;
  113. expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
  114. }
  115. void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
  116. void Parser::InitParserEnvironment(const py::object &obj) {
  117. Parser::top_func_graph_ = FuncGraphWeakPtr();
  118. ScopeManager::GetInstance().ClearScope();
  119. (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj);
  120. }
  121. void Parser::CleanParserResource() {
  122. Parser::top_func_graph_ = FuncGraphWeakPtr();
  123. ScopeManager::GetInstance().ClearScope();
  124. }
  125. AnfNodePtr AppendParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
  126. MS_EXCEPTION_IF_NULL(func_graph);
  127. auto value = py::cast<tensor::MetaTensorPtr>(obj);
  128. // parameter object should not be none
  129. if (value == nullptr || !value->is_parameter()) {
  130. MS_LOG(EXCEPTION) << "Parameter error: because obj is not Parameter object.";
  131. }
  132. // get the parameter name from parameter object
  133. auto param_name = value->param_info()->name();
  134. auto top_graph = func_graph;
  135. // if the parameter node has been created , return it
  136. AnfNodePtr para_node = nullptr;
  137. for (auto param : top_graph->parameters()) {
  138. auto param_node = dyn_cast<Parameter>(param);
  139. if (param_node != nullptr && param_node->name() == param_name) {
  140. para_node = param;
  141. break;
  142. }
  143. }
  144. if (para_node == nullptr) {
  145. auto node = top_graph->AddWeightParameter(param_name);
  146. node->set_default_param(value);
  147. // set_abstract for parameter
  148. auto abs = value->ToAbstract();
  149. // boarden value
  150. abs = abs->Broaden();
  151. node->set_abstract(abs);
  152. para_node = node;
  153. }
  154. return para_node;
  155. }
  156. void UpdataParam(const FuncGraphPtr &top_graph, const py::object &cell) {
  157. auto params = py::list(cell.attr("get_parameters")()).cast<std::vector<py::object>>();
  158. for (size_t i = 0; i < params.size(); i++) {
  159. (void)AppendParameterObj(top_graph, params[i]);
  160. }
  161. }
  162. void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &ast) {
  163. // check whether the functions refered by this function and itself are missing 'return' statement
  164. auto mng = Manage(fn, false);
  165. for (auto func_graph : mng->func_graphs()) {
  166. if (func_graph->get_return() != nullptr) {
  167. continue;
  168. }
  169. py::object node = ast->GetAstNode();
  170. py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  171. py::str desc =
  172. python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]);
  173. MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
  174. }
  175. // clear manager info after checking missing return
  176. for (auto fg : mng->func_graphs()) {
  177. fg->ClearAllManagerInfo();
  178. }
  179. }
  180. FuncGraphPtr Parser::ParseFuncGraph() {
  181. // get ast FunctionDef node
  182. py::object node = ast_->GetAstNode();
  183. FunctionBlockPtr pFnBlock = ParseFunction(node);
  184. if (errcode() != PARSE_SUCCESS) {
  185. MS_LOG(ERROR) << "Parse function error, code is " << errcode();
  186. return nullptr;
  187. }
  188. RemoveUnnecessaryPhis();
  189. MS_EXCEPTION_IF_NULL(pFnBlock);
  190. CheckFuncReturn(pFnBlock->func_graph(), ast_);
  191. return pFnBlock->func_graph();
  192. }
  193. void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
  194. py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args");
  195. py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg");
  196. block->func_graph()->set_has_vararg(!py::isinstance<py::none>(var_arg_node));
  197. py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg");
  198. block->func_graph()->set_has_kwarg(!py::isinstance<py::none>(kw_arg_node));
  199. py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs");
  200. block->func_graph()->set_kwonlyargs_count(SizeToInt(kwonly_args.size()));
  201. MS_EXCEPTION_IF_NULL(ast_);
  202. py::list args = ast_->GetArgs(fn_node);
  203. for (std::size_t i = 0; i < args.size(); i++) {
  204. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  205. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  206. if (arg_name == "self") {
  207. continue;
  208. }
  209. }
  210. TraceManager::DebugTrace(GetLocation(args[i]));
  211. auto para_node = std::make_shared<Parameter>(block->func_graph());
  212. MS_EXCEPTION_IF_NULL(para_node);
  213. TraceManager::EndTrace();
  214. para_node->set_name(arg_name);
  215. para_node->debug_info()->set_name(arg_name);
  216. block->func_graph()->add_parameter(para_node);
  217. AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block->func_graph(), para_node);
  218. block->WriteVariable(arg_name, para_after_cast);
  219. MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
  220. }
  221. }
  222. void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
  223. py::list defaults = ast_->GetArgsDefaultValues(fn_node);
  224. py::list args = ast_->GetArgs(fn_node);
  225. std::vector<std::string> namelist_for_default_value;
  226. std::vector<AnfNodePtr> default_values;
  227. for (std::size_t i = 0; i < args.size(); i++) {
  228. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  229. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  230. if (arg_name == "self") {
  231. continue;
  232. }
  233. }
  234. namelist_for_default_value.push_back(arg_name);
  235. if (py::isinstance<py::none>(defaults[i])) {
  236. default_values.push_back(NewValueNode(kNull));
  237. } else {
  238. default_values.push_back(ParseExprNode(block, defaults[i]));
  239. }
  240. }
  241. block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values);
  242. }
  243. ScopePtr Parser::GetScopeForParseFunction() {
  244. ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope();
  245. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  246. py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj());
  247. if (!py::isinstance<py::none>(scope_str)) {
  248. auto scope_name = py::cast<std::string>(scope_str);
  249. scope = std::make_shared<Scope>(scope_name);
  250. }
  251. }
  252. return scope;
  253. }
  254. FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) {
  255. ScopePtr scope = GetScopeForParseFunction();
  256. // the node created in the parsefunction context, will inherit the scope created using scope_guard
  257. ScopeGuard scope_guard(scope);
  258. TraceGuard trace_guard(data_converter::GetObjKey(ast()->obj())[0], GetLocation(node));
  259. FunctionBlockPtr pFunBlock = MakeFunctionBlock(*this);
  260. if (block != nullptr) {
  261. pFunBlock->AddPrevBlock(block);
  262. } else {
  263. func_graph_ = pFunBlock->func_graph();
  264. }
  265. pFunBlock->Mature();
  266. auto current_fg = pFunBlock->func_graph();
  267. auto function_name = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "name"));
  268. MS_LOG(DEBUG) << "The function name is " << function_name;
  269. current_fg->debug_info()->set_name(function_name);
  270. MS_EXCEPTION_IF_NULL(ast_);
  271. py::list deco_list = node.attr("decorator_list");
  272. if (deco_list.size() > 0) {
  273. current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
  274. }
  275. bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
  276. if (!ast_->obj().is(ast_->function())) {
  277. set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg);
  278. }
  279. if (!set_flag) {
  280. MS_LOG(ERROR) << "Set flags failed";
  281. return nullptr;
  282. }
  283. GenerateArgsNodeForFunction(pFunBlock, node);
  284. // when parsing the top graph of construct, save the top graph
  285. if (GetTopFuncGraph() == nullptr) {
  286. UpdateTopFuncGraph(pFunBlock->func_graph());
  287. }
  288. // save the function node to block
  289. pFunBlock->WriteVariable(function_name, NewValueNode(current_fg));
  290. py::object funcObj = python_adapter::GetPyObjAttr(node, "body");
  291. (void)ParseStatements(pFunBlock, funcObj);
  292. if (current_fg->get_return() == nullptr) {
  293. py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  294. py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, node, ret[0], ret[1]);
  295. MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
  296. }
  297. GenerateArgsDefaultValueForFunction(pFunBlock, node);
  298. return pFunBlock;
  299. }
  300. FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr fn_block, const py::object &nodes) {
  301. py::int_ pcount = python_adapter::CallPyObjMethod(nodes, "__len__");
  302. size_t count = IntToSize(pcount);
  303. MS_LOG(DEBUG) << "The nodes count is " << count;
  304. for (size_t i = 0; i < count; i++) {
  305. auto node = py::cast<py::list>(nodes)[i];
  306. TraceManager::DebugTrace(GetLocation(node));
  307. fn_block = ParseStatement(fn_block, node);
  308. TraceManager::EndTrace();
  309. // insert appropriate depended items for the function block if it has a return node
  310. if (fn_block->func_graph()->get_return() != nullptr) {
  311. fn_block->InsertDependItemsBeforeReturn();
  312. // Skip statements after 'return' (or 'break', 'continue').
  313. break;
  314. }
  315. }
  316. return fn_block;
  317. }
  318. FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) {
  319. auto node_type = ast_->GetNodeType(node);
  320. // check the node type
  321. AstMainType nodeType = node_type->main_type();
  322. if (nodeType != AST_MAIN_TYPE_STMT) {
  323. MS_LOG(INFO) << "Node type is error : " << nodeType;
  324. return block;
  325. }
  326. // call the process function
  327. std::string node_name = node_type->node_name();
  328. MS_LOG(DEBUG) << "Ast node is " << node_name;
  329. if (stmt_method_map_.count(node_name)) {
  330. TraceManager::DebugTrace(GetLocation(node));
  331. auto stmt_block = (this->*stmt_method_map_[node_name])(block, node);
  332. TraceManager::EndTrace();
  333. return stmt_block;
  334. } else {
  335. errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
  336. py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  337. if (location.size() < 2) {
  338. MS_LOG(EXCEPTION) << "List size should not be less than 2.";
  339. }
  340. auto filename = location[0].cast<std::string>();
  341. auto line_no = location[1].cast<int>();
  342. auto fn_loc = block->func_graph()->debug_info()->location();
  343. py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(),
  344. fn_loc->file_name(), fn_loc->line());
  345. MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no << " in "
  346. << desc.cast<std::string>() << ".";
  347. }
  348. }
  349. AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) {
  350. MS_LOG(DEBUG) << "Process ast expr";
  351. auto node_type = ast_->GetNodeType(node);
  352. // check the node type
  353. AstMainType node_main_type = node_type->main_type();
  354. if (node_main_type != AST_MAIN_TYPE_EXPR) {
  355. MS_LOG(ERROR) << "Node type is error : " << node_main_type;
  356. errcode_ = PARSE_NODE_TYPE_NO_MATCH;
  357. return nullptr;
  358. }
  359. // call the process function
  360. std::string node_name = node_type->node_name();
  361. MS_LOG(DEBUG) << "Ast node is " << node_name;
  362. if (expr_method_map_.count(node_name)) {
  363. TraceManager::DebugTrace(GetLocation(node));
  364. auto expr_node = (this->*expr_method_map_[node_name])(block, node);
  365. TraceManager::EndTrace();
  366. return expr_node;
  367. } else {
  368. errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
  369. py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  370. auto filename = ret[0].cast<std::string>();
  371. auto line_no = ret[1].cast<int>();
  372. auto fn_loc = block->func_graph()->debug_info()->location();
  373. py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(),
  374. fn_loc->file_name(), fn_loc->line());
  375. MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no << " in "
  376. << desc.cast<std::string>() << ".";
  377. }
  378. }
  379. // process the expr statement and expand it
  380. // eg: x.append(y) -> x = x.append(y)
  381. FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) {
  382. MS_LOG(DEBUG) << "Process ast Expr";
  383. // Expr only have value , no target
  384. py::tuple expand_info = ast_->CallParserObjMethod(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node);
  385. // refer python function expand_expr_statement, expand_info is one of the following:
  386. // True, expr.value, x
  387. // True, expr.value
  388. // False, None, None
  389. // check the expand info result
  390. auto is_expand = py::cast<bool>(expand_info[0]);
  391. if (is_expand) {
  392. // process the expr statement
  393. py::object value_object = expand_info[1];
  394. AnfNodePtr value_node = ParseExprNode(block, value_object);
  395. if (py::len(expand_info) == 2) {
  396. // add to depend list and insert before output
  397. block->AddAutoDepend(value_node);
  398. } else {
  399. // expand the assign statement
  400. py::object target_node = expand_info[2];
  401. WriteAssignVars(block, target_node, value_node);
  402. }
  403. }
  404. return block;
  405. }
  406. LocationPtr Parser::GetLocation(const py::object &node) const {
  407. MS_EXCEPTION_IF_NULL(ast_);
  408. py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  409. if (ret.size() < 5) {
  410. MS_LOG(EXCEPTION) << "List size should not be less than 5.";
  411. }
  412. // refer to Location::Location() for each member of ret: line, column, line_end, column_end.
  413. auto location = std::make_shared<Location>(ret[0].cast<std::string>(), ret[1].cast<int>(), ret[2].cast<int>(),
  414. ret[3].cast<int>(), ret[4].cast<int>());
  415. return location;
  416. }
  417. void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
  418. const FunctionBlockPtr &false_block) {
  419. true_block->AddPrevBlock(pre_block);
  420. true_block->Mature();
  421. false_block->AddPrevBlock(pre_block);
  422. false_block->Mature();
  423. }
  424. FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) {
  425. MS_LOG(DEBUG) << "Process ast return";
  426. MS_EXCEPTION_IF_NULL(block);
  427. // create return valuenode
  428. AnfNodePtr pReturnValueNode = NewValueNode(prim::kPrimReturn);
  429. // parse the return Statements value
  430. py::object value = python_adapter::GetPyObjAttr(node, "value");
  431. AnfNodePtr pReturnStatementNode = ParseExprNode(block, value);
  432. // Create the cnode
  433. CNodePtr pReturnCNode = block->func_graph()->NewCNode({pReturnValueNode, pReturnStatementNode});
  434. block->func_graph()->set_return(pReturnCNode);
  435. return block;
  436. }
  437. // Process binary operators,eg: `a + b`, `a | b`, etc.
  438. AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) {
  439. MS_LOG(DEBUG) << "Process ast BinOP";
  440. py::object left = python_adapter::GetPyObjAttr(node, "left");
  441. py::object right = python_adapter::GetPyObjAttr(node, "right");
  442. py::object op = python_adapter::GetPyObjAttr(node, "op");
  443. // create left and right ANF node
  444. AnfNodePtr left_node = ParseExprNode(block, left);
  445. if (left_node == nullptr) {
  446. MS_LOG(WARNING) << "DoBinOp process left node failed: " << errcode();
  447. return nullptr;
  448. }
  449. AnfNodePtr right_node = ParseExprNode(block, right);
  450. if (right_node == nullptr) {
  451. MS_LOG(WARNING) << "DoBinOp process right node failed:" << errcode();
  452. return nullptr;
  453. }
  454. // resolve the op
  455. AnfNodePtr op_node = block->MakeResolveAstOp(op);
  456. // create apply node
  457. return block->func_graph()->NewCNode({op_node, left_node, right_node});
  458. }
  459. AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) {
  460. MS_LOG(DEBUG) << "Process ast Name";
  461. auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
  462. MS_LOG(DEBUG) << "The Name id is " << name_id;
  463. TraceGuard trace_guard(GetLocation(node));
  464. if (block->IsGlobalVar(name_id)) {
  465. return block->MakeResolveSymbol(name_id);
  466. }
  467. return block->ReadVariable(name_id);
  468. }
  469. AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) {
  470. MS_LOG(DEBUG) << "Process ast NoneType";
  471. return NewValueNode(kNone);
  472. }
  473. AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) {
  474. MS_LOG(DEBUG) << "Process ast Ellipsis";
  475. return NewValueNode(kEllipsis);
  476. }
  477. AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
  478. MS_LOG(DEBUG) << "Process ast Num";
  479. py::object obj = python_adapter::GetPyObjAttr(node, "n");
  480. TraceGuard trace_guard(GetLocation(node));
  481. if (py::isinstance<py::int_>(obj)) {
  482. MS_LOG(INFO) << "The Num is int:" << (std::string)py::str(obj);
  483. auto data = py::cast<int>(obj);
  484. return NewValueNode(data);
  485. } else if (py::isinstance<py::float_>(obj)) {
  486. MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj);
  487. auto data = py::cast<float>(obj);
  488. return NewValueNode(data);
  489. } else {
  490. // no else actually
  491. MS_LOG(ERROR) << "Unsupported Num type : " << (std::string)py::str(obj) << GetLocation(node)->ToString();
  492. errcode_ = PARSE_NODE_TYPE_UNKOWN;
  493. return nullptr;
  494. }
  495. }
  496. AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) {
  497. MS_LOG(DEBUG) << "Process ast Str";
  498. auto str_s = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "s"));
  499. return NewValueNode(str_s);
  500. }
  501. AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) {
  502. MS_LOG(DEBUG) << "Process ast NameConstant";
  503. py::object obj = python_adapter::GetPyObjAttr(node, "value");
  504. TraceGuard trace_guard(GetLocation(node));
  505. if (py::isinstance<py::bool_>(obj)) {
  506. MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj);
  507. auto data = py::cast<bool>(obj);
  508. return NewValueNode(data);
  509. } else if (py::isinstance<py::none>(obj)) {
  510. MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj);
  511. return NewValueNode(kNone);
  512. } else {
  513. // no else actually
  514. MS_LOG(ERROR) << "Unsupported NameConstant type: " << (std::string)py::str(obj) << GetLocation(node)->ToString();
  515. errcode_ = PARSE_NODE_TYPE_UNKOWN;
  516. return nullptr;
  517. }
  518. }
  519. AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes) {
  520. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  521. std::vector<AnfNodePtr> make_tuple_nodes;
  522. make_tuple_nodes.push_back(make_tuple_op);
  523. (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes),
  524. [](AnfNodePtr arg) -> AnfNodePtr { return arg; });
  525. return block->func_graph()->NewCNode(make_tuple_nodes);
  526. }
  527. AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) {
  528. py::object father_class;
  529. if (args.empty()) {
  530. father_class = py::none();
  531. } else if (args.size() == 2) {
  532. father_class = args[0];
  533. auto arg_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[1])));
  534. if (arg_type != AST_SUB_TYPE_NAME || py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) != "self") {
  535. MS_EXCEPTION(ArgumentError) << "When call 'super', the second arg should be 'self'.";
  536. }
  537. } else {
  538. MS_EXCEPTION(ArgumentError) << "When call 'super', the args number should be 0 or 2, but got" << args.size() << ".";
  539. }
  540. py::object target_class_instance = ast()->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast()->obj());
  541. py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
  542. NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
  543. SymbolPtr symbol = std::make_shared<Symbol>("namespace");
  544. return block->MakeResolve(name_space, symbol);
  545. }
  546. // process function call, eg : f1(x, y) ...
  547. AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) {
  548. MS_LOG(DEBUG) << "Process ast Call";
  549. // process function call
  550. py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func");
  551. py::list args = python_adapter::GetPyObjAttr(node, "args");
  552. auto arg_type =
  553. AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, function_ast_node)));
  554. if (arg_type == AST_SUB_TYPE_NAME) {
  555. auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id"));
  556. if (name_id == "super") {
  557. return ParseSuper(block, args);
  558. }
  559. }
  560. AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node);
  561. // function call arguments should be passed in as groups and unpacked later using unpack call
  562. std::vector<AnfNodePtr> packed_arguments;
  563. std::vector<AnfNodePtr> group_arguments;
  564. bool need_unpack_args = ParseArgsInCall(block, args, &packed_arguments, &group_arguments);
  565. bool need_unpack_keywords = ParseKeywordsInCall(block, node, &packed_arguments);
  566. // if there is stared or keyword argument, unpack may be needed
  567. bool need_unpack = need_unpack_args || need_unpack_keywords;
  568. return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack);
  569. }
  570. CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_anf_node,
  571. const std::vector<AnfNodePtr> &packed_arguments) {
  572. std::vector<AnfNodePtr> unpack_call_nodes;
  573. auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
  574. unpack_call_nodes.push_back(unpack_call_op);
  575. unpack_call_nodes.push_back(call_function_anf_node);
  576. (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes),
  577. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  578. CNodePtr unpack_call = func_graph->NewCNode(unpack_call_nodes);
  579. return unpack_call;
  580. }
  581. AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
  582. const std::vector<AnfNodePtr> &packed_arguments,
  583. const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
  584. // if there is keyword arguments or starred, using an unpack_call op to unpack the argument
  585. if (need_unpack) {
  586. return MakeUnpackCall(block->func_graph(), call_function_anf_node, packed_arguments);
  587. }
  588. // else there is no keyword arguments and starred, parsed as normal arguments without unpack
  589. std::vector<AnfNodePtr> func_call_nodes;
  590. func_call_nodes.push_back(call_function_anf_node);
  591. (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes),
  592. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  593. CNodePtr call_anf_node = block->func_graph()->NewCNode(func_call_nodes);
  594. return call_anf_node;
  595. }
  596. bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args,
  597. std::vector<AnfNodePtr> *packed_arguments, std::vector<AnfNodePtr> *group_arguments) {
  598. bool need_unpack = false;
  599. for (size_t i = 0; i < args.size(); i++) {
  600. auto arg_node = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[i])));
  601. if (arg_node == AST_SUB_TYPE_STARRED) {
  602. if (!group_arguments->empty()) {
  603. packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments));
  604. }
  605. packed_arguments->push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value")));
  606. group_arguments->clear();
  607. need_unpack = true;
  608. } else {
  609. group_arguments->push_back(ParseExprNode(block, args[i]));
  610. }
  611. }
  612. if (!group_arguments->empty()) {
  613. packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments));
  614. }
  615. return need_unpack;
  616. }
  617. bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node,
  618. std::vector<AnfNodePtr> *packed_arguments) {
  619. bool need_unpack = false;
  620. py::list keywords = python_adapter::GetPyObjAttr(node, "keywords");
  621. if (!keywords.empty()) {
  622. need_unpack = true;
  623. std::vector<AnfNodePtr> keys;
  624. std::vector<AnfNodePtr> values;
  625. for (size_t index = 0; index < keywords.size(); index++) {
  626. auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg");
  627. auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value");
  628. if (py::isinstance<py::none>(kw_key)) {
  629. packed_arguments->push_back(ParseExprNode(block, kw_value));
  630. } else {
  631. auto kw_key_c = kw_key.cast<std::string>();
  632. keys.push_back(NewValueNode(kw_key_c));
  633. values.push_back(ParseExprNode(block, kw_value));
  634. }
  635. }
  636. auto keys_tuple = GenerateMakeTuple(block, keys);
  637. auto values_tuple = GenerateMakeTuple(block, values);
  638. auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
  639. std::vector<AnfNodePtr> make_dict_nodes;
  640. make_dict_nodes.push_back(make_dict_op);
  641. make_dict_nodes.push_back(keys_tuple);
  642. make_dict_nodes.push_back(values_tuple);
  643. packed_arguments->push_back(block->func_graph()->NewCNode(make_dict_nodes));
  644. }
  645. return need_unpack;
  646. }
  647. // process call attributes of class type define, eg: x.y()
  648. AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) {
  649. MS_LOG(DEBUG) << "Process ast Attribute";
  650. // process class value,eg: self.xx
  651. if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
  652. if (ast_->IsClassMember(node)) {
  653. std::string var_name = "self.";
  654. std::string attr_name = node.attr("attr").cast<std::string>();
  655. (void)var_name.append(attr_name);
  656. auto attr_obj = ast()->obj().attr(attr_name.c_str());
  657. if (py::hasattr(ast()->obj(), attr_name.c_str()) &&
  658. (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance<py::int_>(attr_obj) ||
  659. py::isinstance<py::float_>(attr_obj) || py::isinstance<py::bool_>(attr_obj) ||
  660. py::isinstance<py::str>(attr_obj) || data_converter::IsCellInstance(attr_obj))) {
  661. return block->MakeResolveSymbol(var_name);
  662. } else {
  663. return block->ReadVariable(var_name);
  664. }
  665. }
  666. }
  667. // process the get attr
  668. // Use the Primitive replace the operation resolve node (getattr)
  669. // because the getattr will eventually be converted to Primitive node
  670. AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr);
  671. // process the attr body
  672. py::object value_body = python_adapter::GetPyObjAttr(node, "value");
  673. AnfNodePtr value_node = ParseExprNode(block, value_body);
  674. if (value_node == nullptr) {
  675. MS_LOG(WARNING) << "Parse attribute failed";
  676. return nullptr;
  677. }
  678. // process the node attr
  679. auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast<std::string>();
  680. MS_LOG(DEBUG) << "Attr = " << attr_str;
  681. TraceManager::DebugTrace(GetLocation(python_adapter::GetPyObjAttr(node, "attr")));
  682. AnfNodePtr attr_node = NewValueNode(attr_str);
  683. TraceManager::EndTrace();
  684. // create the apply node
  685. return block->func_graph()->NewCNode({op_node, value_node, attr_node});
  686. }
  687. // Process comparison expression : a == b. a > b etc.
  688. AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) {
  689. MS_LOG(DEBUG) << "Process ast Compare";
  690. // for python comparison ,there may be if x>y>5 ,
  691. // which there is two ops , but we only support one now
  692. py::list ops = python_adapter::GetPyObjAttr(node, "ops");
  693. if (ops.size() > MAX_COMPARISON_OPS_SUPPORTED) {
  694. MS_LOG(ERROR) << "MindSpore does not support comparison with operators more than one now, ops size =" << ops.size();
  695. return nullptr;
  696. }
  697. py::object left = python_adapter::GetPyObjAttr(node, "left");
  698. py::list comparators = python_adapter::GetPyObjAttr(node, "comparators");
  699. AnfNodePtr left_node = ParseExprNode(block, left);
  700. AnfNodePtr right_node = ParseExprNode(block, comparators[0]);
  701. MS_EXCEPTION_IF_NULL(block);
  702. AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]);
  703. return block->func_graph()->NewCNode({op_node, left_node, right_node});
  704. }
  705. AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) {
  706. // if there is only one bool op now
  707. if (value_list.size() == 1) {
  708. AnfNodePtr first_node = ParseExprNode(block, value_list[0]);
  709. return first_node;
  710. } else {
  711. py::object first = value_list[0];
  712. py::list rest;
  713. for (size_t i = 1; i < value_list.size(); i++) {
  714. rest.append(value_list[i]);
  715. }
  716. MS_EXCEPTION_IF_NULL(block);
  717. TraceManager::DebugTrace(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
  718. FunctionBlockPtr true_block = MakeFunctionBlock(*this);
  719. TraceManager::EndTrace();
  720. TraceManager::DebugTrace(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
  721. FunctionBlockPtr false_block = MakeFunctionBlock(*this);
  722. TraceManager::EndTrace();
  723. MakeConditionBlocks(block, true_block, false_block);
  724. FunctionBlockPtr b1, b2;
  725. // if it is and, we need to process the rest nodes;
  726. // if it is or, we continue to next
  727. if (mode == AST_SUB_TYPE_AND) {
  728. b1 = true_block;
  729. b2 = false_block;
  730. } else if (mode == AST_SUB_TYPE_OR) {
  731. b2 = true_block;
  732. b1 = false_block;
  733. } else {
  734. MS_LOG(ERROR) << "Not supported mode: " << mode;
  735. return nullptr;
  736. }
  737. AnfNodePtr test_node = ParseExprNode(block, first);
  738. AnfNodePtr rest_node = ProcessBoolOpValueList(b1, rest, mode);
  739. b1->func_graph()->set_output(rest_node);
  740. b2->func_graph()->set_output(test_node);
  741. auto cond_node = block->ForceToBoolNode(test_node);
  742. auto switch_app =
  743. block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()),
  744. NewValueNode(false_block->func_graph())});
  745. std::vector<AnfNodePtr> call_graph_nodes{switch_app};
  746. auto switch_app_call = block->func_graph()->NewCNode(call_graph_nodes);
  747. return switch_app_call;
  748. }
  749. }
  750. // Process comparison expression : a and b. a or b .
  751. AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) {
  752. MS_LOG(DEBUG) << "Process ast BoolOp";
  753. py::object op_node = python_adapter::GetPyObjAttr(node, "op");
  754. AstSubType op_type = ast_->GetOpType(op_node);
  755. if (op_type == AST_SUB_TYPE_UNKNOWN) {
  756. MS_LOG(WARNING) << "ProcessBoolOp, got unkown op type";
  757. return nullptr;
  758. }
  759. py::list op_values = python_adapter::GetPyObjAttr(node, "values");
  760. return ProcessBoolOpValueList(block, op_values, op_type);
  761. }
  762. // Process a function def
  763. FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
  764. MS_LOG(DEBUG) << "Process ast FunctionDef";
  765. FunctionBlockPtr function_block = ParseFunction(node, block);
  766. MS_EXCEPTION_IF_NULL(function_block);
  767. // get function name
  768. py::str name = python_adapter::GetPyObjAttr(node, "name");
  769. std::string function_name = name;
  770. ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph());
  771. block->WriteVariable(function_name, valuenode_graph);
  772. return block;
  773. }
  774. // Process a lambda expression . like lambda x,y: x + y
  775. AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
  776. MS_LOG(DEBUG) << "Process ast Lambda";
  777. FunctionBlockPtr func_block = MakeFunctionBlock(*this);
  778. func_block->AddPrevBlock(block);
  779. func_block->Mature();
  780. // get lambda args
  781. py::list args = ast_->GetArgs(node);
  782. for (std::size_t i = 0; i < args.size(); i++) {
  783. std::string arg = py::cast<std::string>(args[i].attr("arg"));
  784. TraceManager::DebugTrace(GetLocation(args[i]));
  785. auto para_node = std::make_shared<Parameter>(func_block->func_graph());
  786. TraceManager::EndTrace();
  787. para_node->debug_info()->set_name(arg);
  788. func_block->func_graph()->add_parameter(para_node);
  789. func_block->WriteVariable(arg, para_node);
  790. MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg;
  791. }
  792. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  793. AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
  794. func_block->func_graph()->set_output(lambda_body_node);
  795. ValueNodePtr const_graph = NewValueNode(func_block->func_graph());
  796. return const_graph;
  797. }
  798. // process a tuple
  799. AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) {
  800. MS_LOG(DEBUG) << "Process ast Tuple";
  801. MS_EXCEPTION_IF_NULL(block);
  802. py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
  803. if (elts.size() == 0) {
  804. auto empty_tuple = std::vector<ValuePtr>();
  805. return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
  806. }
  807. std::vector<AnfNodePtr> tuple_vec;
  808. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  809. tuple_vec.emplace_back(make_tuple_op);
  810. for (size_t i = 0; i < elts.size(); i++) {
  811. AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
  812. tuple_vec.emplace_back(node_ptr);
  813. }
  814. CNodePtr tuple_app = block->func_graph()->NewCNode(tuple_vec);
  815. return tuple_app;
  816. }
  817. // process a list
  818. AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) {
  819. MS_LOG(DEBUG) << "Process ast List";
  820. MS_EXCEPTION_IF_NULL(block);
  821. py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
  822. if (elts.size() == 0) {
  823. auto empty_list = std::vector<ValuePtr>();
  824. return NewValueNode(std::make_shared<ValueList>(empty_list));
  825. }
  826. std::vector<AnfNodePtr> list_vec;
  827. AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
  828. list_vec.emplace_back(make_list_op);
  829. for (size_t i = 0; i < elts.size(); i++) {
  830. AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
  831. list_vec.emplace_back(node_ptr);
  832. }
  833. CNodePtr list_app = block->func_graph()->NewCNode(list_vec);
  834. return list_app;
  835. }
  836. // process a subscript, such as x[y] , node expressed as value[slice]
  837. AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) {
  838. MS_LOG(DEBUG) << "Process ast Subscript";
  839. MS_EXCEPTION_IF_NULL(block);
  840. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  841. py::object value_node = python_adapter::GetPyObjAttr(node, "value");
  842. py::object slice_node = python_adapter::GetPyObjAttr(node, "slice");
  843. AnfNodePtr value = ParseExprNode(block, value_node);
  844. AnfNodePtr slice = ParseExprNode(block, slice_node);
  845. return block->func_graph()->NewCNode({op_getitem, value, slice});
  846. }
  847. // process a slice, get the slice value
  848. AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) {
  849. MS_LOG(DEBUG) << "Process ast Slice";
  850. MS_EXCEPTION_IF_NULL(block);
  851. AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE);
  852. py::object start = python_adapter::GetPyObjAttr(node, "lower");
  853. py::object stop = python_adapter::GetPyObjAttr(node, "upper");
  854. py::object step = python_adapter::GetPyObjAttr(node, "step");
  855. AnfNodePtr start_node = ParseExprNode(block, start);
  856. AnfNodePtr stop_node = ParseExprNode(block, stop);
  857. AnfNodePtr step_node = ParseExprNode(block, step);
  858. return block->func_graph()->NewCNode({op_makeslice, start_node, stop_node, step_node});
  859. }
  860. // process a extslice
  861. AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) {
  862. MS_LOG(DEBUG) << "Process ast ExtSlice";
  863. MS_EXCEPTION_IF_NULL(block);
  864. AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
  865. py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims");
  866. std::vector<AnfNodePtr> node_vec;
  867. node_vec.emplace_back(make_tuple_op);
  868. for (size_t i = 0; i < slice_tuple.size(); i++) {
  869. AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]);
  870. node_vec.emplace_back(node_ptr);
  871. }
  872. CNodePtr tuple_conde = block->func_graph()->NewCNode(node_vec);
  873. return tuple_conde;
  874. }
  875. // process a index, get the index number
  876. AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) {
  877. MS_LOG(DEBUG) << "Process ast Index";
  878. py::object value_node = python_adapter::GetPyObjAttr(node, "value");
  879. return ParseExprNode(block, value_node);
  880. }
  881. // process a UnaryOp, +a, -b
  882. AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) {
  883. MS_LOG(DEBUG) << "Process ast UnaryOp";
  884. py::object op = python_adapter::GetPyObjAttr(node, "op");
  885. MS_EXCEPTION_IF_NULL(block);
  886. // resolve the op
  887. AnfNodePtr op_node = block->MakeResolveAstOp(op);
  888. py::object operand = python_adapter::GetPyObjAttr(node, "operand");
  889. AnfNodePtr operand_node = ParseExprNode(block, operand);
  890. return block->func_graph()->NewCNode({op_node, operand_node});
  891. }
  892. // process a dict ast node expression
  893. AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
  894. MS_LOG(DEBUG) << "Process ast Dict";
  895. py::list keys = node.attr("keys");
  896. py::list values = node.attr("values");
  897. std::vector<AnfNodePtr> key_nodes;
  898. std::vector<AnfNodePtr> value_nodes;
  899. for (size_t i = 0; i < keys.size(); i++) {
  900. key_nodes.push_back(ParseExprNode(block, keys[i]));
  901. value_nodes.push_back(ParseExprNode(block, values[i]));
  902. }
  903. auto keys_tuple = GenerateMakeTuple(block, key_nodes);
  904. auto values_tuple = GenerateMakeTuple(block, value_nodes);
  905. MS_EXCEPTION_IF_NULL(block);
  906. auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
  907. return block->func_graph()->NewCNode({make_dict_op, keys_tuple, values_tuple});
  908. }
  909. // process a augment assign such as a += b;
  910. FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) {
  911. MS_LOG(DEBUG) << "Process ast AugAssign";
  912. py::object op = python_adapter::GetPyObjAttr(node, "op");
  913. MS_EXCEPTION_IF_NULL(block);
  914. // resolve the op
  915. AnfNodePtr op_node = block->MakeResolveAstOp(op);
  916. py::object target_node = python_adapter::GetPyObjAttr(node, "target");
  917. MS_EXCEPTION_IF_NULL(ast_);
  918. auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, target_node)));
  919. AnfNodePtr read_node = nullptr;
  920. if (ast_type == AST_SUB_TYPE_NAME) {
  921. read_node = ParseName(block, target_node);
  922. } else if (ast_->IsClassMember(target_node)) {
  923. read_node = ParseAttribute(block, target_node);
  924. } else {
  925. MS_LOG(EXCEPTION) << "Not supported augassign";
  926. }
  927. if (read_node == nullptr) {
  928. MS_LOG(EXCEPTION) << "Can not get target node ";
  929. }
  930. py::object value = python_adapter::GetPyObjAttr(node, "value");
  931. AnfNodePtr value_node = ParseExprNode(block, value);
  932. CNodePtr augassign_app = block->func_graph()->NewCNode({op_node, read_node, value_node});
  933. WriteAssignVars(block, target_node, augassign_app);
  934. return block;
  935. }
  936. // process global declaration such as 'global x';
  937. FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) {
  938. MS_LOG(DEBUG) << "Process ast Global";
  939. MS_EXCEPTION_IF_NULL(block);
  940. py::list vars = python_adapter::GetPyObjAttr(node, "names");
  941. for (auto &item : vars) {
  942. block->AddGlobalVar(py::cast<std::string>(item));
  943. }
  944. return block;
  945. }
  946. // process a if statement
  947. FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) {
  948. MS_LOG(DEBUG) << "Process ast If";
  949. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  950. AnfNodePtr condition_node = ParseExprNode(block, test_node);
  951. MS_EXCEPTION_IF_NULL(block);
  952. CNodePtr bool_node = block->ForceToBoolNode(condition_node);
  953. TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
  954. FunctionBlockPtr true_block = MakeFunctionBlock(*this);
  955. TraceManager::EndTrace();
  956. TraceManager::DebugTrace(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
  957. FunctionBlockPtr false_block = MakeFunctionBlock(*this);
  958. TraceManager::EndTrace();
  959. MakeConditionBlocks(block, true_block, false_block);
  960. TraceManager::DebugTrace(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
  961. FunctionBlockPtr after_block = MakeFunctionBlock(*this);
  962. TraceManager::EndTrace();
  963. if (MsContext::GetInstance()->backend_policy() != "ge") {
  964. // for backends excludes 'ge', it can handle multi graph call, use this flag to
  965. // generate call not inline `after_block` graph to reduce if by if switch expansion.
  966. after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
  967. }
  968. // process the if-true branch
  969. py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
  970. FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode);
  971. // if the return_ is set ,it has its own continuation block
  972. if (true_end->func_graph()->get_return() == nullptr) {
  973. true_end->Jump(after_block, nullptr);
  974. }
  975. // process the orelse branch
  976. py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
  977. FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode);
  978. // if the return_ is set ,it has its own continuation block
  979. if (false_end->func_graph()->get_return() == nullptr) {
  980. false_end->Jump(after_block, nullptr);
  981. }
  982. block->ConditionalJump(bool_node, true_block, false_block);
  983. after_block->Mature();
  984. return after_block;
  985. }
  986. FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) {
  987. MS_LOG(DEBUG) << "Process ast While";
  988. MS_EXCEPTION_IF_NULL(block);
  989. MS_LOG(INFO) << "Parse while statement";
  990. TraceManager::DebugTrace(std::make_shared<TraceWhileHeader>(block->func_graph()->debug_info()));
  991. FunctionBlockPtr header_block = MakeFunctionBlock(*this);
  992. TraceManager::EndTrace();
  993. TraceManager::DebugTrace(std::make_shared<TraceWhileBody>(block->func_graph()->debug_info()));
  994. FunctionBlockPtr body_block = MakeFunctionBlock(*this);
  995. TraceManager::EndTrace();
  996. TraceManager::DebugTrace(std::make_shared<TraceWhileAfter>(block->func_graph()->debug_info()));
  997. FunctionBlockPtr after_block = MakeFunctionBlock(*this);
  998. TraceManager::EndTrace();
  999. body_block->AddPrevBlock(header_block);
  1000. after_block->AddPrevBlock(header_block);
  1001. block->Jump(header_block, nullptr);
  1002. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  1003. AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
  1004. condition_node = header_block->ForceToWhileCond(condition_node);
  1005. body_block->Mature();
  1006. header_block->ConditionalJump(condition_node, body_block, after_block);
  1007. // Parse loop body statements with loop context.
  1008. LoopContext loop_context{&loops_, header_block, nullptr};
  1009. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1010. FunctionBlockPtr after_body = ParseStatements(body_block, body_node);
  1011. if (after_body->func_graph()->get_return() == nullptr) {
  1012. after_body->Jump(header_block, nullptr);
  1013. }
  1014. header_block->Mature();
  1015. after_block->Mature();
  1016. auto &end_block = loop_context.EndBlock();
  1017. if (end_block) {
  1018. // end_block exists if we encounter 'break' in loop body.
  1019. after_block->Jump(end_block, nullptr);
  1020. end_block->Mature();
  1021. return end_block;
  1022. }
  1023. // No 'break', no end_block.
  1024. return after_block;
  1025. }
  1026. CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::object &node,
  1027. const AnfNodePtr &op_iter) {
  1028. py::object iter_node = python_adapter::GetPyObjAttr(node, "iter");
  1029. AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
  1030. return block->func_graph()->NewCNode({op_iter, iter_anf_node});
  1031. }
  1032. CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
  1033. const AnfNodePtr &op_hasnext) {
  1034. MS_EXCEPTION_IF_NULL(header_block);
  1035. return header_block->func_graph()->NewCNode({op_hasnext, iter_param});
  1036. }
  1037. FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
  1038. TraceManager::DebugTrace(trace_info);
  1039. FunctionBlockPtr body_block = MakeFunctionBlock(*this);
  1040. TraceManager::EndTrace();
  1041. return body_block;
  1042. }
  1043. // A for loop will generate 3 functions :the test, the body, and the continuation
  1044. // for x in xs:
  1045. // body
  1046. // it is compiled to be following statement
  1047. // if len(xs) < max_loop_cnt:
  1048. // ParseForIter() // use iter to implement for loop, which always unroll loop
  1049. // else:
  1050. // ParseForLoop() // use loop var to implement for loop, which always sink loop
  1051. FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
  1052. MS_LOG(DEBUG) << "Process ast For, create an if else statement";
  1053. MS_EXCEPTION_IF_NULL(block);
  1054. // create statement 'len(xs) < MAX_FOR_LOOP_COUNT'
  1055. AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
  1056. py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
  1057. AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
  1058. CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
  1059. CNodePtr bool_node =
  1060. block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(MAX_FOR_LOOP_COUNT)});
  1061. // create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
  1062. TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
  1063. FunctionBlockPtr true_block = MakeFunctionBlock(*this);
  1064. TraceManager::EndTrace();
  1065. TraceManager::DebugTrace(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
  1066. FunctionBlockPtr false_block = MakeFunctionBlock(*this);
  1067. TraceManager::EndTrace();
  1068. MakeConditionBlocks(block, true_block, false_block);
  1069. TraceManager::DebugTrace(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
  1070. FunctionBlockPtr after_block = MakeFunctionBlock(*this);
  1071. TraceManager::EndTrace();
  1072. FunctionBlockPtr true_end = ParseForIter(true_block, node);
  1073. true_end->Jump(after_block, nullptr);
  1074. FunctionBlockPtr false_end = ParseForLoop(false_block, node);
  1075. false_end->Jump(after_block, nullptr);
  1076. block->ConditionalJump(bool_node, true_block, false_block);
  1077. after_block->Mature();
  1078. return after_block;
  1079. }
  1080. // A for loop will generate 3 functions :the test, the body, and the continuation
  1081. // for x in xs:
  1082. // body
  1083. // it is compiled to be following statement
  1084. // it = iter(xs)
  1085. // while hastnext(it)
  1086. // x, it = next(it)
  1087. // body
  1088. FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) {
  1089. MS_LOG(DEBUG) << "Process ast For";
  1090. MS_EXCEPTION_IF_NULL(block);
  1091. AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
  1092. AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
  1093. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1094. AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
  1095. // generate the iterator apply
  1096. CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter);
  1097. MS_EXCEPTION_IF_NULL(iter_apply);
  1098. FunctionBlockPtr header_block =
  1099. GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
  1100. MS_EXCEPTION_IF_NULL(header_block);
  1101. // generate the hasnext apply which is a condition
  1102. ParameterPtr iter_param = header_block->func_graph()->add_parameter();
  1103. CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext);
  1104. // generate the body of the for statement
  1105. FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
  1106. MS_EXCEPTION_IF_NULL(body_block);
  1107. body_block->AddPrevBlock(header_block);
  1108. // generate the iterator next apply
  1109. // process as following: `app = next(it); target = app[0]; it = app[1];`
  1110. CNodePtr app = body_block->func_graph()->NewCNode({op_next, iter_param});
  1111. CNodePtr target_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(0)});
  1112. py::object target_node = python_adapter::GetPyObjAttr(node, "target");
  1113. CNodePtr iter2_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(1)});
  1114. WriteAssignVars(body_block, target_node, target_app);
  1115. // link the variable name with the target
  1116. auto it_info = std::make_shared<TraceIterator>(target_app->debug_info());
  1117. iter_param->debug_info()->set_trace_info(it_info);
  1118. iter2_app->debug_info()->set_trace_info(it_info);
  1119. iter_apply->debug_info()->set_trace_info(it_info);
  1120. TraceManager::DebugTrace(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
  1121. FunctionBlockPtr after_block = MakeFunctionBlock(*this);
  1122. MS_EXCEPTION_IF_NULL(after_block);
  1123. TraceManager::EndTrace();
  1124. after_block->AddPrevBlock(header_block);
  1125. block->Jump(header_block, iter_apply);
  1126. body_block->Mature();
  1127. header_block->ConditionalJump(cond_apply, body_block, after_block);
  1128. // Parse loop body statements with loop context.
  1129. LoopContext loop_context{&loops_, header_block, iter2_app};
  1130. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1131. FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
  1132. if (after_body_block->func_graph()->get_return() == nullptr) {
  1133. after_body_block->Jump(header_block, iter2_app);
  1134. }
  1135. header_block->Mature();
  1136. after_block->Mature();
  1137. auto &end_block = loop_context.EndBlock();
  1138. if (end_block) {
  1139. // end_block exists if we encounter 'break' in loop body.
  1140. after_block->Jump(end_block, nullptr);
  1141. end_block->Mature();
  1142. return end_block;
  1143. }
  1144. // No 'break', no end_block.
  1145. return after_block;
  1146. }
  1147. // A for loop will generate 3 functions :the test, the body, and the continuation
  1148. // for x in xs:
  1149. // body
  1150. // it is compiled to be following statement
  1151. // i = 0
  1152. // while i < len(xs)
  1153. // x = xs[i]
  1154. // i = i + 1
  1155. // body
  1156. FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) {
  1157. MS_LOG(DEBUG) << "Process ast For by loop variable";
  1158. MS_EXCEPTION_IF_NULL(block);
  1159. AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
  1160. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1161. // get varibale name of 'x' in statement 'for x in xs'
  1162. py::object target_node = python_adapter::GetPyObjAttr(node, "target");
  1163. // create statement 'len(xs)'
  1164. py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
  1165. AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
  1166. MS_EXCEPTION_IF_NULL(iter_node);
  1167. // Generate node for loop count and convert it to tensor, to make the loop not unroll
  1168. CNodePtr scalar_len = block->func_graph()->NewCNode({op_len, iter_node});
  1169. auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations");
  1170. auto scalar_to_tensor_node = block->func_graph()->NewCNode({NewValueNode(scalar_to_tensor)});
  1171. CNodePtr len_iter = block->func_graph()->NewCNode({scalar_to_tensor_node, scalar_len});
  1172. FunctionBlockPtr header_block =
  1173. GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
  1174. MS_EXCEPTION_IF_NULL(header_block);
  1175. // create loop variable 'i'
  1176. ParameterPtr loop_var = header_block->func_graph()->add_parameter();
  1177. // create loop condition 'i < len(xs)'
  1178. auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations");
  1179. auto less_node = header_block->func_graph()->NewCNode({NewValueNode(prim_less)});
  1180. CNodePtr cond_node = header_block->func_graph()->NewCNode({less_node, loop_var, len_iter});
  1181. // generate the body of the for statement
  1182. FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
  1183. MS_EXCEPTION_IF_NULL(body_block);
  1184. body_block->AddPrevBlock(header_block);
  1185. // create 'x = xs[i]'
  1186. CNodePtr target_var = body_block->func_graph()->NewCNode({op_getitem, iter_node, loop_var});
  1187. WriteAssignVars(body_block, target_node, target_var);
  1188. // create 'i = i + 1'
  1189. CNodePtr loop_var_inc =
  1190. body_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(1)});
  1191. body_block->WriteVariable(loop_var->name(), loop_var_inc);
  1192. // link the variable name with the target
  1193. auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
  1194. loop_var->debug_info()->set_trace_info(it_info);
  1195. len_iter->debug_info()->set_trace_info(it_info);
  1196. TraceManager::DebugTrace(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
  1197. FunctionBlockPtr after_block = MakeFunctionBlock(*this);
  1198. MS_EXCEPTION_IF_NULL(after_block);
  1199. TraceManager::EndTrace();
  1200. after_block->AddPrevBlock(header_block);
  1201. block->Jump(header_block, NewValueNode(0));
  1202. body_block->Mature();
  1203. header_block->ConditionalJump(cond_node, body_block, after_block, false);
  1204. // Parse loop body statements with loop context.
  1205. LoopContext loop_context{&loops_, header_block, loop_var_inc};
  1206. py::object body_node = python_adapter::GetPyObjAttr(node, "body");
  1207. FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
  1208. if (after_body_block->func_graph()->get_return() == nullptr) {
  1209. after_body_block->Jump(header_block, loop_var_inc);
  1210. }
  1211. header_block->Mature();
  1212. after_block->Mature();
  1213. auto &end_block = loop_context.EndBlock();
  1214. if (end_block) {
  1215. // end_block exists if we encounter 'break' in loop body.
  1216. after_block->Jump(end_block, nullptr);
  1217. end_block->Mature();
  1218. return end_block;
  1219. }
  1220. // No 'break', no end_block.
  1221. return after_block;
  1222. }
  1223. AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) {
  1224. MS_LOG(DEBUG) << "Process ast IfExp";
  1225. MS_EXCEPTION_IF_NULL(block);
  1226. py::object test_node = python_adapter::GetPyObjAttr(node, "test");
  1227. AnfNodePtr condition_node = ParseExprNode(block, test_node);
  1228. CNodePtr bool_node = block->ForceToBoolNode(condition_node);
  1229. TraceManager::DebugTrace(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
  1230. FunctionBlockPtr true_block = MakeFunctionBlock(*this);
  1231. TraceManager::EndTrace();
  1232. TraceManager::DebugTrace(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
  1233. FunctionBlockPtr false_block = MakeFunctionBlock(*this);
  1234. TraceManager::EndTrace();
  1235. MakeConditionBlocks(block, true_block, false_block);
  1236. // process the if-true branch
  1237. py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
  1238. true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode));
  1239. AnfNodePtr true_node = ParseExprNode(true_block, bodyNode);
  1240. // process the orelse branch
  1241. py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
  1242. false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode));
  1243. AnfNodePtr false_node = ParseExprNode(false_block, orelseNode);
  1244. true_block->func_graph()->set_output(true_node);
  1245. false_block->func_graph()->set_output(false_node);
  1246. // Use the Primitive replace the operation resolve node (switch)
  1247. // because the switch will eventually be converted to Primitive node
  1248. CNodePtr switch_app =
  1249. block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), bool_node, NewValueNode(true_block->func_graph()),
  1250. NewValueNode(false_block->func_graph())});
  1251. std::vector<AnfNodePtr> call_graph_nodes{switch_app};
  1252. CNodePtr switch_app_call = block->func_graph()->NewCNode(call_graph_nodes);
  1253. return switch_app_call;
  1254. }
  1255. void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
  1256. MS_EXCEPTION_IF_NULL(block);
  1257. MS_EXCEPTION_IF_NULL(assigned_node);
  1258. py::str name = python_adapter::GetPyObjAttr(targ, "id");
  1259. std::string name_id = name;
  1260. assigned_node->debug_info()->set_name(name_id);
  1261. // set the debug name of the constant graph
  1262. if (IsValueNode<FuncGraph>(assigned_node)) {
  1263. // the value should be graph
  1264. auto fg = GetValueNode<FuncGraphPtr>(assigned_node);
  1265. if (fg->debug_info()->name().empty()) {
  1266. fg->debug_info()->set_name(name_id);
  1267. }
  1268. }
  1269. block->WriteVariable(name_id, assigned_node);
  1270. }
  1271. void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
  1272. MS_EXCEPTION_IF_NULL(block);
  1273. AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
  1274. py::list items = python_adapter::GetPyObjAttr(targ, "elts");
  1275. for (size_t i = 0; i < items.size(); i++) {
  1276. // Use the Primitive replace the operation resolve node (getitem)
  1277. // because the getitem will eventually be converted to Primitive node
  1278. CNodePtr item_apply = block->func_graph()->NewCNode({op_getitem, assigned_node, NewValueNode(static_cast<int>(i))});
  1279. py::object elt = items[i];
  1280. WriteAssignVars(block, elt, item_apply);
  1281. }
  1282. }
  1283. void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ,
  1284. const AnfNodePtr &assigned_node) {
  1285. // Now only support the self.xx = xxxxx, can't support x.y = xxxx
  1286. AnfNodePtr target_node = ParseExprNode(block, targ);
  1287. MS_EXCEPTION_IF_NULL(target_node);
  1288. std::string attr_name = targ.attr("attr").cast<std::string>();
  1289. std::string var_name = "self.";
  1290. (void)var_name.append(attr_name);
  1291. MS_LOG(DEBUG) << "assign " << var_name;
  1292. // Get targ location info for error printing
  1293. py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, targ);
  1294. if (location.size() < 2) {
  1295. MS_LOG(EXCEPTION) << "List size should not be less than 2.";
  1296. }
  1297. auto filename = location[0].cast<std::string>();
  1298. auto line_no = location[1].cast<int>();
  1299. // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
  1300. if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
  1301. MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined, at " << filename << ":"
  1302. << line_no;
  1303. }
  1304. auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
  1305. auto obj_type = obj.attr("__class__").attr("__name__");
  1306. if (!py::hasattr(obj, "__parameter__")) {
  1307. MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '"
  1308. << py::str(obj).cast<std::string>() << "' with type '"
  1309. << py::str(obj_type).cast<std::string>() << "' at " << filename << ":" << line_no;
  1310. }
  1311. MS_EXCEPTION_IF_NULL(block);
  1312. block->WriteVariable(var_name, assigned_node);
  1313. MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString();
  1314. block->SetStateAssgin(target_node, var_name);
  1315. }
  1316. void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ,
  1317. const AnfNodePtr &assigned_node) {
  1318. MS_EXCEPTION_IF_NULL(block);
  1319. AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
  1320. py::object value_obj = python_adapter::GetPyObjAttr(targ, "value");
  1321. py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice");
  1322. AnfNodePtr value_node = ParseExprNode(block, value_obj);
  1323. AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
  1324. CNodePtr setitem_app = block->func_graph()->NewCNode({op_setitem, value_node, slice_node, assigned_node});
  1325. // getitem apply should return the sequence data structure itself
  1326. std::string var_name = "";
  1327. if (ast_->IsClassMember(value_obj)) {
  1328. std::string attr_name = value_obj.attr("attr").cast<std::string>();
  1329. var_name = "self." + attr_name;
  1330. if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
  1331. MS_EXCEPTION(TypeError) << "'" << var_name << "' was not defined in the class '__init__' function.";
  1332. }
  1333. auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
  1334. auto obj_type = obj.attr("__class__").attr("__name__");
  1335. if (!py::hasattr(obj, "__parameter__")) {
  1336. MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '"
  1337. << py::str(obj).cast<std::string>() << "' with type '"
  1338. << py::str(obj_type).cast<std::string>() << "'.";
  1339. }
  1340. } else {
  1341. var_name = value_obj.attr("id").cast<std::string>();
  1342. }
  1343. block->WriteVariable(var_name, setitem_app);
  1344. }
  1345. void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node) {
  1346. MS_EXCEPTION_IF_NULL(value_node);
  1347. MS_LOG(DEBUG) << "Process WriteAssignVars";
  1348. auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, targ)));
  1349. if (ast_type == AST_SUB_TYPE_NAME) {
  1350. HandleAssignName(block, targ, value_node);
  1351. } else if (ast_type == AST_SUB_TYPE_TUPLE) {
  1352. HandleAssignTuple(block, targ, value_node);
  1353. } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
  1354. HandleAssignSubscript(block, targ, value_node);
  1355. } else if (ast_->IsClassMember(targ)) {
  1356. HandleAssignClassMember(block, targ, value_node);
  1357. } else {
  1358. MS_LOG(EXCEPTION) << "Not supported assign type: " << ast_type
  1359. << " NodeInfo: " << trace::GetDebugInfo(value_node->debug_info());
  1360. }
  1361. }
  1362. // process a assign statement, such as a =b, a,b = tup
  1363. FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
  1364. MS_LOG(DEBUG) << "Process ast assgin";
  1365. py::object value_object = python_adapter::GetPyObjAttr(node, "value");
  1366. AnfNodePtr value_node = ParseExprNode(block, value_object);
  1367. py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
  1368. py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__");
  1369. size_t count = IntToSize(pcount);
  1370. MS_LOG(DEBUG) << "The nodes count is " << count;
  1371. for (size_t i = 0; i < count; i++) {
  1372. auto target_node = py::cast<py::list>(targets_object)[i];
  1373. WriteAssignVars(block, target_node, value_node);
  1374. }
  1375. return block;
  1376. }
  1377. FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) {
  1378. if (loops_.empty()) {
  1379. // Report error if loop context not set for the 'break' statement.
  1380. py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  1381. if (location.size() < 2) {
  1382. MS_LOG(EXCEPTION) << "List size should not be less than 2.";
  1383. }
  1384. auto filename = location[0].cast<std::string>();
  1385. auto line_no = location[1].cast<int>();
  1386. MS_LOG(EXCEPTION) << "Unexpected 'break' at " << filename << ":" << line_no;
  1387. }
  1388. // Get current loop.
  1389. Loop &loop = loops_.top();
  1390. if (loop.end == nullptr) {
  1391. // Create end_block if it is not existed.
  1392. TraceManager::DebugTrace(std::make_shared<TraceLoopEnd>(block->func_graph()->debug_info()));
  1393. loop.end = MakeFunctionBlock(*this);
  1394. TraceManager::EndTrace();
  1395. }
  1396. // Jump to the end_block.
  1397. block->Jump(loop.end, nullptr);
  1398. return block;
  1399. }
  1400. FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) {
  1401. if (loops_.empty()) {
  1402. // Report error if loop context not set for the 'continue' statement.
  1403. py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
  1404. if (location.size() < 2) {
  1405. MS_LOG(EXCEPTION) << "List size should not be less than 2.";
  1406. }
  1407. auto filename = location[0].cast<std::string>();
  1408. auto line_no = location[1].cast<int>();
  1409. MS_LOG(EXCEPTION) << "Unexpected 'continue' at " << filename << ":" << line_no;
  1410. }
  1411. // Jump to the header of the loop with iterator called.
  1412. Loop &loop = loops_.top();
  1413. block->Jump(loop.header, loop.iterator);
  1414. return block;
  1415. }
  1416. FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) {
  1417. // We just bypass 'pass' statement.
  1418. return block;
  1419. }
  1420. AnfNodePtr FindPhis(const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis, const AnfNodePtr &node) {
  1421. const auto &inp = node->cast<ParameterPtr>();
  1422. const auto &iter = removable_phis.find(inp);
  1423. if (iter == removable_phis.end()) {
  1424. return node;
  1425. }
  1426. return FindPhis(removable_phis, iter->second);
  1427. }
  1428. void Parser::RemoveUnnecessaryPhis() {
  1429. // merge all removable phis to one map;
  1430. std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis;
  1431. std::vector<ParameterPtr> phis;
  1432. for (FunctionBlockPtr &block : func_block_list_) {
  1433. MS_EXCEPTION_IF_NULL(block);
  1434. removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end());
  1435. std::transform(block->removable_phis().begin(), block->removable_phis().end(), std::back_inserter(phis),
  1436. [](std::pair<ParameterPtr, AnfNodePtr> pair) { return pair.first; });
  1437. }
  1438. if (removable_phis.size() == 0) {
  1439. return;
  1440. }
  1441. auto fg_name = func_graph_->ToString();
  1442. auto mng = Manage(func_graph_, false);
  1443. // replace the nodes
  1444. // remove from inside to outside
  1445. for (int idx = SizeToInt(phis.size() - 1); idx >= 0; idx--) {
  1446. auto phi = phis[IntToSize(idx)];
  1447. auto new_node = FindPhis(removable_phis, phi);
  1448. MS_LOG(DEBUG) << "phi " << phi->DebugString() << " to " << new_node->DebugString();
  1449. mng->Replace(phi, new_node);
  1450. }
  1451. // remove the parameter
  1452. for (FunctionBlockPtr &block : func_block_list_) {
  1453. MS_EXCEPTION_IF_NULL(block);
  1454. auto &local_removable_phis = block->removable_phis();
  1455. if (local_removable_phis.size() == 0) {
  1456. continue;
  1457. }
  1458. auto func_graph = block->func_graph();
  1459. auto &parameters = func_graph->parameters();
  1460. std::vector<AnfNodePtr> new_parameters(parameters.size());
  1461. auto it = std::copy_if(
  1462. parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](AnfNodePtr param) {
  1463. return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end();
  1464. });
  1465. // shrink container to new size
  1466. new_parameters.resize(std::distance(new_parameters.begin(), it));
  1467. func_graph->set_parameters(new_parameters);
  1468. }
  1469. for (auto fg : mng->func_graphs()) {
  1470. fg->ClearAllManagerInfo();
  1471. }
  1472. }
  1473. // ParseAst class code
  1474. bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) {
  1475. // init the type
  1476. target_type_ = PARSE_TARGET_UNKNOW;
  1477. // call python parse, get the parser fn
  1478. module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
  1479. py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD);
  1480. // get the obj type
  1481. auto type = data_converter::GetObjType(obj_);
  1482. if (type == RESOLVE_TYPE_FUNCTION) {
  1483. target_type_ = PARSE_TARGET_FUNCTION;
  1484. function_ = obj_;
  1485. } else if (type == RESOLVE_TYPE_METHOD) {
  1486. // process the method ,need get the method's self obj
  1487. target_type_ = PARSE_TARGET_METHOD;
  1488. py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS);
  1489. if (py::isinstance<py::none>(method_object)) {
  1490. MS_LOG(ERROR) << "Get method's self object instance failed.";
  1491. return false;
  1492. }
  1493. target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
  1494. function_ = obj_;
  1495. obj_ = method_object;
  1496. } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) {
  1497. // obj is class instance, get the method to parse.
  1498. function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method);
  1499. if (py::isinstance<py::none>(function_)) {
  1500. MS_LOG(ERROR) << "Get obj method function failed.";
  1501. return false;
  1502. }
  1503. target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
  1504. // check the fn is method
  1505. auto obj_type = data_converter::GetObjType(function_);
  1506. if (obj_type != RESOLVE_TYPE_METHOD) {
  1507. MS_LOG(WARNING) << "Parse method function is invalid.";
  1508. return false;
  1509. }
  1510. } else {
  1511. MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type;
  1512. return false;
  1513. }
  1514. // call python parse get ast tree
  1515. parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method);
  1516. ast_tree_ = python_adapter::CallPyObjMethod(parser_, "parse");
  1517. // get fn name and module
  1518. function_module_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_module"));
  1519. function_name_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_name"));
  1520. function_filename_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "filename"));
  1521. function_line_offset_ = py::cast<int>(python_adapter::GetPyObjAttr(parser_, "line_offset"));
  1522. return true;
  1523. }
  1524. // Get ast tree node : is the tree bode list[0]
  1525. py::object ParseAst::GetAstNode() {
  1526. py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body");
  1527. py::object ast_node = tree_body[0];
  1528. return ast_node;
  1529. }
  1530. py::list ParseAst::GetArgs(const py::object &func_node) {
  1531. py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS, func_node);
  1532. return ret;
  1533. }
  1534. py::list ParseAst::GetArgsDefaultValues(const py::object &func_node) {
  1535. py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node);
  1536. return ret;
  1537. }
  1538. AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) {
  1539. py::list list_value = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_NODE_TYPE, node);
  1540. if (list_value.size() < 2) {
  1541. MS_LOG(ERROR) << "The node of python method must has 2 values.";
  1542. return nullptr;
  1543. }
  1544. auto node_name = py::cast<std::string>(list_value[0]);
  1545. auto type = AstMainType(py::cast<int32_t>(list_value[1]));
  1546. return std::make_shared<AstNodeType>(node, node_name, type);
  1547. }
  1548. AstSubType ParseAst::GetOpType(const py::object &node) {
  1549. auto op_type = AstSubType(python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_AST_TYPE, node).cast<int32_t>());
  1550. return op_type;
  1551. }
  1552. bool ParseAst::IsClassMember(const py::object &node) {
  1553. py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node);
  1554. if (!py::isinstance<py::bool_>(ret)) {
  1555. MS_LOG(ERROR) << "The result of mod function parse, should be bool type.";
  1556. return false;
  1557. }
  1558. return ret.cast<bool>();
  1559. }
  1560. bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
  1561. if (func_graph == nullptr) {
  1562. MS_LOG(ERROR) << "FuncGraph is null";
  1563. return false;
  1564. }
  1565. if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) {
  1566. MS_LOG(DEBUG) << "No flags";
  1567. return true;
  1568. }
  1569. py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG);
  1570. for (auto &item : flags) {
  1571. if (!py::isinstance<py::str>(item.first)) {
  1572. MS_LOG(ERROR) << "Type error in flags dict convert";
  1573. return false;
  1574. }
  1575. auto name = py::cast<std::string>(item.first);
  1576. if (py::isinstance<py::bool_>(item.second)) {
  1577. auto value = py::cast<bool>(item.second);
  1578. MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
  1579. func_graph->set_flag(name, value);
  1580. } else if (py::isinstance<py::str>(item.second)) {
  1581. auto value = py::cast<std::string>(item.second);
  1582. MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
  1583. func_graph->set_attr(name, MakeValue(value));
  1584. } else {
  1585. MS_LOG(ERROR) << "Type error in flags/attrs dict convert";
  1586. return false;
  1587. }
  1588. }
  1589. return true;
  1590. }
  1591. // Generate and copy a ValueNode, or a CNode with its child nodes
  1592. static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr func_graph, const AnfNodePtr &param_node) {
  1593. MS_EXCEPTION_IF_NULL(param_node);
  1594. if (param_node->isa<ValueNode>()) {
  1595. return std::make_shared<ValueNode>(param_node->cast<ValueNodePtr>()->value());
  1596. }
  1597. // Parameter default value is CNode.
  1598. std::size_t index = 0;
  1599. std::vector<AnfNodePtr> old_cnodes;
  1600. old_cnodes.emplace_back(param_node);
  1601. auto res = func_graph->NewCNode({});
  1602. std::vector<CNodePtr> new_cnodes;
  1603. new_cnodes.emplace_back(res);
  1604. while (index < old_cnodes.size()) {
  1605. auto current = old_cnodes[index];
  1606. auto current_new_cnode = new_cnodes[index];
  1607. index++;
  1608. MS_EXCEPTION_IF_NULL(current);
  1609. if (current->isa<CNode>()) {
  1610. auto &inputs = current->cast<CNodePtr>()->inputs();
  1611. for (auto it = inputs.begin(); it != inputs.end(); it++) {
  1612. AnfNodePtr input = *it;
  1613. if (input != nullptr && input->isa<CNode>()) {
  1614. old_cnodes.emplace_back(input);
  1615. auto new_cnode = func_graph->NewCNode({});
  1616. new_cnodes.emplace_back(new_cnode);
  1617. current_new_cnode->add_input(new_cnode);
  1618. } else if (input->isa<ValueNode>()) {
  1619. current_new_cnode->add_input(std::make_shared<ValueNode>(input->cast<ValueNodePtr>()->value()));
  1620. } else {
  1621. MS_LOG(EXCEPTION) << "Wrong type item in default parameters: " << input->ToString();
  1622. }
  1623. }
  1624. }
  1625. }
  1626. return res;
  1627. }
  1628. FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
  1629. auto current_graph = dyn_cast<FuncGraph>(cell_ptr);
  1630. if (current_graph == nullptr) {
  1631. MS_LOG(EXCEPTION) << "Current graph cast failed from " << cell_ptr->ToString();
  1632. }
  1633. auto func_graph = std::make_shared<FuncGraph>();
  1634. func_graph->debug_info()->set_name(current_graph->debug_info()->name() + "_wrapper");
  1635. // Copy all parameters information
  1636. for (auto &para : current_graph->parameters()) {
  1637. auto param = func_graph->add_parameter();
  1638. auto orig_param = para->cast<ParameterPtr>();
  1639. auto name = orig_param->name();
  1640. param->set_name(name);
  1641. param->debug_info()->set_name(name);
  1642. }
  1643. func_graph->set_has_vararg(current_graph->has_vararg());
  1644. func_graph->set_has_kwarg(current_graph->has_kwarg());
  1645. func_graph->set_kwonlyargs_count(current_graph->kwonlyargs_count());
  1646. // Copy all default values
  1647. for (auto &d : current_graph->parameter_default_value()) {
  1648. func_graph->set_param_default_value(d.first, CopyNodesFromParamDefaultValue(func_graph, d.second));
  1649. }
  1650. // cell_obj
  1651. MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell));
  1652. parse::UpdateFuncGraphFlags(cell, func_graph);
  1653. // top graph's construct flag
  1654. if (py::hasattr(cell, "construct")) {
  1655. parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph);
  1656. }
  1657. auto unpacking = func_graph->has_vararg() || func_graph->has_kwarg();
  1658. if (!unpacking) {
  1659. std::vector<AnfNodePtr> inputs;
  1660. inputs.emplace_back(NewValueNode(cell_ptr));
  1661. auto &params = func_graph->parameters();
  1662. (void)std::transform(params.begin(), params.end(), std::back_inserter(inputs),
  1663. [](AnfNodePtr node) -> AnfNodePtr { return node; });
  1664. func_graph->set_output(func_graph->NewCNode(inputs));
  1665. } else {
  1666. // ret = cell_obj(*arg, *kwargs)
  1667. auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), func_graph->parameters());
  1668. // return ret
  1669. func_graph->set_output(call_fn);
  1670. }
  1671. return func_graph;
  1672. }
  1673. } // namespace parse
  1674. } // namespace mindspore