You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf_ir_utils.cc 73 kB

5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "debug/anf_ir_utils.h"
  17. #include <fstream>
  18. #include <map>
  19. #include <memory>
  20. #include <unordered_map>
  21. #include <algorithm>
  22. #include <iomanip>
  23. #include "ir/graph_utils.h"
  24. #include "utils/symbolic.h"
  25. #include "ir/meta_func_graph.h"
  26. #include "ir/param_info.h"
  27. #include "pybind_api/ir/tensor_py.h"
  28. #include "pipeline/jit/parse/python_adapter.h"
  29. #include "pipeline/jit/parse/resolve.h"
  30. #include "frontend/operator/composite/composite.h"
  31. #include "frontend/operator/composite/map.h"
  32. #include "utils/ordered_map.h"
  33. #include "utils/ordered_set.h"
  34. #include "utils/utils.h"
  35. #include "utils/shape_utils.h"
  36. #include "debug/trace.h"
  37. #include "utils/label.h"
  38. #include "utils/ms_context.h"
  39. #include "frontend/operator/ops.h"
  40. #include "pipeline/jit/base.h"
  41. using mindspore::tensor::TensorPy;
  42. namespace mindspore {
  43. // max number of elements in sequence
  44. const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF;
  45. // ============================================== MindSpore IR Common ==============================================
  46. // get MindSpore Intermediate Representation Path
  47. std::string GetMsIrPath(void) {
  48. std::string path;
  49. const char *path_ptr = getenv("MS_IR_PATH");
  50. if (path_ptr != nullptr) {
  51. path = path_ptr;
  52. char real_path[PATH_MAX] = {0};
  53. #if defined(_WIN32) || defined(_WIN64)
  54. if (path.size() > PATH_MAX || _fullpath(real_path, path.c_str(), PATH_MAX) == nullptr) {
  55. MS_LOG(EXCEPTION) << "MS IR Path error, " << path_ptr;
  56. }
  57. #else
  58. if (path.size() > PATH_MAX || nullptr == realpath(path.c_str(), real_path)) {
  59. MS_LOG(EXCEPTION) << "MS IR path error, " << path_ptr;
  60. }
  61. #endif
  62. path = real_path;
  63. }
  64. return path;
  65. }
  66. std::string dump_obj(const py::object &obj, const std::string &path) {
  67. py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
  68. py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path));
  69. return py::str(name);
  70. }
  71. py::object load_obj(const std::string &path) {
  72. py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
  73. py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path));
  74. return obj;
  75. }
  76. // ============================================= MindSpore IR Exporter =============================================
  77. std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) {
  78. abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
  79. TypePtr type = dyn_cast<Type>(nd->Type());
  80. std::ostringstream oss;
  81. if ((nullptr != shape) && (nullptr != type)) {
  82. oss << type->DumpText() << shape->DumpText();
  83. } else if (nullptr != type) {
  84. oss << type->DumpText();
  85. } else {
  86. oss << "Undefined";
  87. }
  88. return oss.str();
  89. }
  90. std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const {
  91. std::string pkl_path = GetMsIrPath();
  92. // if not specified env 'MS_IR_PATH', do not create any files
  93. if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) {
  94. return "null";
  95. }
  96. std::string file_prefix = id_ + "." + category;
  97. std::string file_name = dump_obj(obj, pkl_path + "/" + file_prefix);
  98. return file_prefix + file_name;
  99. }
  100. int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &param, bool throw_excp) {
  101. if (func_graph == nullptr || param == nullptr) {
  102. return -1;
  103. }
  104. FuncGraphPtr fg = func_graph;
  105. while (fg != nullptr) {
  106. if (exported.find(fg) == exported.end()) {
  107. if (!check_integrity_) {
  108. break;
  109. }
  110. MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'";
  111. }
  112. auto param_map = exported[fg];
  113. if (param_map.find(param) != param_map.end()) {
  114. return param_map[param];
  115. }
  116. fg = fg->parent();
  117. }
  118. if (throw_excp) {
  119. MS_LOG(EXCEPTION) << "Can not find index for param '" << param->DumpText() << "' for func graph '"
  120. << func_graph->DumpText() << "." << func_graph->debug_info()->get_id() << "'";
  121. }
  122. return -1;
  123. }
  124. // try to find index of parameter for SymbolicKeyInstance from all exported graphs
  125. // NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different
  126. int AnfExporter::GetParamIndexFromExported(const AnfNodePtr &param) {
  127. if (param == nullptr) {
  128. return -1;
  129. }
  130. int ret = -1;
  131. for (const auto &item : exported) {
  132. auto pram_iter = item.second.find(param);
  133. if (pram_iter != item.second.end()) {
  134. return pram_iter->second;
  135. }
  136. }
  137. return ret;
  138. }
  139. std::string AnfExporter::GetValueNodeText(const FuncGraphPtr &fg, const ValueNodePtr &node) {
  140. MS_EXCEPTION_IF_NULL(node);
  141. return GetValueText(fg, node->value());
  142. }
  143. std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph) {
  144. auto py_funcs = mt_func_graph->GetPyFunctions();
  145. if (py_funcs.empty()) {
  146. return "";
  147. }
  148. std::ostringstream oss;
  149. oss << "{";
  150. bool is_first = true;
  151. for (const auto &py_func : py_funcs) {
  152. if (is_first) {
  153. is_first = false;
  154. } else {
  155. oss << ", ";
  156. }
  157. oss << "(";
  158. for (size_t i = 0; i < py_func.first.size(); ++i) {
  159. if (i > 0) {
  160. oss << ", ";
  161. }
  162. oss << py_func.first[i]->DumpText();
  163. }
  164. oss << ")";
  165. // dump Python Function object
  166. oss << "@" << DumpObject(py_func.second, "F");
  167. }
  168. oss << "}";
  169. return oss.str();
  170. }
  171. /* inherit relation of MetaFuncGraph
  172. *
  173. * MetaGraph
  174. * ├── MultitypeGraph
  175. * ├── HyperMap
  176. * │ └── HyperMapPy
  177. * ├── Map
  178. * │ └── MapPy
  179. * ├── Tail
  180. * ├── MakeTupleGradient
  181. * ├── MakeListGradient
  182. * ├── GradOperation
  183. * └── TupleAdd
  184. */
  185. std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) {
  186. if (meta_func_graph == nullptr) {
  187. return "";
  188. }
  189. std::ostringstream oss;
  190. oss << meta_func_graph->type_name() << "::" << meta_func_graph->name();
  191. if (meta_func_graph->isa<prim::MultitypeFuncGraph>()) {
  192. prim::MultitypeFuncGraphPtr mt_func_graph = meta_func_graph->cast<prim::MultitypeFuncGraphPtr>();
  193. oss << GetMultitypeFuncGraphText(mt_func_graph);
  194. } else if (meta_func_graph
  195. ->isa<prim::HyperMapPy>()) { // this statement must before 'meta_graph->isa<prim::HyperMap>()'
  196. auto hyper_map = meta_func_graph->cast<prim::HyperMapPyPtr>();
  197. if (hyper_map->GetFnLeaf() != nullptr) {
  198. oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}";
  199. }
  200. } else if (meta_func_graph->isa<prim::HyperMap>()) {
  201. auto hyper_map = meta_func_graph->cast<prim::HyperMapPtr>();
  202. if (hyper_map->GetFnLeaf() != nullptr) {
  203. oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}";
  204. }
  205. } else if (meta_func_graph->isa<prim::MapPy>()) { // this statement must before 'meta_graph->isa<prim::Map>()'
  206. auto map = meta_func_graph->cast<prim::MapPyPtr>();
  207. if (map->GetFnLeaf() != nullptr) {
  208. oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}";
  209. }
  210. } else if (meta_func_graph->isa<prim::Map>()) {
  211. auto map = meta_func_graph->cast<prim::MapPtr>();
  212. if (map->GetFnLeaf() != nullptr) {
  213. oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}";
  214. }
  215. } else if (meta_func_graph->isa<prim::GradOperation>()) {
  216. prim::GradOperationPtr grad_op = meta_func_graph->cast<prim::GradOperationPtr>();
  217. oss << "{get_all=" << grad_op->get_all_ << ", get_by_list=" << grad_op->get_by_list_
  218. << ", sens_param=" << grad_op->sens_param_ << "}";
  219. } else if (meta_func_graph->isa<prim::Tail>()) {
  220. // do nothing
  221. } else if (meta_func_graph->isa<prim::MakeTupleGradient>()) {
  222. // do nothing
  223. } else if (meta_func_graph->isa<prim::MakeListGradient>()) {
  224. // do nothing
  225. } else if (meta_func_graph->isa<prim::TupleAdd>()) {
  226. // do nothing
  227. } else if (meta_func_graph->isa<prim::TupleSlice>()) {
  228. // do nothing
  229. } else if (meta_func_graph->isa<prim::UnpackCall>()) {
  230. // do nothing
  231. } else if (meta_func_graph->isa<prim::ZipOperation>()) {
  232. // do nothing
  233. } else if (meta_func_graph->isa<prim::ListAppend>()) {
  234. // do nothing
  235. } else if (meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
  236. // do nothing
  237. } else {
  238. MS_LOG(EXCEPTION) << "Unknown MetaFuncGraph type " << meta_func_graph->type_name();
  239. }
  240. return oss.str();
  241. }
  242. std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) {
  243. std::ostringstream oss;
  244. if (prim == nullptr) {
  245. return oss.str();
  246. }
  247. oss << prim->type_name() << "::" << prim->name();
  248. // need to serialize internal python function of PrimitivePy and record its prim_type
  249. if (prim->isa<PrimitivePy>()) {
  250. PrimitivePyPtr primpy = prim->cast<PrimitivePyPtr>();
  251. // dump related function in PrimitivePy
  252. oss << "@" << DumpObject(primpy->GetPyObj(), "P");
  253. // output primitive type
  254. oss << "{prim_type=" << static_cast<int>(prim->prim_type()) << "}";
  255. }
  256. // output primitive attributes
  257. oss << prim->GetAttrsText();
  258. if (prim->isa<prim::DoSignaturePrimitive>()) {
  259. auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim);
  260. auto &func = do_signature->function();
  261. if (func->isa<Primitive>()) {
  262. auto sig_prim = dyn_cast<Primitive>(func);
  263. oss << sig_prim->GetAttrsText();
  264. }
  265. }
  266. return oss.str();
  267. }
  268. std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) {
  269. std::ostringstream oss;
  270. if (ns == nullptr) {
  271. return oss.str();
  272. }
  273. // dump related module information in Namespace
  274. oss << ns->type_name() << "::" << ns->module() << "@" << DumpObject(ns->obj(), "N");
  275. return oss.str();
  276. }
  277. std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph,
  278. const SymbolicKeyInstancePtr &sym_inst) {
  279. MS_EXCEPTION_IF_NULL(func_graph);
  280. MS_EXCEPTION_IF_NULL(sym_inst);
  281. AnfNodePtr sym_node = sym_inst->node();
  282. MS_EXCEPTION_IF_NULL(sym_node);
  283. std::ostringstream oss;
  284. if (sym_node->isa<Parameter>()) {
  285. int idx = GetParamIndex(func_graph, sym_node, false);
  286. // if can not find SymbolicKeyInstance related parameter from ancestors,
  287. // try to find from all exported graphs
  288. if (idx < 0) {
  289. idx = GetParamIndexFromExported(sym_node);
  290. }
  291. if (idx < 0) {
  292. ParameterPtr p = dyn_cast<Parameter>(sym_node);
  293. if (p == nullptr) {
  294. MS_LOG(EXCEPTION) << "Sym_inst's node could not cast to parameter";
  295. }
  296. MS_LOG(WARNING) << "Can not find SymbolicKeyInstance: " << p->name();
  297. }
  298. oss << "SymInst(%para" << idx << ")";
  299. } else {
  300. MS_LOG(WARNING) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString();
  301. oss << "SymInst(cnode_" << sym_node->ToString() << ")";
  302. }
  303. return oss.str();
  304. }
  305. std::string AnfExporter::GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
  306. std::ostringstream oss;
  307. // output ValueList, ValueTuple
  308. ValueSequeuePtr seq = dyn_cast<ValueSequeue>(value);
  309. MS_EXCEPTION_IF_NULL(seq);
  310. MS_EXCEPTION_IF_NULL(value);
  311. bool is_tuple = value->isa<ValueTuple>();
  312. oss << (is_tuple ? "(" : "[");
  313. bool first_flag = true;
  314. for (auto elem : seq->value()) {
  315. if (first_flag) {
  316. first_flag = false;
  317. } else {
  318. oss << ", ";
  319. }
  320. oss << GetValueText(func_graph, elem);
  321. }
  322. oss << (is_tuple ? ")" : "]");
  323. return oss.str();
  324. }
  325. std::string AnfExporter::GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
  326. std::ostringstream oss;
  327. ValueDictionaryPtr dict = value->cast<ValueDictionaryPtr>();
  328. oss << "{";
  329. bool first_flag = true;
  330. for (const auto &elem : dict->value()) {
  331. if (first_flag) {
  332. first_flag = false;
  333. } else {
  334. oss << ", ";
  335. }
  336. oss << "\"" << elem.first << "\": " << GetValueText(func_graph, elem.second);
  337. }
  338. oss << "}";
  339. return oss.str();
  340. }
  341. std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr &value) {
  342. std::ostringstream oss;
  343. if (check_integrity_) {
  344. MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText();
  345. }
  346. oss << value->type_name() << "[" << value->DumpText() << "]";
  347. return oss.str();
  348. }
  349. std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
  350. std::ostringstream oss;
  351. bool is_null_ptr = (func_graph == nullptr || value == nullptr);
  352. if (is_null_ptr) {
  353. return oss.str();
  354. }
  355. if (value->isa<Primitive>()) {
  356. oss << GetPrimitiveText(value->cast<PrimitivePtr>());
  357. } else if (value->isa<MetaFuncGraph>()) {
  358. MetaFuncGraphPtr meta_func_graph = value->cast<MetaFuncGraphPtr>();
  359. oss << GetMetaFuncGraphText(meta_func_graph);
  360. } else if (value->isa<SymbolicKeyInstance>()) {
  361. oss << GetSymbolicKeyInstanceText(func_graph, value->cast<SymbolicKeyInstancePtr>());
  362. } else if (value->isa<RefKey>()) {
  363. oss << value->DumpText();
  364. } else if (value->isa<Scalar>() || value->isa<StringImm>()) {
  365. oss << value->DumpText();
  366. } else if (value->isa<tensor::Tensor>()) {
  367. auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
  368. oss << value->DumpText() << "@" << DumpObject(TensorPy::AsNumpy(*tensor_ptr), "T");
  369. } else if (value->isa<parse::Symbol>() || value->isa<None>() || value->isa<Null>()) {
  370. oss << value->DumpText();
  371. } else if (value->isa<ValueSequeue>()) {
  372. oss << GetSequenceText(func_graph, value);
  373. } else if (value->isa<ValueDictionary>()) {
  374. oss << GetDictText(func_graph, value);
  375. } else if (value->isa<ValueSlice>()) {
  376. ValueSlicePtr slice = value->cast<ValueSlicePtr>();
  377. oss << slice->DumpText();
  378. } else if (value->isa<Type>()) {
  379. oss << value->DumpText();
  380. } else if (value->isa<parse::NameSpace>()) {
  381. oss << GetNameSpaceText(value->cast<parse::NameSpacePtr>());
  382. } else if (value->isa<parse::PyObjectWrapper>()) {
  383. oss << value->type_name();
  384. } else if (value->isa<KeywordArg>()) {
  385. KeywordArgPtr keyword_arg = value->cast<KeywordArgPtr>();
  386. oss << keyword_arg->DumpText();
  387. } else {
  388. return GetOtherValueText(func_graph, value);
  389. }
  390. return oss.str();
  391. }
  392. // this function is used to output node in CNode's inputs
  393. std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  394. const std::map<AnfNodePtr, int> &apply_map) {
  395. std::ostringstream oss;
  396. if (func_graph == nullptr || node == nullptr) {
  397. return oss.str();
  398. }
  399. if (node->isa<CNode>()) {
  400. auto iter = apply_map.find(node);
  401. if (iter == apply_map.end()) {
  402. MS_LOG(EXCEPTION) << "Can not find node '" << node->DumpText() << "' in apply_map";
  403. }
  404. oss << "%" << iter->second;
  405. } else if (node->isa<Parameter>()) {
  406. // Parameter maybe a free variable, so check it in its own funcgraph.
  407. oss << "%para" << GetParamIndex(node->func_graph(), node, check_integrity_);
  408. } else if (IsValueNode<FuncGraph>(node)) {
  409. FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
  410. oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id();
  411. if (!func_graph_set.contains(fg) && exported.find(fg) == exported.end() && export_used_) {
  412. func_graph_set.add(fg);
  413. }
  414. } else if (node->isa<ValueNode>()) {
  415. oss << GetValueNodeText(func_graph, node->cast<ValueNodePtr>());
  416. } else {
  417. MS_LOG(EXCEPTION) << "Unknown node '" << node->DumpText() << "'";
  418. }
  419. return oss.str();
  420. }
  421. void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> &parameters,
  422. OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map) {
  423. bool first_flag = true;
  424. for (const AnfNodePtr &param : parameters) {
  425. if (first_flag) {
  426. first_flag = false;
  427. ofs << " ";
  428. } else {
  429. ofs << " , ";
  430. }
  431. (*param_map)[param] = param_index;
  432. std::string type_info = GetNodeType(param);
  433. // output parameter and type
  434. if (type_info == "Undefined") {
  435. ofs << "%para" << param_index;
  436. } else {
  437. ofs << "%para" << param_index << " : " << type_info;
  438. }
  439. // dump Default value of parameter if exists
  440. const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
  441. if (param_ptr == nullptr) {
  442. MS_LOG(EXCEPTION) << "Param could not cast to parameter";
  443. }
  444. if (param_ptr->has_default()) {
  445. auto param_value = param_ptr->default_param();
  446. ofs << " = @" << DumpObject(py::cast(param_value), "D");
  447. }
  448. // output comment
  449. ofs << " # " << param->DumpText() << "\n";
  450. param_index += 1;
  451. }
  452. }
  453. void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &node) {
  454. if (node == nullptr) {
  455. return;
  456. }
  457. // output type of each input argument
  458. auto &inputs = node->inputs();
  459. if (inputs.size() > 1) {
  460. ofs << " #(";
  461. for (size_t i = 1; i < inputs.size(); ++i) {
  462. if (i != 1) {
  463. ofs << ", ";
  464. }
  465. AnfNodePtr arg = inputs[i];
  466. ofs << GetNodeType(arg);
  467. }
  468. ofs << ")";
  469. }
  470. // output other comment, map the graph name to original representation(containing unicode character)
  471. std::ostringstream comment;
  472. comment << " #";
  473. bool has_comment = false;
  474. for (size_t i = 0; i < inputs.size(); ++i) {
  475. AnfNodePtr arg = inputs[i];
  476. if (!IsValueNode<FuncGraph>(arg)) {
  477. continue;
  478. }
  479. if (!has_comment) {
  480. has_comment = true;
  481. } else {
  482. comment << ",";
  483. }
  484. FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(arg);
  485. std::string func_graph_id = fg->debug_info()->get_id();
  486. comment << " fg_" << func_graph_id << "=" << fg->ToString() << "." << func_graph_id;
  487. }
  488. if (has_comment) {
  489. ofs << comment.str();
  490. }
  491. ofs << " #scope: " << node->scope()->name();
  492. }
  493. void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes,
  494. const FuncGraphPtr &func_graph) {
  495. if (func_graph == nullptr) {
  496. return;
  497. }
  498. int idx = 1;
  499. std::map<AnfNodePtr, int> apply_map;
  500. for (const AnfNodePtr &node : nodes) {
  501. MS_EXCEPTION_IF_NULL(node);
  502. if (!node->isa<CNode>()) {
  503. continue;
  504. }
  505. auto iter = tagged_cnodes_.find(node);
  506. if (iter != tagged_cnodes_.end()) {
  507. ofs << "\n#------------------------> " << iter->second << "\n";
  508. }
  509. auto cnode = node->cast<CNodePtr>();
  510. auto &inputs = cnode->inputs();
  511. std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map);
  512. // non-return node
  513. if (node != func_graph->get_return()) {
  514. int apply_idx = idx++;
  515. apply_map[node] = apply_idx;
  516. std::string type_info = GetNodeType(node);
  517. if (type_info == "Undefined") {
  518. ofs << " %" << apply_idx << " = " << op_text << "(";
  519. } else {
  520. ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "(";
  521. }
  522. } else {
  523. ofs << " " << op_text << "(";
  524. }
  525. for (size_t i = 1; i < inputs.size(); ++i) {
  526. if (i != 1) {
  527. ofs << ", ";
  528. }
  529. AnfNodePtr arg = inputs[i];
  530. ofs << GetAnfNodeText(func_graph, arg, apply_map);
  531. }
  532. ofs << ")";
  533. // output comment
  534. OutputStatementComment(ofs, cnode);
  535. ofs << "\n";
  536. if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
  537. ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#"
  538. << label_manage::Label(cnode->debug_info()) << "\n";
  539. } else {
  540. ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#" << cnode->ToString()
  541. << "\n";
  542. }
  543. }
  544. }
  545. void AnfExporter::OutputOrderList(std::ofstream &ofs, const FuncGraphPtr &func_graph) {
  546. auto &order_list = func_graph->order_list();
  547. if (order_list.empty()) {
  548. return;
  549. }
  550. constexpr int width = 4;
  551. ofs << "# order:\n";
  552. int i = 1;
  553. auto &isolate_nodes = func_graph->isolate_nodes();
  554. for (auto &node : order_list) {
  555. bool is_isolate = (isolate_nodes.find(node) != isolate_nodes.end());
  556. const std::string isolate_str = (is_isolate ? " # isolate" : "");
  557. ofs << '#' << std::setw(width) << i << ": " << node->DebugString() << isolate_str << '\n';
  558. ++i;
  559. }
  560. }
  561. void AnfExporter::OutputIsolateNodes(std::ofstream &ofs, const FuncGraphPtr &func_graph) {
  562. auto &isolate_nodes = func_graph->isolate_nodes();
  563. if (isolate_nodes.empty()) {
  564. return;
  565. }
  566. constexpr int width = 4;
  567. ofs << "# isolate nodes:\n";
  568. int i = 1;
  569. for (auto &node : isolate_nodes) {
  570. ofs << '#' << std::setw(width) << i << ": " << node->DebugString() << '\n';
  571. ++i;
  572. }
  573. }
  574. void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) {
  575. if (func_graph == nullptr) {
  576. return;
  577. }
  578. std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
  579. std::vector<AnfNodePtr> parameters = func_graph->parameters();
  580. OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;
  581. if (*(func_graph->switch_layer_input())) {
  582. ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n";
  583. }
  584. ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "."
  585. << func_graph->debug_info()->get_id() << "\n";
  586. if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
  587. ofs << trace::GetDebugInfo(func_graph->debug_info(), "# ", kSourceLineTipDiscard) << "#"
  588. << label_manage::Label(func_graph->debug_info()) << "\n";
  589. } else {
  590. ofs << trace::GetDebugInfo(func_graph->debug_info(), "# ", kSourceLineTipDiscard) << "\n";
  591. }
  592. ofs << "funcgraph fg_" << func_graph->debug_info()->get_id();
  593. // output name of parent of graph if exists
  594. if (func_graph->parent() != nullptr) {
  595. ofs << "[fg_" << func_graph->parent()->debug_info()->get_id() << "]";
  596. }
  597. ofs << "(\n";
  598. OutputParameters(ofs, parameters, &param_map);
  599. exported[func_graph] = param_map;
  600. ofs << (!parameters.empty() ? " " : "") << ") {\n";
  601. OutputCNodes(ofs, nodes, func_graph);
  602. ofs << "}\n";
  603. OutputOrderList(ofs, func_graph);
  604. OutputIsolateNodes(ofs, func_graph);
  605. }
  606. void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
  607. if (func_graph == nullptr) {
  608. return;
  609. }
  610. std::ofstream ofs(filename);
  611. if (!ofs.is_open()) {
  612. MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
  613. return;
  614. }
  615. param_index = 1;
  616. func_graph_set.add(func_graph);
  617. while (!func_graph_set.empty()) {
  618. FuncGraphPtr fg = *func_graph_set.begin();
  619. ExportOneFuncGraph(ofs, fg);
  620. ofs << "\n\n";
  621. (void)func_graph_set.erase(fg);
  622. }
  623. ofs << "# num of total function graphs: " << exported.size();
  624. ofs.close();
  625. }
  626. void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
  627. if (graphs.empty()) {
  628. return;
  629. }
  630. std::ofstream ofs(filename);
  631. if (!ofs.is_open()) {
  632. MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
  633. return;
  634. }
  635. param_index = 1;
  636. for (const auto &tagged_graph : graphs) {
  637. tagged_cnodes_ = tagged_graph.second;
  638. ExportOneFuncGraph(ofs, tagged_graph.first);
  639. tagged_cnodes_.clear();
  640. ofs << "\n\n";
  641. }
  642. ofs << "# num of total function graphs: " << graphs.size();
  643. ofs.close();
  644. }
  645. #ifdef ENABLE_DUMP_IR
  646. void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) {
  647. if (func_graph == nullptr) {
  648. return;
  649. }
  650. auto real_filename = pipeline::GetSaveGraphsPathName(filename);
  651. AnfExporter exporter(id);
  652. ChangeFileMode(real_filename, S_IRWXU);
  653. exporter.ExportFuncGraph(real_filename, func_graph);
  654. // set file mode to read only by user
  655. ChangeFileMode(real_filename, S_IRUSR);
  656. }
  657. void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
  658. auto real_filename = pipeline::GetSaveGraphsPathName(filename);
  659. AnfExporter exporter("", false);
  660. ChangeFileMode(real_filename, S_IRWXU);
  661. exporter.ExportFuncGraph(real_filename, graphs);
  662. // set file mode to read only by user
  663. ChangeFileMode(real_filename, S_IRUSR);
  664. }
  665. #else
  666. void ExportIR(const std::string &, const std::string &, const FuncGraphPtr &) {
  667. static bool already_printed = false;
  668. if (already_printed) {
  669. return;
  670. }
  671. already_printed = true;
  672. MS_LOG(WARNING) << "The functionality of dumping function graph IR is disabled, "
  673. << "please recompile source to enable it. See help of building script.";
  674. }
  675. void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
  676. static bool already_printed = false;
  677. if (already_printed) {
  678. return;
  679. }
  680. already_printed = true;
  681. MS_LOG(WARNING) << "The functionality of dumping function graph IR is disabled, "
  682. << "please recompile source to enable it. See help of building script.";
  683. }
  684. #endif
  685. // ============================================= MindSpore IR Importer =============================================
  686. enum Token : int {
  687. TOK_INVALID = 0, // invalid token
  688. TOK_LPARENTHESIS, // ( left parenthesis
  689. TOK_RPARENTHESIS, // ) right parenthesis
  690. TOK_LBRACKET, // [ left bracket
  691. TOK_RBRACKET, // ] right bracket
  692. TOK_LBRACE, // { left brace
  693. TOK_RBRACE, // } right brace
  694. TOK_COMMA, // , comma
  695. TOK_EQUALITY, // = equality
  696. TOK_COLON, // : colon
  697. TOK_STAR, // * star
  698. TOK_VARIABLE, // variable
  699. TOK_AT_FILE, // @filename
  700. TOK_PARAMETER, // parameter
  701. TOK_IDENTIFIER, // identifier
  702. TOK_FUNCGRAPH, // keyword 'funcgraph'
  703. TOK_RETURN, // id prim::return
  704. TOK_STRING, // string
  705. TOK_NUMBER, // number
  706. TOK_COMMENT, // comment
  707. TOK_EOL, // end of line
  708. TOK_EOF, // end of file
  709. TOK_ERROR // file read error
  710. };
  711. std::map<Token, const char *> token_text = {
  712. {TOK_INVALID, "invalid"}, // invalid token
  713. {TOK_LPARENTHESIS, "("}, // ( left parenthesis
  714. {TOK_RPARENTHESIS, ")"}, // ) right parenthesis
  715. {TOK_LBRACKET, "["}, // [ left bracket
  716. {TOK_RBRACKET, "]"}, // ] right bracket
  717. {TOK_LBRACE, "{"}, // { left brace
  718. {TOK_RBRACE, "}"}, // } right brace
  719. {TOK_COMMA, ","}, // , comma
  720. {TOK_EQUALITY, "="}, // = equality
  721. {TOK_COLON, ":"}, // : colon
  722. {TOK_STAR, "*"}, // * start
  723. {TOK_VARIABLE, nullptr}, // variable
  724. {TOK_AT_FILE, nullptr}, // @file
  725. {TOK_PARAMETER, nullptr}, // parameter
  726. {TOK_IDENTIFIER, nullptr}, // identifier
  727. {TOK_FUNCGRAPH, "funcgraph"}, // keyword 'funcgraph'
  728. {TOK_RETURN, nullptr}, // id prim::return
  729. {TOK_STRING, nullptr}, // string
  730. {TOK_NUMBER, nullptr}, // number
  731. {TOK_COMMENT, nullptr}, // comment
  732. {TOK_EOL, "\n"}, // end of line
  733. {TOK_EOF, ""}, // end of file
  734. {TOK_ERROR, "error"} // file read error
  735. };
  736. class Lexer {
  737. public:
  738. // filename is checked in ImportIR;
  739. explicit Lexer(const char *filename) : fin(filename) {}
  740. ~Lexer() {
  741. try {
  742. if (fin.is_open()) {
  743. fin.close();
  744. }
  745. } catch (const std::exception &e) {
  746. MS_LOG(ERROR) << "Exception when closing file";
  747. } catch (...) {
  748. std::string exName(abi::__cxa_current_exception_type()->name());
  749. MS_LOG(ERROR) << "Error occurred when closing file. Exception name: " << exName;
  750. }
  751. }
  752. bool IsSingleCharToken(char ch, Token *token_ptr) {
  753. // clang-format off
  754. std::unordered_map<char, Token> char_to_token = {
  755. {'(', TOK_LPARENTHESIS},
  756. {')', TOK_RPARENTHESIS},
  757. {'[', TOK_LBRACKET},
  758. {']', TOK_RBRACKET},
  759. {'{', TOK_LBRACE},
  760. {'}', TOK_RBRACE},
  761. {',', TOK_COMMA},
  762. {'=', TOK_EQUALITY},
  763. {':', TOK_COLON},
  764. {'*', TOK_STAR}};
  765. // clang-format on
  766. auto iter = char_to_token.find(ch);
  767. if (iter == char_to_token.end()) {
  768. return false;
  769. }
  770. if (token_ptr != nullptr) {
  771. *token_ptr = iter->second;
  772. }
  773. return true;
  774. }
  775. Token GetNextToken() {
  776. #ifdef DEBUG
  777. Token token = GetNextTokenInner();
  778. const char *str = token_text[token];
  779. std::string text = (str == nullptr ? GetTokenText() : str);
  780. MS_LOG(DEBUG) << "------Parse token] " << text;
  781. return token;
  782. }
  783. Token GetNextTokenInner() {
  784. #endif
  785. tok_idx = 0;
  786. Token tok = TOK_ERROR;
  787. char ch = SkipTabAndSpace();
  788. if (ch == CODE_EOF) {
  789. return TOK_EOF;
  790. } else if (ch == CODE_ERROR) {
  791. return TOK_ERROR;
  792. } else if (IsSingleCharToken(ch, &tok)) {
  793. return tok;
  794. } else if (ch == '\r') {
  795. char c = GetChar();
  796. if (c == '\n') {
  797. line_++;
  798. return TOK_EOL;
  799. }
  800. UnGetChar(c);
  801. line_++;
  802. return TOK_EOL;
  803. } else if (ch == '\n') {
  804. line_++;
  805. return TOK_EOL;
  806. } else if (ch == '#') {
  807. return ParseComment(ch);
  808. } else if (ch == '"') {
  809. return ParseString();
  810. } else if (ch == '%') {
  811. return ParseVariableOrParameter(ch);
  812. } else if (ch == '@') {
  813. return ParseAtFile();
  814. } else if (IsDigit(ch) || ch == '-') {
  815. return ParseNumber(ch);
  816. } else if (IsAlpha(ch) || ch == '_') {
  817. return ParseIdentifier(ch);
  818. } else {
  819. return TOK_ERROR;
  820. }
  821. }
  822. Token SkipWhiteToken() {
  823. Token tok = GetNextToken();
  824. while (tok == TOK_EOL || tok == TOK_COMMENT) {
  825. tok = GetNextToken();
  826. }
  827. return tok;
  828. }
  829. std::string GetTokenText() const { return std::string(tok_buf); }
  830. int GetLineNo() const { return line_; }
  831. private:
  832. Token ParseComment(char ch) {
  833. char c = GetChar();
  834. while (c != '\r' && c != '\n' && c != CODE_EOF) {
  835. c = GetChar();
  836. }
  837. if (ch != CODE_EOF) {
  838. UnGetChar(c);
  839. }
  840. tok_buf[0] = '#';
  841. tok_buf[1] = '\0';
  842. return TOK_COMMENT;
  843. }
  844. Token ParseString() {
  845. tok_idx = 0;
  846. char c = GetChar();
  847. while (c != '"') {
  848. if (tok_idx >= BUF_SIZE) {
  849. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  850. }
  851. if (c == '\r' || c == '\n') {
  852. MS_LOG(EXCEPTION) << "Literal newline characters are not allowed within the quote at line " << line_;
  853. }
  854. if (c == CODE_EOF) {
  855. MS_LOG(EXCEPTION) << "Encounter EOF within the quote at line " << line_;
  856. }
  857. tok_buf[tok_idx++] = c;
  858. c = GetChar();
  859. }
  860. tok_buf[tok_idx] = '\0';
  861. return TOK_STRING;
  862. }
  863. Token ParseVariableOrParameter(char ch) {
  864. tok_idx = 0;
  865. tok_buf[tok_idx++] = ch;
  866. char c = GetChar();
  867. while (IsAlphaNumeric(c)) {
  868. if (tok_idx >= BUF_SIZE) {
  869. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  870. }
  871. tok_buf[tok_idx++] = c;
  872. c = GetChar();
  873. }
  874. tok_buf[tok_idx] = '\0';
  875. UnGetChar(c);
  876. // judge parameter: %para[0-9]+
  877. tok_buf[tok_idx] = '\0';
  878. std::string param_key = "%para";
  879. if (strncmp(tok_buf, param_key.c_str(), param_key.size()) == 0) {
  880. if (tok_idx <= param_key.size()) {
  881. return TOK_ERROR;
  882. }
  883. for (auto i = static_cast<unsigned>(param_key.size()); i < tok_idx; ++i) {
  884. if (!IsDigit(tok_buf[i])) {
  885. return TOK_ERROR;
  886. }
  887. }
  888. return TOK_PARAMETER;
  889. }
  890. // judge local variable: %[0-9]+
  891. if (tok_idx == 1) {
  892. return TOK_ERROR;
  893. }
  894. for (unsigned i = 1; i < tok_idx; ++i) {
  895. if (!IsDigit(tok_buf[i])) {
  896. return TOK_ERROR;
  897. }
  898. }
  899. return TOK_VARIABLE;
  900. }
  901. Token ParseAtFile() {
  902. tok_idx = 0;
  903. char c = GetChar();
  904. while (IsAlphaNumeric(c) || c == '_' || c == '.') {
  905. if (tok_idx >= BUF_SIZE) {
  906. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  907. }
  908. tok_buf[tok_idx++] = c;
  909. c = GetChar();
  910. }
  911. tok_buf[tok_idx] = '\0';
  912. UnGetChar(c);
  913. if (tok_idx == 0) {
  914. return TOK_ERROR;
  915. }
  916. return TOK_AT_FILE;
  917. }
  918. Token ParseNumber(char ch) {
  919. tok_buf[tok_idx++] = ch;
  920. char c = GetChar();
  921. // parse number, e.g. 10, 15.6, 1e-5
  922. while (IsDigit(c) || c == '.' || c == 'e' || c == '-') {
  923. if (tok_idx >= BUF_SIZE) {
  924. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  925. }
  926. tok_buf[tok_idx++] = c;
  927. c = GetChar();
  928. }
  929. UnGetChar(c);
  930. tok_buf[tok_idx] = '\0';
  931. return TOK_NUMBER;
  932. }
  933. Token ParseIdentifier(char ch) {
  934. tok_idx = 0;
  935. tok_buf[tok_idx++] = ch;
  936. char c = GetChar();
  937. while (IsAlphaNumeric(c) || c == '.' || c == ':' || c == '_') {
  938. if (tok_idx >= BUF_SIZE) {
  939. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  940. }
  941. tok_buf[tok_idx++] = c;
  942. c = GetChar();
  943. }
  944. UnGetChar(c);
  945. tok_buf[tok_idx] = '\0';
  946. if (strcmp(tok_buf, "funcgraph") == 0) {
  947. return TOK_FUNCGRAPH;
  948. }
  949. if (strcmp(tok_buf, "Primitive::return") == 0) {
  950. return TOK_RETURN;
  951. }
  952. return TOK_IDENTIFIER;
  953. }
  954. // Suppose the file only contain ASCII character
  955. char GetChar() {
  956. if (ungot_char != UNGOT_CHAR) {
  957. char ch = ungot_char;
  958. ungot_char = UNGOT_CHAR;
  959. return ch;
  960. }
  961. if (idx >= cnt) {
  962. if (fin.eof()) {
  963. return CODE_EOF;
  964. }
  965. cnt = fin.read(buffer, BUF_SIZE).gcount();
  966. if ((fin.bad() || fin.fail()) && !fin.eof()) {
  967. MS_LOG(EXCEPTION) << "Read file error!";
  968. }
  969. idx = 0;
  970. }
  971. return buffer[idx++];
  972. }
  973. void UnGetChar(char ch) {
  974. if (ungot_char == UNGOT_CHAR) {
  975. ungot_char = ch;
  976. }
  977. }
  978. static bool IsTabOrSpace(char ch) { return ch == ' ' || ch == '\t'; }
  979. static bool IsDigit(char ch) { return ch >= '0' && ch <= '9'; }
  980. static bool IsAlpha(char ch) { return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z'); }
  981. static bool IsAlphaNumeric(char ch) { return IsDigit(ch) || IsAlpha(ch); }
  982. // skip whitespace(including comment) to read a valid character
  983. char SkipTabAndSpace() {
  984. char ch = GetChar();
  985. while (IsTabOrSpace(ch)) {
  986. ch = GetChar();
  987. }
  988. return ch;
  989. }
  990. std::ifstream fin;
  991. static const unsigned BUF_SIZE = 4096; // lexer buffer size
  992. char buffer[BUF_SIZE + 1] = {0}; // buffer for holding text read from text
  993. std::streamsize cnt = 0; // number of valid characters in the buffer
  994. unsigned idx = 0; // index of next char the lexer to read from
  995. char tok_buf[BUF_SIZE + 1] = {0}; // token buffer
  996. unsigned tok_idx = 0; // token buffer index
  997. char ungot_char = UNGOT_CHAR; // store ungot char
  998. static const int CODE_EOF = -1; // return code of GetChar
  999. static const int CODE_ERROR = -2; // read file error
  1000. static const char UNGOT_CHAR = -3; // value of ungot char
  1001. int line_ = 1; // current line number
  1002. };
  1003. const unsigned Lexer::BUF_SIZE;
  1004. class IrParser {
  1005. public:
  1006. explicit IrParser(const char *filename) : lexer_(filename) {}
  1007. ~IrParser() {}
  1008. py::object LoadObject(const std::string &file_name) const {
  1009. std::string pkl_path = GetMsIrPath();
  1010. py::object default_obj = load_obj(pkl_path + "/" + file_name);
  1011. return default_obj;
  1012. }
  1013. void ParseFile() {
  1014. FuncGraphPtr func_graph = ParseFuncGraph();
  1015. while (func_graph != nullptr) {
  1016. func_graphs_.push_back(func_graph);
  1017. func_graph = ParseFuncGraph();
  1018. }
  1019. if (error_flag_) {
  1020. MS_LOG(EXCEPTION) << "Parse Error at line: " << lexer_.GetLineNo();
  1021. }
  1022. MS_LOG(INFO) << "Total graphs: " << func_graphs_.size();
  1023. }
  1024. Token ParseParent(FuncGraphPtr *const parent_ptr) {
  1025. if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
  1026. return TOK_ERROR;
  1027. }
  1028. std::string parent_name = lexer_.GetTokenText();
  1029. // NOTICE: require definition of parent graph must before child graph
  1030. auto iter = func_graphs_map_.find(parent_name);
  1031. if (iter == func_graphs_map_.end()) {
  1032. MS_LOG(EXCEPTION) << "Can not find definition of parent func graph '" << parent_name << "' at line "
  1033. << lexer_.GetLineNo();
  1034. }
  1035. if (parent_ptr != nullptr) {
  1036. *parent_ptr = iter->second;
  1037. }
  1038. if (lexer_.GetNextToken() != TOK_RBRACKET) {
  1039. return TOK_ERROR;
  1040. }
  1041. return lexer_.GetNextToken();
  1042. }
  1043. FuncGraphPtr ParseFuncGraph() {
  1044. cnodes_.clear();
  1045. Token tok = lexer_.SkipWhiteToken();
  1046. if (tok != TOK_FUNCGRAPH) {
  1047. error_flag_ = tok != TOK_EOF;
  1048. return nullptr;
  1049. }
  1050. if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
  1051. error_flag_ = true;
  1052. return nullptr;
  1053. }
  1054. std::string func_graph_name = lexer_.GetTokenText();
  1055. if (func_graphs_map_.find(func_graph_name) == func_graphs_map_.end()) {
  1056. func_graphs_map_[func_graph_name] = std::make_shared<FuncGraph>();
  1057. }
  1058. FuncGraphPtr func_graph = func_graphs_map_[func_graph_name];
  1059. MS_EXCEPTION_IF_NULL(func_graph);
  1060. MS_EXCEPTION_IF_NULL(func_graph->debug_info());
  1061. func_graph->debug_info()->set_name(func_graph_name); // for debugging
  1062. FuncGraphPtr parent = nullptr;
  1063. tok = lexer_.GetNextToken();
  1064. if (tok == TOK_LBRACKET) {
  1065. tok = ParseParent(&parent);
  1066. if (parent != nullptr) {
  1067. parents_map_[func_graph] = parent;
  1068. }
  1069. }
  1070. if (tok != TOK_LPARENTHESIS) {
  1071. error_flag_ = true;
  1072. return nullptr;
  1073. }
  1074. if (ParseParameters(func_graph) == nullptr) {
  1075. error_flag_ = true;
  1076. return nullptr;
  1077. }
  1078. if (lexer_.SkipWhiteToken() != TOK_LBRACE) {
  1079. error_flag_ = true;
  1080. return nullptr;
  1081. }
  1082. // parse statements
  1083. if (ParseStatements(func_graph) == nullptr) {
  1084. error_flag_ = true;
  1085. return nullptr;
  1086. }
  1087. func_graphs_map_[func_graph_name] = func_graph;
  1088. return func_graph;
  1089. }
  1090. FuncGraphPtr ParseStatements(const FuncGraphPtr &func_graph) {
  1091. Token tok = lexer_.SkipWhiteToken();
  1092. while (tok == TOK_VARIABLE) {
  1093. if (ParseStatement(func_graph) == nullptr) {
  1094. return nullptr;
  1095. }
  1096. tok = lexer_.SkipWhiteToken();
  1097. }
  1098. if (tok == TOK_RETURN) {
  1099. return ParseReturn(func_graph);
  1100. }
  1101. return nullptr;
  1102. }
  1103. FuncGraphPtr ParseStatement(FuncGraphPtr func_graph) {
  1104. std::string var_name = lexer_.GetTokenText();
  1105. Token tok = lexer_.GetNextToken();
  1106. AbstractBasePtr type = nullptr;
  1107. if (tok == TOK_COLON) {
  1108. tok = ParseType(func_graph, &type);
  1109. }
  1110. if (tok != TOK_EQUALITY) {
  1111. return nullptr;
  1112. }
  1113. std::vector<AnfNodePtr> inputs;
  1114. AnfNodePtr node = nullptr;
  1115. ValuePtr val = nullptr;
  1116. tok = ParseItem(func_graph, &node, &val);
  1117. if (tok != TOK_LPARENTHESIS) {
  1118. return nullptr;
  1119. }
  1120. inputs.push_back(node);
  1121. int lineno = lexer_.GetLineNo();
  1122. if (ParseArguments(func_graph, &inputs) == nullptr) {
  1123. return nullptr;
  1124. }
  1125. tok = lexer_.GetNextToken();
  1126. if (tok == TOK_COMMENT) {
  1127. tok = lexer_.GetNextToken();
  1128. }
  1129. if (tok != TOK_EOL) {
  1130. return nullptr;
  1131. }
  1132. MS_EXCEPTION_IF_NULL(func_graph);
  1133. cnodes_[var_name] = func_graph->NewCNode(inputs);
  1134. MS_EXCEPTION_IF_NULL(cnodes_[var_name]);
  1135. cnodes_[var_name]->set_debug_info(std::make_shared<NodeDebugInfo>(var_name + "@" + std::to_string(lineno)));
  1136. return func_graph;
  1137. }
  1138. FuncGraphPtr ParseReturn(FuncGraphPtr func_graph) {
  1139. if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
  1140. return nullptr;
  1141. }
  1142. AnfNodePtr input1 = nullptr;
  1143. ValuePtr value = nullptr;
  1144. Token tok = ParseItem(func_graph, &input1, &value, lexer_.GetNextToken());
  1145. int lineno = lexer_.GetLineNo();
  1146. if (tok != TOK_RPARENTHESIS) {
  1147. return nullptr;
  1148. }
  1149. tok = lexer_.GetNextToken();
  1150. if (tok == TOK_COMMENT) {
  1151. tok = lexer_.GetNextToken();
  1152. }
  1153. if (tok != TOK_EOL) {
  1154. return nullptr;
  1155. }
  1156. if (lexer_.SkipWhiteToken() != TOK_RBRACE) {
  1157. return nullptr;
  1158. }
  1159. PrimitivePtr prim = std::make_shared<Primitive>("return");
  1160. ValueNodePtr input0 = std::make_shared<ValueNode>(prim);
  1161. std::vector<AnfNodePtr> inputs;
  1162. inputs.push_back(input0);
  1163. inputs.push_back(input1);
  1164. MS_EXCEPTION_IF_NULL(func_graph);
  1165. CNodePtr ret = func_graph->NewCNode(inputs);
  1166. MS_EXCEPTION_IF_NULL(ret);
  1167. ret->set_debug_info(std::make_shared<NodeDebugInfo>(std::string("ret@") + std::to_string(lineno)));
  1168. func_graph->set_return(ret);
  1169. return func_graph;
  1170. }
  1171. void SetBasicType(TypePtr *ptr, const TypePtr &dtype) const {
  1172. if (ptr == nullptr) {
  1173. return;
  1174. }
  1175. *ptr = dtype;
  1176. }
  1177. void SetTupleType(TypePtr *ptr) {
  1178. if (ptr == nullptr) {
  1179. return;
  1180. }
  1181. *ptr = std::make_shared<Tuple>();
  1182. }
  1183. void SetTupleType(TypePtr *ptr, const TypePtrList &elems) {
  1184. if (ptr == nullptr) {
  1185. return;
  1186. }
  1187. *ptr = std::make_shared<Tuple>(elems);
  1188. }
  1189. void SetArrayType(TypePtr *const ptr, const TypePtr &elem_type, const ShapeVector &) {
  1190. if (ptr == nullptr) {
  1191. return;
  1192. }
  1193. *ptr = std::make_shared<TensorType>(elem_type);
  1194. }
  1195. void SetListType(TypePtr *ptr) {
  1196. if (ptr == nullptr) {
  1197. return;
  1198. }
  1199. *ptr = std::make_shared<List>();
  1200. }
  1201. void SetListType(TypePtr *ptr, const TypePtrList &elems) {
  1202. if (ptr == nullptr) {
  1203. return;
  1204. }
  1205. *ptr = std::make_shared<List>(elems);
  1206. }
  1207. void SetJTaggedType(TypePtr *ptr, const TypePtr &elem) {
  1208. if (ptr == nullptr) {
  1209. return;
  1210. }
  1211. *ptr = std::make_shared<JTagged>(elem);
  1212. }
  1213. void SetBasicType(AbstractBasePtr *ptr, const TypePtr &dtype) const {
  1214. if (ptr == nullptr) {
  1215. return;
  1216. }
  1217. *ptr = std::make_shared<abstract::AbstractScalar>(dtype);
  1218. }
  1219. // void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {}
  1220. void SetBasicType(AbstractBasePtr *const ptr, const TypeNonePtr &) const {
  1221. if (ptr == nullptr) {
  1222. return;
  1223. }
  1224. *ptr = std::make_shared<abstract::AbstractNone>();
  1225. }
  1226. void SetBasicType(AbstractBasePtr *, const FunctionPtr &) const {}
  1227. void SetBasicType(AbstractBasePtr *, const TensorTypePtr &) const {}
  1228. void SetTupleType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
  1229. if (ptr == nullptr) {
  1230. return;
  1231. }
  1232. // if one of elems is nullptr, just return
  1233. if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
  1234. return;
  1235. }
  1236. *ptr = std::make_shared<abstract::AbstractTuple>(elems);
  1237. }
  1238. void SetArrayType(AbstractBasePtr *const ptr, const TypePtr &elem_type, const ShapeVector &shape) {
  1239. if (ptr == nullptr) {
  1240. return;
  1241. }
  1242. *ptr = std::make_shared<abstract::AbstractTensor>(elem_type, shape);
  1243. }
  1244. void SetListType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
  1245. if (ptr == nullptr) {
  1246. return;
  1247. }
  1248. if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
  1249. return;
  1250. }
  1251. *ptr = std::make_shared<abstract::AbstractList>(elems);
  1252. }
  1253. void SetJTaggedType(AbstractBasePtr *const ptr, const AbstractBasePtr &elem) {
  1254. if (ptr == nullptr) {
  1255. return;
  1256. }
  1257. *ptr = std::make_shared<abstract::AbstractJTagged>(elem);
  1258. }
  1259. template <typename T>
  1260. Token ParseTypeVector(const FuncGraphPtr &func_graph, Token tok, const std::string &type, T *const ptr = nullptr) {
  1261. if (tok != TOK_LBRACKET) {
  1262. MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol.";
  1263. return tok;
  1264. }
  1265. bool first_flag = true;
  1266. std::vector<T> elems;
  1267. do {
  1268. tok = lexer_.GetNextToken();
  1269. if (first_flag) {
  1270. if (tok == TOK_RBRACKET) {
  1271. return lexer_.GetNextToken();
  1272. }
  1273. first_flag = false;
  1274. }
  1275. T elem = nullptr;
  1276. tok = ParseOneType(func_graph, tok, &elem);
  1277. elems.push_back(elem);
  1278. if (tok == TOK_STAR) {
  1279. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1280. return TOK_ERROR;
  1281. }
  1282. int num_elems = StringToScalar<int>(lexer_.GetTokenText());
  1283. if (num_elems < 1 || num_elems > NUM_MAX_SEQUENCE_ELEMS) {
  1284. MS_LOG(EXCEPTION) << "Number of elements " << num_elems << " is out of range [1, " << NUM_MAX_SEQUENCE_ELEMS
  1285. << "]";
  1286. }
  1287. for (int i = 0; i < num_elems - 1; ++i) {
  1288. elems.push_back(elem);
  1289. }
  1290. tok = lexer_.GetNextToken();
  1291. }
  1292. } while (tok == TOK_COMMA);
  1293. if (tok != TOK_RBRACKET) {
  1294. return TOK_ERROR;
  1295. }
  1296. if (type == "Tuple") {
  1297. SetTupleType(ptr, elems);
  1298. } else if (type == "List") {
  1299. SetListType(ptr, elems);
  1300. } else {
  1301. MS_LOG(EXCEPTION) << "This method does not support " << type << " parse.";
  1302. }
  1303. return lexer_.GetNextToken();
  1304. }
  1305. template <typename T>
  1306. Token ParseTypeArray(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
  1307. if (tok != TOK_LPARENTHESIS) {
  1308. if (ptr != nullptr) {
  1309. SetBasicType(ptr, std::make_shared<TensorType>());
  1310. }
  1311. return tok;
  1312. }
  1313. // process Array element type
  1314. TypePtr elem_type = nullptr;
  1315. ShapeVector shape;
  1316. tok = ParseOneType(func_graph, lexer_.GetNextToken(), &elem_type);
  1317. if (tok != TOK_RPARENTHESIS) {
  1318. return TOK_ERROR;
  1319. }
  1320. tok = lexer_.GetNextToken();
  1321. if (tok != TOK_LBRACKET) {
  1322. // NOTICE: if shape.size == 0, is this ok?
  1323. SetArrayType(ptr, elem_type, shape);
  1324. return tok;
  1325. }
  1326. // process Array shape
  1327. do {
  1328. tok = lexer_.GetNextToken();
  1329. // case: Array(I32)[]
  1330. if (tok != TOK_NUMBER) {
  1331. break;
  1332. }
  1333. shape.push_back(StringToScalar<int>(lexer_.GetTokenText()));
  1334. tok = lexer_.GetNextToken();
  1335. } while (tok == TOK_COMMA);
  1336. if (tok != TOK_RBRACKET) {
  1337. return TOK_ERROR;
  1338. }
  1339. SetArrayType(ptr, elem_type, shape);
  1340. return lexer_.GetNextToken();
  1341. }
  1342. bool IsNumberType(const std::string &type, TypeId *typeid_ptr) {
  1343. // clang-format off
  1344. static std::unordered_map<std::string, TypeId> basic_types = {
  1345. {"Bool", kNumberTypeBool},
  1346. {"I8", kNumberTypeInt8},
  1347. {"I16", kNumberTypeInt16},
  1348. {"I32", kNumberTypeInt32},
  1349. {"I64", kNumberTypeInt64},
  1350. {"U8", kNumberTypeUInt8},
  1351. {"U16", kNumberTypeUInt16},
  1352. {"U32", kNumberTypeUInt32},
  1353. {"U64", kNumberTypeUInt64},
  1354. {"F16", kNumberTypeFloat16},
  1355. {"F32", kNumberTypeFloat32},
  1356. {"F64", kNumberTypeFloat64},
  1357. {"Int", kNumberTypeInt},
  1358. {"UInt", kNumberTypeUInt},
  1359. {"Float", kNumberTypeFloat},
  1360. {"Number", kObjectTypeNumber}};
  1361. // clang-format on
  1362. auto iter = basic_types.find(type);
  1363. if (iter == basic_types.end()) {
  1364. return false;
  1365. }
  1366. if (typeid_ptr != nullptr) {
  1367. *typeid_ptr = iter->second;
  1368. }
  1369. return true;
  1370. }
  1371. template <typename T>
  1372. void ParseNumberType(const std::string &type, TypeId typeId, T *const ptr = nullptr) {
  1373. TypePtr dtype = nullptr;
  1374. std::unordered_map<int, TypePtr> type_map = {
  1375. {static_cast<int>(kNumberTypeBool), std::make_shared<Bool>()}, // Bool
  1376. {static_cast<int>(kNumberTypeInt8), std::make_shared<Int>(8)}, // Int8
  1377. {static_cast<int>(kNumberTypeInt16), std::make_shared<Int>(16)}, // Int16
  1378. {static_cast<int>(kNumberTypeInt32), std::make_shared<Int>(32)}, // Int32
  1379. {static_cast<int>(kNumberTypeInt64), std::make_shared<Int>(64)}, // Int64
  1380. {static_cast<int>(kNumberTypeUInt8), std::make_shared<UInt>(8)}, // UInt8
  1381. {static_cast<int>(kNumberTypeUInt16), std::make_shared<UInt>(16)}, // UInt16
  1382. {static_cast<int>(kNumberTypeUInt32), std::make_shared<UInt>(32)}, // UInt32
  1383. {static_cast<int>(kNumberTypeUInt64), std::make_shared<UInt>(64)}, // UInt64
  1384. {static_cast<int>(kNumberTypeFloat16), std::make_shared<Float>(16)}, // Float16
  1385. {static_cast<int>(kNumberTypeFloat32), std::make_shared<Float>(32)}, // Float32
  1386. {static_cast<int>(kNumberTypeFloat64), std::make_shared<Float>(64)}, // Float64
  1387. {static_cast<int>(kNumberTypeInt), std::make_shared<Int>()}, // Int
  1388. {static_cast<int>(kNumberTypeUInt), std::make_shared<UInt>()}, // UInt
  1389. {static_cast<int>(kNumberTypeFloat), std::make_shared<Float>()}, // Float
  1390. {static_cast<int>(kObjectTypeNumber), std::make_shared<Number>()}, // Number
  1391. };
  1392. auto iter = type_map.find(static_cast<int>(typeId));
  1393. if (iter != type_map.end()) {
  1394. dtype = iter->second;
  1395. } else {
  1396. MS_LOG(EXCEPTION) << "Unknown number type " << type;
  1397. }
  1398. SetBasicType(ptr, dtype);
  1399. }
  1400. template <typename T>
  1401. Token ParseTrivalType(const std::string &type, T *const ptr = nullptr) {
  1402. if (type == "NoneType") {
  1403. SetBasicType(ptr, std::make_shared<TypeNone>());
  1404. return lexer_.GetNextToken();
  1405. } else if (type == "ProblemType") {
  1406. SetBasicType(ptr, std::make_shared<Problem>());
  1407. return lexer_.GetNextToken();
  1408. } else if (type == "ExternalType") {
  1409. SetBasicType(ptr, std::make_shared<External>());
  1410. return lexer_.GetNextToken();
  1411. } else if (type == "AnythingType") {
  1412. SetBasicType(ptr, kAnyType);
  1413. return lexer_.GetNextToken();
  1414. } else if (type == "TypeType") {
  1415. SetBasicType(ptr, std::make_shared<TypeType>());
  1416. return lexer_.GetNextToken();
  1417. } else {
  1418. MS_LOG(EXCEPTION) << "Unknown type error at line " << lexer_.GetLineNo();
  1419. }
  1420. }
  1421. template <typename T>
  1422. Token ParseOneType(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
  1423. if (tok != TOK_IDENTIFIER) {
  1424. return TOK_ERROR;
  1425. }
  1426. std::string type = lexer_.GetTokenText();
  1427. TypeId typeId = kTypeUnknown;
  1428. if (IsNumberType(type, &typeId)) {
  1429. ParseNumberType(type, typeId, ptr);
  1430. return lexer_.GetNextToken();
  1431. } else if (type == "Tuple") {
  1432. return ParseTypeVector(func_graph, lexer_.GetNextToken(), type, ptr);
  1433. } else if (type == "Tensor") {
  1434. return ParseTypeArray(func_graph, lexer_.GetNextToken(), ptr);
  1435. } else if (type == "List") {
  1436. return ParseTypeVector(func_graph, lexer_.GetNextToken(), type, ptr);
  1437. } else if (type == "Func") {
  1438. tok = lexer_.GetNextToken();
  1439. if (tok != TOK_LBRACKET) {
  1440. SetBasicType(ptr, std::make_shared<Function>());
  1441. return tok;
  1442. }
  1443. MS_LOG(EXCEPTION) << "Need to process function parameter types at line " << lexer_.GetLineNo();
  1444. } else if (type == "JT") {
  1445. tok = lexer_.GetNextToken();
  1446. if (tok != TOK_LBRACKET) {
  1447. return tok;
  1448. }
  1449. T elem = nullptr;
  1450. tok = ParseOneType(func_graph, lexer_.GetNextToken(), &elem);
  1451. SetJTaggedType(ptr, elem);
  1452. if (tok != TOK_RBRACKET) {
  1453. return TOK_ERROR;
  1454. }
  1455. return lexer_.GetNextToken();
  1456. } else if (type == "SymType") {
  1457. SetBasicType(ptr, std::make_shared<SymbolicKeyType>());
  1458. return lexer_.GetNextToken();
  1459. } else if (type == "EnvType") {
  1460. SetBasicType(ptr, std::make_shared<EnvType>());
  1461. return lexer_.GetNextToken();
  1462. } else if (Match(type, "Cls.")) {
  1463. MS_LOG(EXCEPTION) << "Need to do class type at line " << lexer_.GetLineNo();
  1464. } else {
  1465. return ParseTrivalType(type, ptr);
  1466. }
  1467. }
  1468. Token ParseType(const FuncGraphPtr &func_graph, AbstractBasePtr *const abstract = nullptr) {
  1469. return ParseOneType(func_graph, lexer_.GetNextToken(), abstract);
  1470. }
  1471. Token ParseAttributes(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
  1472. Token tok = ParseAttribute(func_graph, prim);
  1473. while (tok == TOK_COMMA) {
  1474. tok = ParseAttribute(func_graph, prim);
  1475. }
  1476. if (tok != TOK_RBRACKET) {
  1477. return TOK_ERROR;
  1478. }
  1479. return lexer_.GetNextToken();
  1480. }
  1481. Token ParseAttribute(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
  1482. Token tok = lexer_.GetNextToken();
  1483. if (tok != TOK_IDENTIFIER) {
  1484. return TOK_ERROR;
  1485. }
  1486. std::string attr_name = lexer_.GetTokenText();
  1487. if (lexer_.GetNextToken() != TOK_EQUALITY) {
  1488. return TOK_ERROR;
  1489. }
  1490. ValuePtr value = nullptr;
  1491. tok = ParseValue(func_graph, lexer_.GetNextToken(), &value);
  1492. if (prim != nullptr) {
  1493. prim->set_attr(attr_name, value);
  1494. } else {
  1495. MS_LOG(EXCEPTION) << "Non primitive obj has attributes";
  1496. }
  1497. return tok;
  1498. }
  1499. FuncGraphPtr ParseParameters(FuncGraphPtr func_graph) {
  1500. Token tok = lexer_.SkipWhiteToken();
  1501. while (tok == TOK_PARAMETER) {
  1502. ParameterPtr param = std::make_shared<Parameter>(func_graph);
  1503. param->set_name(lexer_.GetTokenText());
  1504. param_nodes_[lexer_.GetTokenText()] = param;
  1505. int lineno = lexer_.GetLineNo();
  1506. param->set_debug_info(std::make_shared<NodeDebugInfo>(lexer_.GetTokenText() + "@" + std::to_string(lineno)));
  1507. func_graph->add_parameter(param);
  1508. tok = lexer_.GetNextToken();
  1509. // parse type
  1510. if (tok == TOK_COLON) {
  1511. AbstractBasePtr type = nullptr;
  1512. tok = ParseType(func_graph, &type);
  1513. }
  1514. // parse default value
  1515. if (tok == TOK_EQUALITY) {
  1516. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1517. MS_LOG(EXCEPTION) << "Expect @file at line " << lexer_.GetLineNo();
  1518. }
  1519. // load parameter default value from serialized file
  1520. py::object default_obj = LoadObject(lexer_.GetTokenText());
  1521. auto param_value_new = py::cast<tensor::TensorPtr>(default_obj);
  1522. param->set_default_param(param_value_new);
  1523. tok = lexer_.GetNextToken();
  1524. }
  1525. if (tok == TOK_COMMENT || tok == TOK_EOL) {
  1526. tok = lexer_.SkipWhiteToken();
  1527. }
  1528. Token next = tok;
  1529. if (next == TOK_RPARENTHESIS) {
  1530. return func_graph;
  1531. } else if (next == TOK_COMMA) {
  1532. tok = lexer_.SkipWhiteToken();
  1533. } else {
  1534. return nullptr;
  1535. }
  1536. }
  1537. return tok == TOK_RPARENTHESIS ? func_graph : nullptr;
  1538. }
  1539. FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
  1540. Token tok = ParseArgument(func_graph, inputs_ptr);
  1541. while (tok == TOK_COMMA) {
  1542. tok = ParseArgument(func_graph, inputs_ptr);
  1543. }
  1544. if (tok != TOK_RPARENTHESIS) {
  1545. return nullptr;
  1546. }
  1547. return func_graph;
  1548. }
  1549. AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string &param_name) {
  1550. while (func_graph != nullptr) {
  1551. for (auto &ptr : func_graph->parameters()) {
  1552. MS_EXCEPTION_IF_NULL(ptr);
  1553. ParameterPtr param = ptr->cast<ParameterPtr>();
  1554. MS_EXCEPTION_IF_NULL(param);
  1555. if (param->name() == param_name) {
  1556. return ptr;
  1557. }
  1558. }
  1559. auto iter = parents_map_.find(func_graph);
  1560. if (iter == parents_map_.end()) {
  1561. break;
  1562. }
  1563. func_graph = iter->second;
  1564. }
  1565. return nullptr;
  1566. }
  1567. bool Match(const std::string &str, const std::string &pattern) const {
  1568. return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0;
  1569. }
  1570. template <typename T, typename V>
  1571. Token ParseScalar(ValuePtr *const val_ptr) {
  1572. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1573. return TOK_ERROR;
  1574. }
  1575. std::stringstream ss;
  1576. ss << lexer_.GetTokenText();
  1577. if (lexer_.GetNextToken() != TOK_RPARENTHESIS) {
  1578. return TOK_ERROR;
  1579. }
  1580. V val;
  1581. ss >> val;
  1582. *val_ptr = std::make_shared<T>(val);
  1583. return lexer_.GetNextToken();
  1584. }
  1585. template <typename VT, typename V, typename T>
  1586. Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
  1587. if (tok != TOK_LPARENTHESIS) {
  1588. *val_ptr = std::make_shared<T>();
  1589. return tok;
  1590. }
  1591. return ParseScalar<VT, V>(val_ptr);
  1592. }
  1593. template <typename VT, typename V, typename T, const unsigned nbits>
  1594. Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
  1595. if (tok != TOK_LPARENTHESIS) {
  1596. *val_ptr = std::make_shared<T>(nbits);
  1597. return tok;
  1598. }
  1599. return ParseScalar<VT, V>(val_ptr);
  1600. }
  1601. template <typename T>
  1602. T StringToScalar(const std::string &text) {
  1603. std::stringstream ss;
  1604. T value;
  1605. ss << text;
  1606. ss >> value;
  1607. return value;
  1608. }
  1609. Token ParseTensor(ValuePtr *const val_ptr) {
  1610. // parse type
  1611. TypeId type;
  1612. if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
  1613. return TOK_ERROR;
  1614. }
  1615. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1616. return TOK_ERROR;
  1617. }
  1618. type = static_cast<TypeId>(StringToScalar<int>(lexer_.GetTokenText()));
  1619. if (lexer_.GetNextToken() != TOK_RPARENTHESIS) {
  1620. return TOK_ERROR;
  1621. }
  1622. // parse shape
  1623. ShapeVector shape;
  1624. Token tok = lexer_.GetNextToken();
  1625. if (tok != TOK_LBRACKET) {
  1626. return TOK_ERROR;
  1627. }
  1628. do {
  1629. tok = lexer_.GetNextToken();
  1630. // consider case: Tensor(23)[]
  1631. if (tok != TOK_NUMBER) {
  1632. break;
  1633. }
  1634. shape.push_back(StringToScalar<int>(lexer_.GetTokenText()));
  1635. tok = lexer_.GetNextToken();
  1636. } while (tok == TOK_COMMA);
  1637. if (tok != TOK_RBRACKET) {
  1638. return TOK_ERROR;
  1639. }
  1640. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1641. return TOK_ERROR;
  1642. }
  1643. py::object tensor_obj = LoadObject(lexer_.GetTokenText());
  1644. py::array tensor_data = py::cast<py::array>(tensor_obj);
  1645. if (!tensor_data) {
  1646. return TOK_ERROR;
  1647. }
  1648. *val_ptr = TensorPy::MakeTensor(tensor_data, TypeIdToType(type));
  1649. return lexer_.GetNextToken();
  1650. }
  1651. Token ParsePrimType(Token tok, PrimType *prim_type_ptr) {
  1652. if (tok != TOK_LBRACE) {
  1653. return tok;
  1654. }
  1655. if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
  1656. return TOK_ERROR;
  1657. }
  1658. if (lexer_.GetTokenText() != "prim_type") {
  1659. return TOK_ERROR;
  1660. }
  1661. if (lexer_.GetNextToken() != TOK_EQUALITY) {
  1662. return TOK_ERROR;
  1663. }
  1664. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1665. return TOK_ERROR;
  1666. }
  1667. int val = 0;
  1668. std::stringstream ss;
  1669. ss << lexer_.GetTokenText();
  1670. ss >> val;
  1671. *prim_type_ptr = PrimType(val);
  1672. if (lexer_.GetNextToken() != TOK_RBRACE) {
  1673. return TOK_ERROR;
  1674. }
  1675. return lexer_.GetNextToken();
  1676. }
  1677. Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
  1678. if (tok != TOK_LPARENTHESIS) {
  1679. return TOK_ERROR;
  1680. }
  1681. TypePtrList type_list;
  1682. do {
  1683. TypePtr type = nullptr;
  1684. tok = ParseOneType(nullptr, lexer_.GetNextToken(), &type);
  1685. type_list.push_back(type);
  1686. } while (tok == TOK_COMMA);
  1687. if (tok != TOK_RPARENTHESIS) {
  1688. return TOK_ERROR;
  1689. }
  1690. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1691. return TOK_ERROR;
  1692. }
  1693. // load Python function from serialized file
  1694. py::object py_func = LoadObject(lexer_.GetTokenText());
  1695. MS_EXCEPTION_IF_NULL(mt_func_graph);
  1696. mt_func_graph->Register(type_list, py::function(py_func));
  1697. return lexer_.GetNextToken();
  1698. }
  1699. Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
  1700. if (tok != TOK_LBRACE) {
  1701. return tok;
  1702. }
  1703. do {
  1704. tok = ParseMultitypeFuncGraphItem(mt_func_graph, lexer_.GetNextToken());
  1705. } while (tok == TOK_COMMA);
  1706. if (tok != TOK_RBRACE) {
  1707. return TOK_ERROR;
  1708. }
  1709. return lexer_.GetNextToken();
  1710. }
  1711. Token ParseBoolValue(const std::string &key, bool *val_ptr) {
  1712. if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) {
  1713. return TOK_ERROR;
  1714. }
  1715. if (lexer_.GetNextToken() != TOK_EQUALITY) {
  1716. return TOK_ERROR;
  1717. }
  1718. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1719. return TOK_ERROR;
  1720. }
  1721. bool value = false;
  1722. {
  1723. std::stringstream ss;
  1724. ss << lexer_.GetTokenText();
  1725. ss >> value;
  1726. }
  1727. if (val_ptr != nullptr) {
  1728. *val_ptr = value;
  1729. }
  1730. return lexer_.GetNextToken();
  1731. }
  1732. Token ParseValueGradOperation(const std::string &name, ValuePtr *const val_ptr) {
  1733. if (lexer_.GetNextToken() != TOK_LBRACE) {
  1734. return TOK_ERROR;
  1735. }
  1736. // get_all=0, get_by_list=1, sens_param=1
  1737. bool get_all = false;
  1738. Token tok = ParseBoolValue("get_all", &get_all);
  1739. if (tok != TOK_COMMA) {
  1740. return TOK_ERROR;
  1741. }
  1742. bool get_by_list = false;
  1743. tok = ParseBoolValue("get_by_list", &get_by_list);
  1744. if (tok != TOK_COMMA) {
  1745. return TOK_ERROR;
  1746. }
  1747. bool sens_param = false;
  1748. tok = ParseBoolValue("sens_param", &sens_param);
  1749. if (tok != TOK_RBRACE) {
  1750. return TOK_ERROR;
  1751. }
  1752. *val_ptr = std::make_shared<prim::GradOperation>(name, get_all, get_by_list, sens_param);
  1753. return lexer_.GetNextToken();
  1754. }
  1755. Token ParseSymbolicKeyInstance(const FuncGraphPtr &func_graph, AnfNodePtr *const node_ptr = nullptr) {
  1756. if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
  1757. return TOK_ERROR;
  1758. }
  1759. if (lexer_.GetNextToken() != TOK_PARAMETER) {
  1760. return TOK_ERROR;
  1761. }
  1762. std::string param_name = lexer_.GetTokenText();
  1763. if (lexer_.GetNextToken() != TOK_RPARENTHESIS) {
  1764. return TOK_ERROR;
  1765. }
  1766. auto iter = param_nodes_.find(param_name);
  1767. if (iter == param_nodes_.end()) {
  1768. MS_LOG(EXCEPTION) << "Can not find param '" << param_name << "' for SymbolicKeyInstance at line "
  1769. << lexer_.GetLineNo();
  1770. }
  1771. PrimitivePtr embed = std::make_shared<Primitive>("embed");
  1772. std::vector<AnfNodePtr> inputs;
  1773. inputs.push_back(std::make_shared<ValueNode>(embed));
  1774. inputs.push_back(iter->second);
  1775. if (node_ptr != nullptr) {
  1776. MS_EXCEPTION_IF_NULL(func_graph);
  1777. *node_ptr = func_graph->NewCNode(inputs);
  1778. } else {
  1779. MS_LOG(EXCEPTION) << "Not processed SymbolicKeyInstance '" << param_name << "' at line " << lexer_.GetLineNo()
  1780. << ".";
  1781. }
  1782. return lexer_.GetNextToken();
  1783. }
  1784. Token ParsePrimitivePy(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr) {
  1785. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1786. return TOK_ERROR;
  1787. }
  1788. // restore python function of PrimitivePy from serialized file
  1789. py::object py_obj = LoadObject(lexer_.GetTokenText());
  1790. PrimitivePyPtr ptr = nullptr;
  1791. if (py::hasattr(py_obj, "__setattr_flag__") && py::hasattr(py_obj, "_clone")) {
  1792. auto clone_fn = py_obj.attr("_clone");
  1793. py::object new_obj = clone_fn();
  1794. ptr = new_obj.cast<PrimitivePyPtr>();
  1795. if (ptr == nullptr) {
  1796. MS_LOG(EXCEPTION) << "Cast to type 'PrimitivePyPtr' error";
  1797. }
  1798. } else {
  1799. auto len = strlen("PrimitivePy::");
  1800. if (id.size() < len) {
  1801. return TOK_ERROR;
  1802. }
  1803. ptr = std::make_shared<PrimitivePy>(id.substr(len), py_obj);
  1804. }
  1805. *val_ptr = ptr;
  1806. PrimType prim_type = kPrimTypeUnknown;
  1807. Token next = ParsePrimType(lexer_.GetNextToken(), &prim_type);
  1808. if (prim_type != kPrimTypeUnknown) {
  1809. ptr->set_prim_type(prim_type);
  1810. }
  1811. if (next != TOK_LBRACKET) {
  1812. return next;
  1813. }
  1814. // parse attributes
  1815. next = ParseAttributes(func_graph, ptr);
  1816. return next;
  1817. }
  1818. Token ParseValueGraphAndNamespace(const std::string &id, ValuePtr *const val_ptr) {
  1819. if (Match(id, "MultitypeFuncGraph::")) {
  1820. std::string name = id.substr(strlen("MultitypeFuncGraph::"));
  1821. auto mt_func_graph = std::make_shared<prim::MultitypeFuncGraph>(name);
  1822. *val_ptr = mt_func_graph;
  1823. Token next = ParseMultitypeFuncGraph(mt_func_graph, lexer_.GetNextToken());
  1824. return next;
  1825. } else if (Match(id, "HyperMapPy::")) {
  1826. *val_ptr = std::make_shared<prim::HyperMapPy>();
  1827. Token next = lexer_.GetNextToken();
  1828. // process case: fn_leaf is not null
  1829. if (next == TOK_LBRACE) {
  1830. MS_LOG(EXCEPTION) << "Need to process fn_leaf at line " << lexer_.GetLineNo();
  1831. }
  1832. return next;
  1833. } else if (Match(id, "FuncGraph::")) {
  1834. std::string func_graph_name = id.substr(strlen("FuncGraph::"));
  1835. // if the graph does not exist, create a null graph, then fill the graph when encounter the definition
  1836. // of the graph
  1837. if (func_graphs_map_.find(func_graph_name) == func_graphs_map_.end()) {
  1838. func_graphs_map_[func_graph_name] = std::make_shared<FuncGraph>();
  1839. }
  1840. *val_ptr = func_graphs_map_[func_graph_name];
  1841. return lexer_.GetNextToken();
  1842. } else if (Match(id, "NameSpace::")) {
  1843. std::string module_name = id.substr(strlen("NameSpace::"));
  1844. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1845. MS_LOG(ERROR) << "Expect TOK_AT_FILE at line " << lexer_.GetLineNo();
  1846. return TOK_ERROR;
  1847. }
  1848. // load Python module information from serialized file
  1849. py::object py_obj = LoadObject(lexer_.GetTokenText());
  1850. *val_ptr = std::make_shared<parse::NameSpace>(module_name, py_obj);
  1851. return lexer_.GetNextToken();
  1852. } else {
  1853. MS_LOG(EXCEPTION) << "Unknown id " << id << " at line " << lexer_.GetLineNo();
  1854. }
  1855. }
  1856. Token ParseValueBasic(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr,
  1857. AnfNodePtr *const node_ptr = nullptr) {
  1858. if (id == "None") {
  1859. *val_ptr = std::make_shared<None>();
  1860. return lexer_.GetNextToken();
  1861. } else if (id == "Bool") {
  1862. return ParseScalar<BoolImm, bool, Bool>(val_ptr, lexer_.GetNextToken());
  1863. } else if (id == "I8") {
  1864. return ParseScalar<Int8Imm, int8_t, Int, 8>(val_ptr, lexer_.GetNextToken());
  1865. } else if (id == "I16") {
  1866. return ParseScalar<Int16Imm, int16_t, Int, 16>(val_ptr, lexer_.GetNextToken());
  1867. } else if (id == "I32") {
  1868. return ParseScalar<Int32Imm, int32_t, Int, 32>(val_ptr, lexer_.GetNextToken());
  1869. } else if (id == "I64") {
  1870. return ParseScalar<Int64Imm, int64_t, Int, 64>(val_ptr, lexer_.GetNextToken());
  1871. } else if (id == "U8") {
  1872. return ParseScalar<UInt8Imm, uint8_t, UInt, 8>(val_ptr, lexer_.GetNextToken());
  1873. } else if (id == "U16") {
  1874. return ParseScalar<UInt16Imm, uint16_t, UInt, 16>(val_ptr, lexer_.GetNextToken());
  1875. } else if (id == "U32") {
  1876. return ParseScalar<UInt32Imm, uint32_t, UInt, 32>(val_ptr, lexer_.GetNextToken());
  1877. } else if (id == "U64") {
  1878. return ParseScalar<UInt64Imm, uint64_t, UInt, 64>(val_ptr, lexer_.GetNextToken());
  1879. } else if (id == "F16") {
  1880. // Notice: Since there is no basic data type for storing fp16, just use float instead
  1881. return ParseScalar<FP32Imm, float, Float, 16>(val_ptr, lexer_.GetNextToken());
  1882. } else if (id == "F32") {
  1883. return ParseScalar<FP32Imm, float, Float, 32>(val_ptr, lexer_.GetNextToken());
  1884. } else if (id == "F64") {
  1885. return ParseScalar<FP64Imm, double, Float, 64>(val_ptr, lexer_.GetNextToken());
  1886. } else if (id == "Tensor") {
  1887. return ParseTensor(val_ptr);
  1888. } else if (id == "SymInst") {
  1889. return ParseSymbolicKeyInstance(func_graph, node_ptr);
  1890. } else if (id == "Array") {
  1891. TypePtr type = nullptr;
  1892. Token ret = ParseTypeArray(func_graph, lexer_.GetNextToken(), &type);
  1893. *val_ptr = type;
  1894. return ret;
  1895. } else if (Match(id, "PrimitivePy::")) {
  1896. return ParsePrimitivePy(func_graph, id, val_ptr);
  1897. } else if (Match(id, "Primitive::")) {
  1898. *val_ptr = std::make_shared<Primitive>(id.substr(strlen("Primitive::")));
  1899. return lexer_.GetNextToken();
  1900. } else if (Match(id, "GradOperation::")) {
  1901. return ParseValueGradOperation(id.substr(strlen("GradOperation::")), val_ptr);
  1902. } else {
  1903. return ParseValueGraphAndNamespace(id, val_ptr);
  1904. }
  1905. }
  1906. Token SetListOrTupleValue(const FuncGraphPtr &func_graph, Token left_tok, Token next, bool node_is_valid,
  1907. const std::vector<ValuePtr> &elems, const std::vector<AnfNodePtr> &nodes,
  1908. ValuePtr *const val_ptr, AnfNodePtr *node_ptr) {
  1909. if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) {
  1910. if (node_is_valid && node_ptr != nullptr) {
  1911. MS_EXCEPTION_IF_NULL(func_graph);
  1912. *node_ptr = func_graph->NewCNode(nodes);
  1913. } else {
  1914. *val_ptr = std::make_shared<ValueTuple>(elems);
  1915. }
  1916. return lexer_.GetNextToken();
  1917. } else if (left_tok == TOK_LBRACKET && next == TOK_RBRACKET) {
  1918. if (node_is_valid && node_ptr != nullptr) {
  1919. MS_LOG(EXCEPTION) << "Encounter valid node in value list";
  1920. }
  1921. *val_ptr = std::make_shared<ValueList>(elems);
  1922. return lexer_.GetNextToken();
  1923. } else {
  1924. return TOK_ERROR;
  1925. }
  1926. }
  1927. Token ParseListOrTupleValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr,
  1928. AnfNodePtr *node_ptr = nullptr) {
  1929. Token left_tok = tok;
  1930. std::vector<ValuePtr> elems;
  1931. std::vector<AnfNodePtr> nodes;
  1932. nodes.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>("make_tuple")));
  1933. ValuePtr elem = nullptr;
  1934. AnfNodePtr node = nullptr;
  1935. bool node_is_valid = false;
  1936. bool first_flag = true;
  1937. Token next = TOK_ERROR;
  1938. do {
  1939. next = lexer_.GetNextToken();
  1940. if (first_flag) {
  1941. first_flag = false;
  1942. // case (), zero elements
  1943. if ((left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) ||
  1944. (left_tok == TOK_LBRACKET && next == TOK_RBRACKET)) {
  1945. if (left_tok == TOK_LPARENTHESIS) {
  1946. *val_ptr = std::make_shared<ValueTuple>(elems);
  1947. } else {
  1948. *val_ptr = std::make_shared<ValueList>(elems);
  1949. }
  1950. return lexer_.GetNextToken();
  1951. }
  1952. }
  1953. node = nullptr;
  1954. next = ParseValue(func_graph, next, &elem, &node);
  1955. elems.push_back(elem);
  1956. if (node != nullptr) {
  1957. nodes.push_back(node);
  1958. node_is_valid = true;
  1959. } else {
  1960. nodes.push_back(std::make_shared<ValueNode>(elem));
  1961. }
  1962. } while (next == TOK_COMMA);
  1963. return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr);
  1964. }
  1965. Token ParseValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, AnfNodePtr *node_ptr = nullptr) {
  1966. // tuple or list
  1967. if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) {
  1968. return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr);
  1969. } else if (tok == TOK_IDENTIFIER) {
  1970. return ParseValueBasic(func_graph, lexer_.GetTokenText(), val_ptr, node_ptr);
  1971. } else if (tok == TOK_STRING) {
  1972. *val_ptr = std::make_shared<StringImm>(lexer_.GetTokenText());
  1973. return lexer_.GetNextToken();
  1974. }
  1975. MS_LOG(ERROR) << "Parse error!";
  1976. return TOK_ERROR;
  1977. }
  1978. Token ParseItem(const FuncGraphPtr &func_graph, AnfNodePtr *node_ptr, ValuePtr *const val_ptr,
  1979. Token tok = TOK_INVALID) {
  1980. if (tok == TOK_INVALID) {
  1981. tok = lexer_.GetNextToken();
  1982. }
  1983. if (tok == TOK_VARIABLE) {
  1984. auto iter = cnodes_.find(lexer_.GetTokenText());
  1985. if (iter == cnodes_.end()) {
  1986. MS_LOG(EXCEPTION) << "Can not find definition of '" << lexer_.GetTokenText() << "'";
  1987. }
  1988. *node_ptr = iter->second;
  1989. } else if (tok == TOK_PARAMETER) {
  1990. AnfNodePtr param = FindParameter(func_graph, lexer_.GetTokenText());
  1991. if (param == nullptr) {
  1992. MS_LOG(EXCEPTION) << "Can not find definition of '" << lexer_.GetTokenText() << "' at line "
  1993. << lexer_.GetLineNo();
  1994. }
  1995. *node_ptr = param;
  1996. } else if (tok == TOK_IDENTIFIER || tok == TOK_LPARENTHESIS || tok == TOK_STRING) {
  1997. ValuePtr value;
  1998. AnfNodePtr node;
  1999. tok = ParseValue(func_graph, tok, &value, &node);
  2000. if (tok == TOK_ERROR) {
  2001. MS_LOG(ERROR) << "Parse value error!";
  2002. return tok;
  2003. }
  2004. if (node == nullptr) {
  2005. *val_ptr = value;
  2006. *node_ptr = std::make_shared<ValueNode>(value);
  2007. } else {
  2008. *node_ptr = node;
  2009. }
  2010. return tok;
  2011. } else {
  2012. MS_LOG(EXCEPTION) << "tok_type = " << tok;
  2013. }
  2014. return lexer_.GetNextToken();
  2015. }
  2016. Token ParseArgument(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
  2017. Token tok = lexer_.GetNextToken();
  2018. if (tok == TOK_RPARENTHESIS) {
  2019. return tok;
  2020. }
  2021. AnfNodePtr node = nullptr;
  2022. ValuePtr value = nullptr;
  2023. tok = ParseItem(func_graph, &node, &value, tok);
  2024. if (tok != TOK_ERROR) {
  2025. MS_EXCEPTION_IF_NULL(inputs_ptr);
  2026. inputs_ptr->push_back(node);
  2027. }
  2028. return tok;
  2029. }
  2030. const std::vector<FuncGraphPtr> &GetFuncGraphs() const { return func_graphs_; }
  2031. private:
  2032. Lexer lexer_;
  2033. std::vector<FuncGraphPtr> func_graphs_;
  2034. bool error_flag_ = false;
  2035. // store all parsed graphs
  2036. std::map<std::string, FuncGraphPtr> func_graphs_map_;
  2037. // map from child to parent, consider adding a 'parent' field in class Graph
  2038. std::map<FuncGraphPtr, FuncGraphPtr> parents_map_;
  2039. // map for buffering cnodes when parsing a graph
  2040. std::map<std::string, CNodePtr> cnodes_;
  2041. std::map<std::string, ParameterPtr> param_nodes_; // map parameter name to parameter
  2042. };
  2043. std::vector<FuncGraphPtr> ImportIR(const std::string &filename) {
  2044. IrParser parser(filename.c_str());
  2045. parser.ParseFile();
  2046. return parser.GetFuncGraphs();
  2047. }
  2048. } // namespace mindspore