You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf_ir_utils.cc 73 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "debug/anf_ir_utils.h"
  17. #include <fstream>
  18. #include <map>
  19. #include <memory>
  20. #include <unordered_map>
  21. #include <unordered_set>
  22. #include <algorithm>
  23. #include "utils/graph_utils.h"
  24. #include "utils/symbolic.h"
  25. #include "ir/meta_func_graph.h"
  26. #include "ir/param_value_py.h"
  27. #include "ir/tensor_py.h"
  28. #include "pipeline/parse/python_adapter.h"
  29. #include "pipeline/parse/resolve.h"
  30. #include "operator/composite/composite.h"
  31. #include "operator/composite/map.h"
  32. #include "utils/ordered_map.h"
  33. #include "utils/ordered_set.h"
  34. #include "utils/utils.h"
  35. #include "debug/trace.h"
  36. #include "debug/label.h"
  37. #include "utils/context/ms_context.h"
  38. #include "operator/ops.h"
  39. using mindspore::tensor::TensorPy;
  40. namespace mindspore {
  41. // max number of elements in sequence
  42. const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF;
  43. // ============================================== MindSpore IR Common ==============================================
  44. // get MindSpore Intermediate Representation Path
  45. std::string GetMsIrPath(void) {
  46. std::string path;
  47. const char *path_ptr = getenv("MS_IR_PATH");
  48. if (path_ptr != nullptr) {
  49. path = path_ptr;
  50. char real_path[PATH_MAX] = {0};
  51. #if defined(_WIN32) || defined(_WIN64)
  52. if (path.size() > PATH_MAX || _fullpath(real_path, path.c_str(), PATH_MAX) == nullptr) {
  53. MS_LOG(EXCEPTION) << "MS IR Path error, " << path_ptr;
  54. }
  55. #else
  56. if (path.size() > PATH_MAX || nullptr == realpath(path.c_str(), real_path)) {
  57. MS_LOG(EXCEPTION) << "MS IR path error, " << path_ptr;
  58. }
  59. #endif
  60. path = real_path;
  61. }
  62. return path;
  63. }
  64. std::string dump_obj(const py::object &obj, const std::string &path) {
  65. py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
  66. py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path));
  67. return py::str(name);
  68. }
  69. py::object load_obj(const std::string &path) {
  70. py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
  71. py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path));
  72. return obj;
  73. }
  74. // ============================================= MindSpore IR Exporter =============================================
  75. std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) {
  76. abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
  77. TypePtr type = dyn_cast<Type>(nd->Type());
  78. std::ostringstream oss;
  79. if ((nullptr != shape) && (nullptr != type)) {
  80. oss << type->DumpText() << shape->DumpText();
  81. } else if (nullptr != type) {
  82. oss << type->DumpText();
  83. } else {
  84. oss << "Undefined";
  85. }
  86. return oss.str();
  87. }
  88. std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const {
  89. std::string pkl_path = GetMsIrPath();
  90. // if not specified env 'MS_IR_PATH', do not create any files
  91. if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) {
  92. return "null";
  93. }
  94. std::string file_prefix = id_ + "." + category;
  95. std::string file_name = dump_obj(obj, pkl_path + "/" + file_prefix);
  96. return file_prefix + file_name;
  97. }
  98. int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &param, bool throw_excp) {
  99. if (func_graph == nullptr || param == nullptr) {
  100. return -1;
  101. }
  102. FuncGraphPtr fg = func_graph;
  103. while (fg != nullptr) {
  104. if (exported.find(fg) == exported.end()) {
  105. if (!check_integrity_) {
  106. break;
  107. }
  108. MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'";
  109. }
  110. auto param_map = exported[fg];
  111. if (param_map.find(param) != param_map.end()) {
  112. return param_map[param];
  113. }
  114. fg = fg->parent();
  115. }
  116. if (throw_excp) {
  117. MS_LOG(EXCEPTION) << "Can not find index for param '" << param->DumpText() << "' for func graph '"
  118. << func_graph->DumpText() << "." << func_graph->debug_info()->get_id() << "'";
  119. }
  120. return -1;
  121. }
  122. // try to find index of parameter for SymbolicKeyInstance from all exported graphs
  123. // NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different
  124. int AnfExporter::GetParamIndexFromExported(const AnfNodePtr &param) {
  125. if (param == nullptr) {
  126. return -1;
  127. }
  128. int ret = -1;
  129. for (const auto &item : exported) {
  130. auto pram_iter = item.second.find(param);
  131. if (pram_iter != item.second.end()) {
  132. return pram_iter->second;
  133. }
  134. }
  135. return ret;
  136. }
  137. std::string AnfExporter::GetValueNodeText(const FuncGraphPtr &fg, const ValueNodePtr &node) {
  138. MS_EXCEPTION_IF_NULL(node);
  139. return GetValueText(fg, node->value());
  140. }
  141. std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph) {
  142. auto py_funcs = mt_func_graph->GetPyFunctions();
  143. if (py_funcs.empty()) {
  144. return "";
  145. }
  146. std::ostringstream oss;
  147. oss << "{";
  148. bool is_first = true;
  149. for (const auto &py_func : py_funcs) {
  150. if (is_first) {
  151. is_first = false;
  152. } else {
  153. oss << ", ";
  154. }
  155. oss << "(";
  156. for (size_t i = 0; i < py_func.first.size(); ++i) {
  157. if (i > 0) {
  158. oss << ", ";
  159. }
  160. oss << py_func.first[i]->DumpText();
  161. }
  162. oss << ")";
  163. // dump Python Function object
  164. oss << "@" << DumpObject(py_func.second, "F");
  165. }
  166. oss << "}";
  167. return oss.str();
  168. }
  169. /* inherit relation of MetaFuncGraph
  170. *
  171. * MetaGraph
  172. * ├── MultitypeGraph
  173. * ├── HyperMap
  174. * │ └── HyperMapPy
  175. * ├── Map
  176. * │ └── MapPy
  177. * ├── Tail
  178. * ├── MakeTupleGradient
  179. * ├── GradOperation
  180. * └── TupleAdd
  181. */
  182. std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) {
  183. if (meta_func_graph == nullptr) {
  184. return "";
  185. }
  186. std::ostringstream oss;
  187. oss << meta_func_graph->type_name() << "::" << meta_func_graph->name();
  188. if (meta_func_graph->isa<prim::MultitypeFuncGraph>()) {
  189. prim::MultitypeFuncGraphPtr mt_func_graph = meta_func_graph->cast<prim::MultitypeFuncGraphPtr>();
  190. oss << GetMultitypeFuncGraphText(mt_func_graph);
  191. } else if (meta_func_graph
  192. ->isa<prim::HyperMapPy>()) { // this statement must before 'meta_graph->isa<prim::HyperMap>()'
  193. auto hyper_map = meta_func_graph->cast<prim::HyperMapPyPtr>();
  194. if (hyper_map->GetFnLeaf() != nullptr) {
  195. oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}";
  196. }
  197. } else if (meta_func_graph->isa<prim::HyperMap>()) {
  198. auto hyper_map = meta_func_graph->cast<prim::HyperMapPtr>();
  199. if (hyper_map->GetFnLeaf() != nullptr) {
  200. oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}";
  201. }
  202. } else if (meta_func_graph->isa<prim::MapPy>()) { // this statement must before 'meta_graph->isa<prim::Map>()'
  203. auto map = meta_func_graph->cast<prim::MapPyPtr>();
  204. if (map->GetFnLeaf() != nullptr) {
  205. oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}";
  206. }
  207. } else if (meta_func_graph->isa<prim::Map>()) {
  208. auto map = meta_func_graph->cast<prim::MapPtr>();
  209. if (map->GetFnLeaf() != nullptr) {
  210. oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}";
  211. }
  212. } else if (meta_func_graph->isa<prim::GradOperation>()) {
  213. prim::GradOperationPtr grad_op = meta_func_graph->cast<prim::GradOperationPtr>();
  214. oss << "{get_all=" << grad_op->get_all_ << ", get_by_list=" << grad_op->get_by_list_
  215. << ", sens_param=" << grad_op->sens_param_ << "}";
  216. } else if (meta_func_graph->isa<prim::Tail>()) {
  217. // do nothing
  218. } else if (meta_func_graph->isa<prim::MakeTupleGradient>()) {
  219. // do nothing
  220. } else if (meta_func_graph->isa<prim::TupleAdd>()) {
  221. // do nothing
  222. } else if (meta_func_graph->isa<prim::TupleSlice>()) {
  223. // do nothing
  224. } else if (meta_func_graph->isa<prim::UnpackCall>()) {
  225. // do nothing
  226. } else if (meta_func_graph->isa<prim::ZipOperation>()) {
  227. // do nothing
  228. } else if (meta_func_graph->isa<prim::ListAppend>()) {
  229. // do nothing
  230. } else if (meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
  231. // do nothing
  232. } else {
  233. MS_LOG(EXCEPTION) << "Unknown MetaFuncGraph type " << meta_func_graph->type_name();
  234. }
  235. return oss.str();
  236. }
  237. std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) {
  238. std::ostringstream oss;
  239. if (prim == nullptr) {
  240. return oss.str();
  241. }
  242. oss << prim->type_name() << "::" << prim->name();
  243. // need to serialize internal python function of PrimitivePy and record its prim_type
  244. if (prim->isa<PrimitivePy>()) {
  245. PrimitivePyPtr primpy = prim->cast<PrimitivePyPtr>();
  246. // dump related function in PrimitivePy
  247. oss << "@" << DumpObject(primpy->GetPyObj(), "P");
  248. // output primitive type
  249. oss << "{prim_type=" << static_cast<int>(prim->prim_type()) << "}";
  250. }
  251. // output primitive attributes
  252. oss << prim->GetAttrsText();
  253. if (prim->isa<prim::DoSignaturePrimitive>()) {
  254. auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim);
  255. auto &func = do_signature->function();
  256. if (func->isa<Primitive>()) {
  257. auto sig_prim = dyn_cast<Primitive>(func);
  258. oss << sig_prim->GetAttrsText();
  259. }
  260. }
  261. return oss.str();
  262. }
  263. std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) {
  264. std::ostringstream oss;
  265. if (ns == nullptr) {
  266. return oss.str();
  267. }
  268. // dump related module information in Namespace
  269. oss << ns->type_name() << "::" << ns->module() << "@" << DumpObject(ns->obj(), "N");
  270. return oss.str();
  271. }
  272. std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph,
  273. const SymbolicKeyInstancePtr &sym_inst) {
  274. MS_EXCEPTION_IF_NULL(func_graph);
  275. MS_EXCEPTION_IF_NULL(sym_inst);
  276. AnfNodePtr sym_node = sym_inst->node();
  277. MS_EXCEPTION_IF_NULL(sym_node);
  278. std::ostringstream oss;
  279. if (sym_node->isa<Parameter>()) {
  280. int idx = GetParamIndex(func_graph, sym_node, false);
  281. // if can not find SymbolicKeyInstance related parameter from ancestors,
  282. // try to find from all exported graphs
  283. if (idx < 0) {
  284. idx = GetParamIndexFromExported(sym_node);
  285. }
  286. if (idx < 0) {
  287. ParameterPtr p = dyn_cast<Parameter>(sym_node);
  288. if (p == nullptr) {
  289. MS_LOG(EXCEPTION) << "Sym_inst's node could not cast to parameter";
  290. }
  291. MS_LOG(WARNING) << "Can not find SymbolicKeyInstance: " << p->name();
  292. }
  293. oss << "SymInst(%para" << idx << ")";
  294. } else {
  295. MS_LOG(EXCEPTION) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString();
  296. }
  297. return oss.str();
  298. }
  299. std::string AnfExporter::GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
  300. std::ostringstream oss;
  301. // output ValueList, ValueTuple
  302. ValueSequeuePtr seq = dyn_cast<ValueSequeue>(value);
  303. MS_EXCEPTION_IF_NULL(seq);
  304. MS_EXCEPTION_IF_NULL(value);
  305. bool is_tuple = value->isa<ValueTuple>();
  306. oss << (is_tuple ? "(" : "[");
  307. bool first_flag = true;
  308. for (auto elem : seq->value()) {
  309. if (first_flag) {
  310. first_flag = false;
  311. } else {
  312. oss << ", ";
  313. }
  314. oss << GetValueText(func_graph, elem);
  315. }
  316. oss << (is_tuple ? ")" : "]");
  317. return oss.str();
  318. }
  319. std::string AnfExporter::GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
  320. std::ostringstream oss;
  321. ValueDictionaryPtr dict = value->cast<ValueDictionaryPtr>();
  322. oss << "{";
  323. bool first_flag = true;
  324. for (const auto &elem : dict->value()) {
  325. if (first_flag) {
  326. first_flag = false;
  327. } else {
  328. oss << ", ";
  329. }
  330. oss << "\"" << elem.first << "\": " << GetValueText(func_graph, elem.second);
  331. }
  332. oss << "}";
  333. return oss.str();
  334. }
  335. std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr &value) {
  336. std::ostringstream oss;
  337. if (check_integrity_) {
  338. MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText();
  339. }
  340. oss << value->type_name() << "[" << value->DumpText() << "]";
  341. return oss.str();
  342. }
  343. std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
  344. std::ostringstream oss;
  345. bool is_null_ptr = (func_graph == nullptr || value == nullptr);
  346. if (is_null_ptr) {
  347. return oss.str();
  348. }
  349. if (value->isa<Primitive>()) {
  350. oss << GetPrimitiveText(value->cast<PrimitivePtr>());
  351. } else if (value->isa<MetaFuncGraph>()) {
  352. MetaFuncGraphPtr meta_func_graph = value->cast<MetaFuncGraphPtr>();
  353. oss << GetMetaFuncGraphText(meta_func_graph);
  354. } else if (value->isa<SymbolicKeyInstance>()) {
  355. oss << GetSymbolicKeyInstanceText(func_graph, value->cast<SymbolicKeyInstancePtr>());
  356. } else if (value->isa<RefKey>()) {
  357. oss << value->DumpText();
  358. } else if (value->isa<Scalar>() || value->isa<StringImm>()) {
  359. oss << value->DumpText();
  360. } else if (value->isa<tensor::Tensor>()) {
  361. auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
  362. oss << value->DumpText() << "@" << DumpObject(TensorPy::AsNumpy(*tensor_ptr), "T");
  363. } else if (value->isa<parse::Symbol>() || value->isa<None>() || value->isa<Null>()) {
  364. oss << value->DumpText();
  365. } else if (value->isa<ValueSequeue>()) {
  366. oss << GetSequenceText(func_graph, value);
  367. } else if (value->isa<ValueDictionary>()) {
  368. oss << GetDictText(func_graph, value);
  369. } else if (value->isa<ValueSlice>()) {
  370. ValueSlicePtr slice = value->cast<ValueSlicePtr>();
  371. oss << slice->DumpText();
  372. } else if (value->isa<Type>()) {
  373. oss << value->DumpText();
  374. } else if (value->isa<parse::NameSpace>()) {
  375. oss << GetNameSpaceText(value->cast<parse::NameSpacePtr>());
  376. } else if (value->isa<parse::PyObjectWrapper>()) {
  377. oss << value->type_name();
  378. } else if (value->isa<KeywordArg>()) {
  379. KeywordArgPtr keyword_arg = value->cast<KeywordArgPtr>();
  380. oss << keyword_arg->DumpText();
  381. } else {
  382. return GetOtherValueText(func_graph, value);
  383. }
  384. return oss.str();
  385. }
  386. // this function is used to output node in CNode's inputs
  387. std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  388. const std::map<AnfNodePtr, int> &apply_map) {
  389. std::ostringstream oss;
  390. if (func_graph == nullptr || node == nullptr) {
  391. return oss.str();
  392. }
  393. if (node->isa<CNode>()) {
  394. auto iter = apply_map.find(node);
  395. if (iter == apply_map.end()) {
  396. MS_LOG(EXCEPTION) << "Can not find node '" << node->DumpText() << "' in apply_map";
  397. }
  398. oss << "%" << iter->second;
  399. } else if (node->isa<Parameter>()) {
  400. oss << "%para" << GetParamIndex(func_graph, node, check_integrity_);
  401. } else if (IsValueNode<FuncGraph>(node)) {
  402. FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
  403. oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id();
  404. if (!func_graph_set.contains(fg) && exported.find(fg) == exported.end() && export_used_) {
  405. func_graph_set.add(fg);
  406. }
  407. } else if (node->isa<ValueNode>()) {
  408. oss << GetValueNodeText(func_graph, node->cast<ValueNodePtr>());
  409. } else {
  410. MS_LOG(EXCEPTION) << "Unknown node '" << node->DumpText() << "'";
  411. }
  412. return oss.str();
  413. }
  414. void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> &parameters,
  415. OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map) {
  416. bool first_flag = true;
  417. for (const AnfNodePtr &param : parameters) {
  418. if (first_flag) {
  419. first_flag = false;
  420. ofs << " ";
  421. } else {
  422. ofs << " , ";
  423. }
  424. (*param_map)[param] = param_index;
  425. std::string type_info = GetNodeType(param);
  426. // output parameter and type
  427. if (type_info == "Undefined") {
  428. ofs << "%para" << param_index;
  429. } else {
  430. ofs << "%para" << param_index << " : " << type_info;
  431. }
  432. // dump Default value of parameter if exists
  433. const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
  434. if (param_ptr == nullptr) {
  435. MS_LOG(EXCEPTION) << "Param could not cast to parameter";
  436. }
  437. if (param_ptr->has_default()) {
  438. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param());
  439. ofs << " = @" << DumpObject(param_value->value(), "D");
  440. }
  441. // output comment
  442. ofs << " # " << param->DumpText() << "\n";
  443. param_index += 1;
  444. }
  445. }
  446. void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &node) {
  447. if (node == nullptr) {
  448. return;
  449. }
  450. // output type of each input argument
  451. auto &inputs = node->inputs();
  452. if (inputs.size() > 1) {
  453. ofs << " #(";
  454. for (size_t i = 1; i < inputs.size(); ++i) {
  455. if (i != 1) {
  456. ofs << ", ";
  457. }
  458. AnfNodePtr arg = inputs[i];
  459. ofs << GetNodeType(arg);
  460. }
  461. ofs << ")";
  462. }
  463. // output other comment, map the graph name to original representation(containing unicode character)
  464. std::ostringstream comment;
  465. comment << " #";
  466. bool has_comment = false;
  467. for (size_t i = 0; i < inputs.size(); ++i) {
  468. AnfNodePtr arg = inputs[i];
  469. if (!IsValueNode<FuncGraph>(arg)) {
  470. continue;
  471. }
  472. if (!has_comment) {
  473. has_comment = true;
  474. } else {
  475. comment << ",";
  476. }
  477. FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(arg);
  478. std::string func_graph_id = fg->debug_info()->get_id();
  479. comment << " fg_" << func_graph_id << "=" << fg->ToString() << "." << func_graph_id;
  480. }
  481. if (has_comment) {
  482. ofs << comment.str();
  483. }
  484. ofs << " #scope: " << node->scope()->name();
  485. }
  486. void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes,
  487. const FuncGraphPtr &func_graph) {
  488. if (func_graph == nullptr) {
  489. return;
  490. }
  491. int idx = 1;
  492. std::map<AnfNodePtr, int> apply_map;
  493. for (const AnfNodePtr &node : nodes) {
  494. MS_EXCEPTION_IF_NULL(node);
  495. if (!node->isa<CNode>()) {
  496. continue;
  497. }
  498. auto iter = tagged_cnodes_.find(node);
  499. if (iter != tagged_cnodes_.end()) {
  500. ofs << "\n#------------------------> " << iter->second << "\n";
  501. }
  502. auto cnode = node->cast<CNodePtr>();
  503. auto &inputs = cnode->inputs();
  504. std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map);
  505. // non-return node
  506. if (node != func_graph->get_return()) {
  507. int apply_idx = idx++;
  508. apply_map[node] = apply_idx;
  509. std::string type_info = GetNodeType(node);
  510. if (type_info == "Undefined") {
  511. ofs << " %" << apply_idx << " = " << op_text << "(";
  512. } else {
  513. ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "(";
  514. }
  515. } else {
  516. ofs << " " << op_text << "(";
  517. }
  518. for (size_t i = 1; i < inputs.size(); ++i) {
  519. if (i != 1) {
  520. ofs << ", ";
  521. }
  522. AnfNodePtr arg = inputs[i];
  523. ofs << GetAnfNodeText(func_graph, arg, apply_map);
  524. }
  525. ofs << ")";
  526. // output comment
  527. OutputStatementComment(ofs, cnode);
  528. ofs << "\n";
  529. if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
  530. ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#"
  531. << label_manage::Label(cnode->debug_info()) << "\n";
  532. } else {
  533. ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n";
  534. }
  535. }
  536. }
  537. void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) {
  538. if (func_graph == nullptr) {
  539. return;
  540. }
  541. std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
  542. std::vector<AnfNodePtr> parameters = func_graph->parameters();
  543. OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;
  544. ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "."
  545. << func_graph->debug_info()->get_id() << "\n";
  546. if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
  547. ofs << trace::GetDebugInfo(func_graph->debug_info(), "# ", kSourceLineTipDiscard) << "#"
  548. << label_manage::Label(func_graph->debug_info()) << "\n";
  549. } else {
  550. ofs << trace::GetDebugInfo(func_graph->debug_info(), "# ", kSourceLineTipDiscard) << "\n";
  551. }
  552. ofs << "funcgraph fg_" << func_graph->debug_info()->get_id();
  553. // output name of parent of graph if exists
  554. if (func_graph->parent() != nullptr) {
  555. ofs << "[fg_" << func_graph->parent()->debug_info()->get_id() << "]";
  556. }
  557. ofs << "(\n";
  558. OutputParameters(ofs, parameters, &param_map);
  559. exported[func_graph] = param_map;
  560. ofs << (!parameters.empty() ? " " : "") << ") {\n";
  561. OutputCNodes(ofs, nodes, func_graph);
  562. ofs << "}\n";
  563. }
  564. void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
  565. if (func_graph == nullptr) {
  566. return;
  567. }
  568. std::ofstream ofs(filename);
  569. if (!ofs.is_open()) {
  570. MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
  571. return;
  572. }
  573. param_index = 1;
  574. func_graph_set.add(func_graph);
  575. while (!func_graph_set.empty()) {
  576. FuncGraphPtr fg = *func_graph_set.begin();
  577. ExportOneFuncGraph(ofs, fg);
  578. ofs << "\n\n";
  579. (void)func_graph_set.erase(fg);
  580. }
  581. ofs << "# num of total function graphs: " << exported.size();
  582. ofs.close();
  583. }
  584. void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
  585. if (graphs.empty()) {
  586. return;
  587. }
  588. std::ofstream ofs(filename);
  589. if (!ofs.is_open()) {
  590. MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
  591. return;
  592. }
  593. param_index = 1;
  594. for (const auto &tagged_graph : graphs) {
  595. tagged_cnodes_ = tagged_graph.second;
  596. ExportOneFuncGraph(ofs, tagged_graph.first);
  597. tagged_cnodes_.clear();
  598. ofs << "\n\n";
  599. }
  600. ofs << "# num of total function graphs: " << graphs.size();
  601. ofs.close();
  602. }
  603. #ifdef ENABLE_DUMP_IR
  604. void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) {
  605. if (func_graph == nullptr) {
  606. return;
  607. }
  608. AnfExporter exporter(id);
  609. ChangeFileMode(filename, S_IRWXU);
  610. exporter.ExportFuncGraph(filename, func_graph);
  611. // set file mode to read only by user
  612. ChangeFileMode(filename, S_IRUSR);
  613. }
  614. void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
  615. AnfExporter exporter("", false);
  616. ChangeFileMode(filename, S_IRWXU);
  617. exporter.ExportFuncGraph(filename, graphs);
  618. // set file mode to read only by user
  619. ChangeFileMode(filename, S_IRUSR);
  620. }
  621. #else
  622. void ExportIR(const std::string &, const std::string &, const FuncGraphPtr &) {
  623. static bool already_printed = false;
  624. if (already_printed) {
  625. return;
  626. }
  627. already_printed = true;
  628. MS_LOG(WARNING) << "The functionality of dumping function graph IR is disabled, "
  629. << "please recompile source to enable it. See help of building script.";
  630. }
  631. void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
  632. static bool already_printed = false;
  633. if (already_printed) {
  634. return;
  635. }
  636. already_printed = true;
  637. MS_LOG(WARNING) << "The functionality of dumping function graph IR is disabled, "
  638. << "please recompile source to enable it. See help of building script.";
  639. }
  640. #endif
  641. // ============================================= MindSpore IR Importer =============================================
  642. enum Token : int {
  643. TOK_INVALID = 0, // invalid token
  644. TOK_LPARENTHESIS, // ( left parenthesis
  645. TOK_RPARENTHESIS, // ) right parenthesis
  646. TOK_LBRACKET, // [ left bracket
  647. TOK_RBRACKET, // ] right bracket
  648. TOK_LBRACE, // { left brace
  649. TOK_RBRACE, // } right brace
  650. TOK_COMMA, // , comma
  651. TOK_EQUALITY, // = equality
  652. TOK_COLON, // : colon
  653. TOK_STAR, // * star
  654. TOK_VARIABLE, // variable
  655. TOK_AT_FILE, // @filename
  656. TOK_PARAMETER, // parameter
  657. TOK_IDENTIFIER, // identifier
  658. TOK_FUNCGRAPH, // keyword 'funcgraph'
  659. TOK_RETURN, // id prim::return
  660. TOK_STRING, // string
  661. TOK_NUMBER, // number
  662. TOK_COMMENT, // comment
  663. TOK_EOL, // end of line
  664. TOK_EOF, // end of file
  665. TOK_ERROR // file read error
  666. };
  667. std::map<Token, const char *> token_text = {
  668. {TOK_INVALID, "invalid"}, // invalid token
  669. {TOK_LPARENTHESIS, "("}, // ( left parenthesis
  670. {TOK_RPARENTHESIS, ")"}, // ) right parenthesis
  671. {TOK_LBRACKET, "["}, // [ left bracket
  672. {TOK_RBRACKET, "]"}, // ] right bracket
  673. {TOK_LBRACE, "{"}, // { left brace
  674. {TOK_RBRACE, "}"}, // } right brace
  675. {TOK_COMMA, ","}, // , comma
  676. {TOK_EQUALITY, "="}, // = equality
  677. {TOK_COLON, ":"}, // : colon
  678. {TOK_STAR, "*"}, // * start
  679. {TOK_VARIABLE, nullptr}, // variable
  680. {TOK_AT_FILE, nullptr}, // @file
  681. {TOK_PARAMETER, nullptr}, // parameter
  682. {TOK_IDENTIFIER, nullptr}, // identifier
  683. {TOK_FUNCGRAPH, "funcgraph"}, // keyword 'funcgraph'
  684. {TOK_RETURN, nullptr}, // id prim::return
  685. {TOK_STRING, nullptr}, // string
  686. {TOK_NUMBER, nullptr}, // number
  687. {TOK_COMMENT, nullptr}, // comment
  688. {TOK_EOL, "\n"}, // end of line
  689. {TOK_EOF, ""}, // end of file
  690. {TOK_ERROR, "error"} // file read error
  691. };
  692. class Lexer {
  693. public:
  694. // filename is checked in ImportIR;
  695. explicit Lexer(const char *filename) : fin(filename) {}
  696. ~Lexer() {
  697. try {
  698. if (fin.is_open()) {
  699. fin.close();
  700. }
  701. } catch (const std::exception &e) {
  702. MS_LOG(ERROR) << "Exception when closing file";
  703. } catch (...) {
  704. std::string exName(abi::__cxa_current_exception_type()->name());
  705. MS_LOG(ERROR) << "Error occurred when closing file. Exception name: " << exName;
  706. }
  707. }
  708. bool IsSingleCharToken(char ch, Token *token_ptr) {
  709. // clang-format off
  710. std::unordered_map<char, Token> char_to_token = {
  711. {'(', TOK_LPARENTHESIS},
  712. {')', TOK_RPARENTHESIS},
  713. {'[', TOK_LBRACKET},
  714. {']', TOK_RBRACKET},
  715. {'{', TOK_LBRACE},
  716. {'}', TOK_RBRACE},
  717. {',', TOK_COMMA},
  718. {'=', TOK_EQUALITY},
  719. {':', TOK_COLON},
  720. {'*', TOK_STAR}};
  721. // clang-format on
  722. auto iter = char_to_token.find(ch);
  723. if (iter == char_to_token.end()) {
  724. return false;
  725. }
  726. if (token_ptr != nullptr) {
  727. *token_ptr = iter->second;
  728. }
  729. return true;
  730. }
  731. Token GetNextToken() {
  732. #ifdef DEBUG
  733. Token token = GetNextTokenInner();
  734. const char *str = token_text[token];
  735. std::string text = (str == nullptr ? GetTokenText() : str);
  736. MS_LOG(DEBUG) << "------Parse token] " << text;
  737. return token;
  738. }
  739. Token GetNextTokenInner() {
  740. #endif
  741. tok_idx = 0;
  742. Token tok = TOK_ERROR;
  743. char ch = SkipTabAndSpace();
  744. if (ch == CODE_EOF) {
  745. return TOK_EOF;
  746. } else if (ch == CODE_ERROR) {
  747. return TOK_ERROR;
  748. } else if (IsSingleCharToken(ch, &tok)) {
  749. return tok;
  750. } else if (ch == '\r') {
  751. char c = GetChar();
  752. if (c == '\n') {
  753. line_++;
  754. return TOK_EOL;
  755. }
  756. UnGetChar(c);
  757. line_++;
  758. return TOK_EOL;
  759. } else if (ch == '\n') {
  760. line_++;
  761. return TOK_EOL;
  762. } else if (ch == '#') {
  763. return ParseComment(ch);
  764. } else if (ch == '"') {
  765. return ParseString();
  766. } else if (ch == '%') {
  767. return ParseVariableOrParameter(ch);
  768. } else if (ch == '@') {
  769. return ParseAtFile();
  770. } else if (IsDigit(ch) || ch == '-') {
  771. return ParseNumber(ch);
  772. } else if (IsAlpha(ch) || ch == '_') {
  773. return ParseIdentifier(ch);
  774. } else {
  775. return TOK_ERROR;
  776. }
  777. }
  778. Token SkipWhiteToken() {
  779. Token tok = GetNextToken();
  780. while (tok == TOK_EOL || tok == TOK_COMMENT) {
  781. tok = GetNextToken();
  782. }
  783. return tok;
  784. }
  785. std::string GetTokenText() const { return std::string(tok_buf); }
  786. int GetLineNo() const { return line_; }
  787. private:
  788. Token ParseComment(char ch) {
  789. char c = GetChar();
  790. while (c != '\r' && c != '\n' && c != CODE_EOF) {
  791. c = GetChar();
  792. }
  793. if (ch != CODE_EOF) {
  794. UnGetChar(c);
  795. }
  796. tok_buf[0] = '#';
  797. tok_buf[1] = '\0';
  798. return TOK_COMMENT;
  799. }
  800. Token ParseString() {
  801. tok_idx = 0;
  802. char c = GetChar();
  803. while (c != '"') {
  804. if (tok_idx >= BUF_SIZE) {
  805. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  806. }
  807. if (c == '\r' || c == '\n') {
  808. MS_LOG(EXCEPTION) << "Literal newline characters are not allowed within the quote at line " << line_;
  809. }
  810. if (c == CODE_EOF) {
  811. MS_LOG(EXCEPTION) << "Encounter EOF within the quote at line " << line_;
  812. }
  813. tok_buf[tok_idx++] = c;
  814. c = GetChar();
  815. }
  816. tok_buf[tok_idx] = '\0';
  817. return TOK_STRING;
  818. }
  819. Token ParseVariableOrParameter(char ch) {
  820. tok_idx = 0;
  821. tok_buf[tok_idx++] = ch;
  822. char c = GetChar();
  823. while (IsAlphaNumeric(c)) {
  824. if (tok_idx >= BUF_SIZE) {
  825. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  826. }
  827. tok_buf[tok_idx++] = c;
  828. c = GetChar();
  829. }
  830. tok_buf[tok_idx] = '\0';
  831. UnGetChar(c);
  832. // judge parameter: %para[0-9]+
  833. tok_buf[tok_idx] = '\0';
  834. std::string param_key = "%para";
  835. if (strncmp(tok_buf, param_key.c_str(), param_key.size()) == 0) {
  836. if (tok_idx <= param_key.size()) {
  837. return TOK_ERROR;
  838. }
  839. for (auto i = static_cast<unsigned>(param_key.size()); i < tok_idx; ++i) {
  840. if (!IsDigit(tok_buf[i])) {
  841. return TOK_ERROR;
  842. }
  843. }
  844. return TOK_PARAMETER;
  845. }
  846. // judge local variable: %[0-9]+
  847. if (tok_idx == 1) {
  848. return TOK_ERROR;
  849. }
  850. for (unsigned i = 1; i < tok_idx; ++i) {
  851. if (!IsDigit(tok_buf[i])) {
  852. return TOK_ERROR;
  853. }
  854. }
  855. return TOK_VARIABLE;
  856. }
  857. Token ParseAtFile() {
  858. tok_idx = 0;
  859. char c = GetChar();
  860. while (IsAlphaNumeric(c) || c == '_' || c == '.') {
  861. if (tok_idx >= BUF_SIZE) {
  862. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  863. }
  864. tok_buf[tok_idx++] = c;
  865. c = GetChar();
  866. }
  867. tok_buf[tok_idx] = '\0';
  868. UnGetChar(c);
  869. if (tok_idx == 0) {
  870. return TOK_ERROR;
  871. }
  872. return TOK_AT_FILE;
  873. }
  874. Token ParseNumber(char ch) {
  875. tok_buf[tok_idx++] = ch;
  876. char c = GetChar();
  877. // parse number, e.g. 10, 15.6, 1e-5
  878. while (IsDigit(c) || c == '.' || c == 'e' || c == '-') {
  879. if (tok_idx >= BUF_SIZE) {
  880. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  881. }
  882. tok_buf[tok_idx++] = c;
  883. c = GetChar();
  884. }
  885. UnGetChar(c);
  886. tok_buf[tok_idx] = '\0';
  887. return TOK_NUMBER;
  888. }
  889. Token ParseIdentifier(char ch) {
  890. tok_idx = 0;
  891. tok_buf[tok_idx++] = ch;
  892. char c = GetChar();
  893. while (IsAlphaNumeric(c) || c == '.' || c == ':' || c == '_') {
  894. if (tok_idx >= BUF_SIZE) {
  895. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  896. }
  897. tok_buf[tok_idx++] = c;
  898. c = GetChar();
  899. }
  900. UnGetChar(c);
  901. tok_buf[tok_idx] = '\0';
  902. if (strcmp(tok_buf, "funcgraph") == 0) {
  903. return TOK_FUNCGRAPH;
  904. }
  905. if (strcmp(tok_buf, "Primitive::return") == 0) {
  906. return TOK_RETURN;
  907. }
  908. return TOK_IDENTIFIER;
  909. }
  910. // Suppose the file only contain ASCII character
  911. char GetChar() {
  912. if (ungot_char != UNGOT_CHAR) {
  913. char ch = ungot_char;
  914. ungot_char = UNGOT_CHAR;
  915. return ch;
  916. }
  917. if (idx >= cnt) {
  918. if (fin.eof()) {
  919. return CODE_EOF;
  920. }
  921. cnt = fin.read(buffer, BUF_SIZE).gcount();
  922. if ((fin.bad() || fin.fail()) && !fin.eof()) {
  923. MS_LOG(EXCEPTION) << "Read file error!";
  924. }
  925. idx = 0;
  926. }
  927. return buffer[idx++];
  928. }
  929. void UnGetChar(char ch) {
  930. if (ungot_char == UNGOT_CHAR) {
  931. ungot_char = ch;
  932. }
  933. }
  934. static bool IsTabOrSpace(char ch) { return ch == ' ' || ch == '\t'; }
  935. static bool IsDigit(char ch) { return ch >= '0' && ch <= '9'; }
  936. static bool IsAlpha(char ch) { return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z'); }
  937. static bool IsAlphaNumeric(char ch) { return IsDigit(ch) || IsAlpha(ch); }
  938. // skip whitespace(including comment) to read a valid character
  939. char SkipTabAndSpace() {
  940. char ch = GetChar();
  941. while (IsTabOrSpace(ch)) {
  942. ch = GetChar();
  943. }
  944. return ch;
  945. }
  946. std::ifstream fin;
  947. static const unsigned BUF_SIZE = 4096; // lexer buffer size
  948. char buffer[BUF_SIZE + 1] = {0}; // buffer for holding text read from text
  949. std::streamsize cnt = 0; // number of valid characters in the buffer
  950. unsigned idx = 0; // index of next char the lexer to read from
  951. char tok_buf[BUF_SIZE + 1] = {0}; // token buffer
  952. unsigned tok_idx = 0; // token buffer index
  953. char ungot_char = UNGOT_CHAR; // store ungot char
  954. static const int CODE_EOF = -1; // return code of GetChar
  955. static const int CODE_ERROR = -2; // read file error
  956. static const char UNGOT_CHAR = -3; // value of ungot char
  957. int line_ = 1; // current line number
  958. };
  959. const unsigned Lexer::BUF_SIZE;
  960. class IrParser {
  961. public:
  962. explicit IrParser(const char *filename) : lexer_(filename) {}
  963. ~IrParser() {}
  964. py::object LoadObject(const std::string &file_name) const {
  965. std::string pkl_path = GetMsIrPath();
  966. py::object default_obj = load_obj(pkl_path + "/" + file_name);
  967. return default_obj;
  968. }
  969. void ParseFile() {
  970. FuncGraphPtr func_graph = ParseFuncGraph();
  971. while (func_graph != nullptr) {
  972. func_graphs_.push_back(func_graph);
  973. func_graph = ParseFuncGraph();
  974. }
  975. if (error_flag_) {
  976. MS_LOG(EXCEPTION) << "Parse Error at line: " << lexer_.GetLineNo();
  977. }
  978. MS_LOG(INFO) << "Total graphs: " << func_graphs_.size();
  979. }
  980. Token ParseParent(FuncGraphPtr *const parent_ptr) {
  981. if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
  982. return TOK_ERROR;
  983. }
  984. std::string parent_name = lexer_.GetTokenText();
  985. // NOTICE: require definition of parent graph must before child graph
  986. auto iter = func_graphs_map_.find(parent_name);
  987. if (iter == func_graphs_map_.end()) {
  988. MS_LOG(EXCEPTION) << "Can not find definition of parent func graph '" << parent_name << "' at line "
  989. << lexer_.GetLineNo();
  990. }
  991. if (parent_ptr != nullptr) {
  992. *parent_ptr = iter->second;
  993. }
  994. if (lexer_.GetNextToken() != TOK_RBRACKET) {
  995. return TOK_ERROR;
  996. }
  997. return lexer_.GetNextToken();
  998. }
  999. FuncGraphPtr ParseFuncGraph() {
  1000. cnodes_.clear();
  1001. Token tok = lexer_.SkipWhiteToken();
  1002. if (tok != TOK_FUNCGRAPH) {
  1003. error_flag_ = tok != TOK_EOF;
  1004. return nullptr;
  1005. }
  1006. if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
  1007. error_flag_ = true;
  1008. return nullptr;
  1009. }
  1010. std::string func_graph_name = lexer_.GetTokenText();
  1011. if (func_graphs_map_.find(func_graph_name) == func_graphs_map_.end()) {
  1012. func_graphs_map_[func_graph_name] = std::make_shared<FuncGraph>();
  1013. }
  1014. FuncGraphPtr func_graph = func_graphs_map_[func_graph_name];
  1015. MS_EXCEPTION_IF_NULL(func_graph);
  1016. MS_EXCEPTION_IF_NULL(func_graph->debug_info());
  1017. func_graph->debug_info()->set_name(func_graph_name); // for debugging
  1018. FuncGraphPtr parent = nullptr;
  1019. tok = lexer_.GetNextToken();
  1020. if (tok == TOK_LBRACKET) {
  1021. tok = ParseParent(&parent);
  1022. if (parent != nullptr) {
  1023. parents_map_[func_graph] = parent;
  1024. }
  1025. }
  1026. if (tok != TOK_LPARENTHESIS) {
  1027. error_flag_ = true;
  1028. return nullptr;
  1029. }
  1030. if (ParseParameters(func_graph) == nullptr) {
  1031. error_flag_ = true;
  1032. return nullptr;
  1033. }
  1034. if (lexer_.SkipWhiteToken() != TOK_LBRACE) {
  1035. error_flag_ = true;
  1036. return nullptr;
  1037. }
  1038. // parse statements
  1039. if (ParseStatements(func_graph) == nullptr) {
  1040. error_flag_ = true;
  1041. return nullptr;
  1042. }
  1043. func_graphs_map_[func_graph_name] = func_graph;
  1044. return func_graph;
  1045. }
  1046. FuncGraphPtr ParseStatements(const FuncGraphPtr &func_graph) {
  1047. Token tok = lexer_.SkipWhiteToken();
  1048. while (tok == TOK_VARIABLE) {
  1049. if (ParseStatement(func_graph) == nullptr) {
  1050. return nullptr;
  1051. }
  1052. tok = lexer_.SkipWhiteToken();
  1053. }
  1054. if (tok == TOK_RETURN) {
  1055. return ParseReturn(func_graph);
  1056. }
  1057. return nullptr;
  1058. }
  1059. FuncGraphPtr ParseStatement(FuncGraphPtr func_graph) {
  1060. std::string var_name = lexer_.GetTokenText();
  1061. Token tok = lexer_.GetNextToken();
  1062. AbstractBasePtr type = nullptr;
  1063. if (tok == TOK_COLON) {
  1064. tok = ParseType(func_graph, &type);
  1065. }
  1066. if (tok != TOK_EQUALITY) {
  1067. return nullptr;
  1068. }
  1069. std::vector<AnfNodePtr> inputs;
  1070. AnfNodePtr node = nullptr;
  1071. ValuePtr val = nullptr;
  1072. tok = ParseItem(func_graph, &node, &val);
  1073. if (tok != TOK_LPARENTHESIS) {
  1074. return nullptr;
  1075. }
  1076. inputs.push_back(node);
  1077. int lineno = lexer_.GetLineNo();
  1078. if (ParseArguments(func_graph, &inputs) == nullptr) {
  1079. return nullptr;
  1080. }
  1081. tok = lexer_.GetNextToken();
  1082. if (tok == TOK_COMMENT) {
  1083. tok = lexer_.GetNextToken();
  1084. }
  1085. if (tok != TOK_EOL) {
  1086. return nullptr;
  1087. }
  1088. MS_EXCEPTION_IF_NULL(func_graph);
  1089. cnodes_[var_name] = func_graph->NewCNode(inputs);
  1090. MS_EXCEPTION_IF_NULL(cnodes_[var_name]);
  1091. cnodes_[var_name]->set_debug_info(std::make_shared<NodeDebugInfo>(var_name + "@" + std::to_string(lineno)));
  1092. return func_graph;
  1093. }
  1094. FuncGraphPtr ParseReturn(FuncGraphPtr func_graph) {
  1095. if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
  1096. return nullptr;
  1097. }
  1098. AnfNodePtr input1 = nullptr;
  1099. ValuePtr value = nullptr;
  1100. Token tok = ParseItem(func_graph, &input1, &value, lexer_.GetNextToken());
  1101. int lineno = lexer_.GetLineNo();
  1102. if (tok != TOK_RPARENTHESIS) {
  1103. return nullptr;
  1104. }
  1105. tok = lexer_.GetNextToken();
  1106. if (tok == TOK_COMMENT) {
  1107. tok = lexer_.GetNextToken();
  1108. }
  1109. if (tok != TOK_EOL) {
  1110. return nullptr;
  1111. }
  1112. if (lexer_.SkipWhiteToken() != TOK_RBRACE) {
  1113. return nullptr;
  1114. }
  1115. PrimitivePtr prim = std::make_shared<Primitive>("return");
  1116. ValueNodePtr input0 = std::make_shared<ValueNode>(prim);
  1117. std::vector<AnfNodePtr> inputs;
  1118. inputs.push_back(input0);
  1119. inputs.push_back(input1);
  1120. MS_EXCEPTION_IF_NULL(func_graph);
  1121. CNodePtr ret = func_graph->NewCNode(inputs);
  1122. MS_EXCEPTION_IF_NULL(ret);
  1123. ret->set_debug_info(std::make_shared<NodeDebugInfo>(std::string("ret@") + std::to_string(lineno)));
  1124. func_graph->set_return(ret);
  1125. return func_graph;
  1126. }
  1127. void SetBasicType(TypePtr *ptr, const TypePtr &dtype) const {
  1128. if (ptr == nullptr) {
  1129. return;
  1130. }
  1131. *ptr = dtype;
  1132. }
  1133. void SetTupleType(TypePtr *ptr) {
  1134. if (ptr == nullptr) {
  1135. return;
  1136. }
  1137. *ptr = std::make_shared<Tuple>();
  1138. }
  1139. void SetTupleType(TypePtr *ptr, const TypePtrList &elems) {
  1140. if (ptr == nullptr) {
  1141. return;
  1142. }
  1143. *ptr = std::make_shared<Tuple>(elems);
  1144. }
  1145. void SetArrayType(TypePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &) {
  1146. if (ptr == nullptr) {
  1147. return;
  1148. }
  1149. *ptr = std::make_shared<TensorType>(elem_type);
  1150. }
  1151. void SetListType(TypePtr *ptr) {
  1152. if (ptr == nullptr) {
  1153. return;
  1154. }
  1155. *ptr = std::make_shared<List>();
  1156. }
  1157. void SetListType(TypePtr *ptr, const TypePtrList &elems) {
  1158. if (ptr == nullptr) {
  1159. return;
  1160. }
  1161. *ptr = std::make_shared<List>(elems);
  1162. }
  1163. void SetJTaggedType(TypePtr *ptr, const TypePtr &elem) {
  1164. if (ptr == nullptr) {
  1165. return;
  1166. }
  1167. *ptr = std::make_shared<JTagged>(elem);
  1168. }
  1169. void SetBasicType(AbstractBasePtr *ptr, const TypePtr &dtype) const {
  1170. if (ptr == nullptr) {
  1171. return;
  1172. }
  1173. *ptr = std::make_shared<abstract::AbstractScalar>(dtype);
  1174. }
  1175. // void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {}
  1176. void SetBasicType(AbstractBasePtr *const ptr, const TypeNonePtr &) const {
  1177. if (ptr == nullptr) {
  1178. return;
  1179. }
  1180. *ptr = std::make_shared<abstract::AbstractNone>();
  1181. }
  1182. void SetBasicType(AbstractBasePtr *, const FunctionPtr &) const {}
  1183. void SetBasicType(AbstractBasePtr *, const TensorTypePtr &) const {}
  1184. void SetTupleType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
  1185. if (ptr == nullptr) {
  1186. return;
  1187. }
  1188. // if one of elems is nullptr, just return
  1189. if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
  1190. return;
  1191. }
  1192. *ptr = std::make_shared<abstract::AbstractTuple>(elems);
  1193. }
  1194. void SetArrayType(AbstractBasePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &shape) {
  1195. if (ptr == nullptr) {
  1196. return;
  1197. }
  1198. *ptr = std::make_shared<abstract::AbstractTensor>(elem_type, shape);
  1199. }
  1200. void SetListType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
  1201. if (ptr == nullptr) {
  1202. return;
  1203. }
  1204. if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
  1205. return;
  1206. }
  1207. *ptr = std::make_shared<abstract::AbstractList>(elems);
  1208. }
  1209. void SetJTaggedType(AbstractBasePtr *const ptr, const AbstractBasePtr &elem) {
  1210. if (ptr == nullptr) {
  1211. return;
  1212. }
  1213. *ptr = std::make_shared<abstract::AbstractJTagged>(elem);
  1214. }
  1215. template <typename T>
  1216. Token ParseTypeVector(const FuncGraphPtr &func_graph, Token tok, const std::string &type, T *const ptr = nullptr) {
  1217. if (tok != TOK_LBRACKET) {
  1218. MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol.";
  1219. return tok;
  1220. }
  1221. bool first_flag = true;
  1222. std::vector<T> elems;
  1223. do {
  1224. tok = lexer_.GetNextToken();
  1225. if (first_flag) {
  1226. if (tok == TOK_RBRACKET) {
  1227. return lexer_.GetNextToken();
  1228. }
  1229. first_flag = false;
  1230. }
  1231. T elem = nullptr;
  1232. tok = ParseOneType(func_graph, tok, &elem);
  1233. elems.push_back(elem);
  1234. if (tok == TOK_STAR) {
  1235. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1236. return TOK_ERROR;
  1237. }
  1238. int num_elems = StringToScalar<int>(lexer_.GetTokenText());
  1239. if (num_elems < 1 || num_elems > NUM_MAX_SEQUENCE_ELEMS) {
  1240. MS_LOG(EXCEPTION) << "Number of elements " << num_elems << " is out of range [1, " << NUM_MAX_SEQUENCE_ELEMS
  1241. << "]";
  1242. }
  1243. for (int i = 0; i < num_elems - 1; ++i) {
  1244. elems.push_back(elem);
  1245. }
  1246. tok = lexer_.GetNextToken();
  1247. }
  1248. } while (tok == TOK_COMMA);
  1249. if (tok != TOK_RBRACKET) {
  1250. return TOK_ERROR;
  1251. }
  1252. if (type == "Tuple") {
  1253. SetTupleType(ptr, elems);
  1254. } else if (type == "List") {
  1255. SetListType(ptr, elems);
  1256. } else {
  1257. MS_LOG(EXCEPTION) << "This method does not support " << type << " parse.";
  1258. }
  1259. return lexer_.GetNextToken();
  1260. }
  1261. template <typename T>
  1262. Token ParseTypeArray(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
  1263. if (tok != TOK_LPARENTHESIS) {
  1264. if (ptr != nullptr) {
  1265. SetBasicType(ptr, std::make_shared<TensorType>());
  1266. }
  1267. return tok;
  1268. }
  1269. // process Array element type
  1270. TypePtr elem_type = nullptr;
  1271. std::vector<int> shape;
  1272. tok = ParseOneType(func_graph, lexer_.GetNextToken(), &elem_type);
  1273. if (tok != TOK_RPARENTHESIS) {
  1274. return TOK_ERROR;
  1275. }
  1276. tok = lexer_.GetNextToken();
  1277. if (tok != TOK_LBRACKET) {
  1278. // NOTICE: if shape.size == 0, is this ok?
  1279. SetArrayType(ptr, elem_type, shape);
  1280. return tok;
  1281. }
  1282. // process Array shape
  1283. do {
  1284. tok = lexer_.GetNextToken();
  1285. // case: Array(I32)[]
  1286. if (tok != TOK_NUMBER) {
  1287. break;
  1288. }
  1289. shape.push_back(StringToScalar<int>(lexer_.GetTokenText()));
  1290. tok = lexer_.GetNextToken();
  1291. } while (tok == TOK_COMMA);
  1292. if (tok != TOK_RBRACKET) {
  1293. return TOK_ERROR;
  1294. }
  1295. SetArrayType(ptr, elem_type, shape);
  1296. return lexer_.GetNextToken();
  1297. }
  1298. bool IsNumberType(const std::string &type, TypeId *typeid_ptr) {
  1299. // clang-format off
  1300. static std::unordered_map<std::string, TypeId> basic_types = {
  1301. {"Bool", kNumberTypeBool},
  1302. {"I8", kNumberTypeInt8},
  1303. {"I16", kNumberTypeInt16},
  1304. {"I32", kNumberTypeInt32},
  1305. {"I64", kNumberTypeInt64},
  1306. {"U8", kNumberTypeUInt8},
  1307. {"U16", kNumberTypeUInt16},
  1308. {"U32", kNumberTypeUInt32},
  1309. {"U64", kNumberTypeUInt64},
  1310. {"F16", kNumberTypeFloat16},
  1311. {"F32", kNumberTypeFloat32},
  1312. {"F64", kNumberTypeFloat64},
  1313. {"Int", kNumberTypeInt},
  1314. {"UInt", kNumberTypeUInt},
  1315. {"Float", kNumberTypeFloat},
  1316. {"Number", kObjectTypeNumber}};
  1317. // clang-format on
  1318. auto iter = basic_types.find(type);
  1319. if (iter == basic_types.end()) {
  1320. return false;
  1321. }
  1322. if (typeid_ptr != nullptr) {
  1323. *typeid_ptr = iter->second;
  1324. }
  1325. return true;
  1326. }
  1327. template <typename T>
  1328. void ParseNumberType(const std::string &type, TypeId typeId, T *const ptr = nullptr) {
  1329. TypePtr dtype = nullptr;
  1330. std::unordered_map<int, TypePtr> type_map = {
  1331. {static_cast<int>(kNumberTypeBool), std::make_shared<Bool>()}, // Bool
  1332. {static_cast<int>(kNumberTypeInt8), std::make_shared<Int>(8)}, // Int8
  1333. {static_cast<int>(kNumberTypeInt16), std::make_shared<Int>(16)}, // Int16
  1334. {static_cast<int>(kNumberTypeInt32), std::make_shared<Int>(32)}, // Int32
  1335. {static_cast<int>(kNumberTypeInt64), std::make_shared<Int>(64)}, // Int64
  1336. {static_cast<int>(kNumberTypeUInt8), std::make_shared<UInt>(8)}, // UInt8
  1337. {static_cast<int>(kNumberTypeUInt16), std::make_shared<UInt>(16)}, // UInt16
  1338. {static_cast<int>(kNumberTypeUInt32), std::make_shared<UInt>(32)}, // UInt32
  1339. {static_cast<int>(kNumberTypeUInt64), std::make_shared<UInt>(64)}, // UInt64
  1340. {static_cast<int>(kNumberTypeFloat16), std::make_shared<Float>(16)}, // Float16
  1341. {static_cast<int>(kNumberTypeFloat32), std::make_shared<Float>(32)}, // Float32
  1342. {static_cast<int>(kNumberTypeFloat64), std::make_shared<Float>(64)}, // Float64
  1343. {static_cast<int>(kNumberTypeInt), std::make_shared<Int>()}, // Int
  1344. {static_cast<int>(kNumberTypeUInt), std::make_shared<UInt>()}, // UInt
  1345. {static_cast<int>(kNumberTypeFloat), std::make_shared<Float>()}, // Float
  1346. {static_cast<int>(kObjectTypeNumber), std::make_shared<Number>()}, // Number
  1347. };
  1348. auto iter = type_map.find(static_cast<int>(typeId));
  1349. if (iter != type_map.end()) {
  1350. dtype = iter->second;
  1351. } else {
  1352. MS_LOG(EXCEPTION) << "Unknown number type " << type;
  1353. }
  1354. SetBasicType(ptr, dtype);
  1355. }
  1356. template <typename T>
  1357. Token ParseTrivalType(const std::string &type, T *const ptr = nullptr) {
  1358. if (type == "NoneType") {
  1359. SetBasicType(ptr, std::make_shared<TypeNone>());
  1360. return lexer_.GetNextToken();
  1361. } else if (type == "ProblemType") {
  1362. SetBasicType(ptr, std::make_shared<Problem>());
  1363. return lexer_.GetNextToken();
  1364. } else if (type == "ExternalType") {
  1365. SetBasicType(ptr, std::make_shared<External>());
  1366. return lexer_.GetNextToken();
  1367. } else if (type == "AnythingType") {
  1368. SetBasicType(ptr, kAnyType);
  1369. return lexer_.GetNextToken();
  1370. } else if (type == "TypeType") {
  1371. SetBasicType(ptr, std::make_shared<TypeType>());
  1372. return lexer_.GetNextToken();
  1373. } else {
  1374. MS_LOG(EXCEPTION) << "Unknown type error at line " << lexer_.GetLineNo();
  1375. }
  1376. }
  1377. template <typename T>
  1378. Token ParseOneType(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
  1379. if (tok != TOK_IDENTIFIER) {
  1380. return TOK_ERROR;
  1381. }
  1382. std::string type = lexer_.GetTokenText();
  1383. TypeId typeId = kTypeUnknown;
  1384. if (IsNumberType(type, &typeId)) {
  1385. ParseNumberType(type, typeId, ptr);
  1386. return lexer_.GetNextToken();
  1387. } else if (type == "Tuple") {
  1388. return ParseTypeVector(func_graph, lexer_.GetNextToken(), type, ptr);
  1389. } else if (type == "Tensor") {
  1390. return ParseTypeArray(func_graph, lexer_.GetNextToken(), ptr);
  1391. } else if (type == "List") {
  1392. return ParseTypeVector(func_graph, lexer_.GetNextToken(), type, ptr);
  1393. } else if (type == "Func") {
  1394. tok = lexer_.GetNextToken();
  1395. if (tok != TOK_LBRACKET) {
  1396. SetBasicType(ptr, std::make_shared<Function>());
  1397. return tok;
  1398. }
  1399. MS_LOG(EXCEPTION) << "Need to process function parameter types at line " << lexer_.GetLineNo();
  1400. } else if (type == "JT") {
  1401. tok = lexer_.GetNextToken();
  1402. if (tok != TOK_LBRACKET) {
  1403. return tok;
  1404. }
  1405. T elem = nullptr;
  1406. tok = ParseOneType(func_graph, lexer_.GetNextToken(), &elem);
  1407. SetJTaggedType(ptr, elem);
  1408. if (tok != TOK_RBRACKET) {
  1409. return TOK_ERROR;
  1410. }
  1411. return lexer_.GetNextToken();
  1412. } else if (type == "SymType") {
  1413. SetBasicType(ptr, std::make_shared<SymbolicKeyType>());
  1414. return lexer_.GetNextToken();
  1415. } else if (type == "EnvType") {
  1416. SetBasicType(ptr, std::make_shared<EnvType>());
  1417. return lexer_.GetNextToken();
  1418. } else if (Match(type, "Cls.")) {
  1419. MS_LOG(EXCEPTION) << "Need to do class type at line " << lexer_.GetLineNo();
  1420. } else {
  1421. return ParseTrivalType(type, ptr);
  1422. }
  1423. }
  1424. Token ParseType(const FuncGraphPtr &func_graph, AbstractBasePtr *const abstract = nullptr) {
  1425. return ParseOneType(func_graph, lexer_.GetNextToken(), abstract);
  1426. }
  1427. Token ParseAttributes(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
  1428. Token tok = ParseAttribute(func_graph, prim);
  1429. while (tok == TOK_COMMA) {
  1430. tok = ParseAttribute(func_graph, prim);
  1431. }
  1432. if (tok != TOK_RBRACKET) {
  1433. return TOK_ERROR;
  1434. }
  1435. return lexer_.GetNextToken();
  1436. }
  1437. Token ParseAttribute(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
  1438. Token tok = lexer_.GetNextToken();
  1439. if (tok != TOK_IDENTIFIER) {
  1440. return TOK_ERROR;
  1441. }
  1442. std::string attr_name = lexer_.GetTokenText();
  1443. if (lexer_.GetNextToken() != TOK_EQUALITY) {
  1444. return TOK_ERROR;
  1445. }
  1446. ValuePtr value = nullptr;
  1447. tok = ParseValue(func_graph, lexer_.GetNextToken(), &value);
  1448. if (prim != nullptr) {
  1449. prim->set_attr(attr_name, value);
  1450. } else {
  1451. MS_LOG(EXCEPTION) << "Non primitive obj has attributes";
  1452. }
  1453. return tok;
  1454. }
  1455. FuncGraphPtr ParseParameters(FuncGraphPtr func_graph) {
  1456. Token tok = lexer_.SkipWhiteToken();
  1457. while (tok == TOK_PARAMETER) {
  1458. ParameterPtr param = std::make_shared<Parameter>(func_graph);
  1459. param->set_name(lexer_.GetTokenText());
  1460. param_nodes_[lexer_.GetTokenText()] = param;
  1461. int lineno = lexer_.GetLineNo();
  1462. param->set_debug_info(std::make_shared<NodeDebugInfo>(lexer_.GetTokenText() + "@" + std::to_string(lineno)));
  1463. func_graph->add_parameter(param);
  1464. tok = lexer_.GetNextToken();
  1465. // parse type
  1466. if (tok == TOK_COLON) {
  1467. AbstractBasePtr type = nullptr;
  1468. tok = ParseType(func_graph, &type);
  1469. }
  1470. // parse default value
  1471. if (tok == TOK_EQUALITY) {
  1472. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1473. MS_LOG(EXCEPTION) << "Expect @file at line " << lexer_.GetLineNo();
  1474. }
  1475. // load parameter default value from serialized file
  1476. py::object default_obj = LoadObject(lexer_.GetTokenText());
  1477. auto param_value_new = std::make_shared<ParamValuePy>(default_obj);
  1478. param->set_default_param(param_value_new);
  1479. tok = lexer_.GetNextToken();
  1480. }
  1481. if (tok == TOK_COMMENT || tok == TOK_EOL) {
  1482. tok = lexer_.SkipWhiteToken();
  1483. }
  1484. Token next = tok;
  1485. if (next == TOK_RPARENTHESIS) {
  1486. return func_graph;
  1487. } else if (next == TOK_COMMA) {
  1488. tok = lexer_.SkipWhiteToken();
  1489. } else {
  1490. return nullptr;
  1491. }
  1492. }
  1493. return tok == TOK_RPARENTHESIS ? func_graph : nullptr;
  1494. }
  1495. FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
  1496. Token tok = ParseArgument(func_graph, inputs_ptr);
  1497. while (tok == TOK_COMMA) {
  1498. tok = ParseArgument(func_graph, inputs_ptr);
  1499. }
  1500. if (tok != TOK_RPARENTHESIS) {
  1501. return nullptr;
  1502. }
  1503. return func_graph;
  1504. }
  1505. AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string &param_name) {
  1506. while (func_graph != nullptr) {
  1507. for (auto &ptr : func_graph->parameters()) {
  1508. MS_EXCEPTION_IF_NULL(ptr);
  1509. ParameterPtr param = ptr->cast<ParameterPtr>();
  1510. MS_EXCEPTION_IF_NULL(param);
  1511. if (param->name() == param_name) {
  1512. return ptr;
  1513. }
  1514. }
  1515. auto iter = parents_map_.find(func_graph);
  1516. if (iter == parents_map_.end()) {
  1517. break;
  1518. }
  1519. func_graph = iter->second;
  1520. }
  1521. return nullptr;
  1522. }
  1523. bool Match(const std::string &str, const std::string &pattern) const {
  1524. return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0;
  1525. }
  1526. template <typename T, typename V>
  1527. Token ParseScalar(ValuePtr *const val_ptr) {
  1528. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1529. return TOK_ERROR;
  1530. }
  1531. std::stringstream ss;
  1532. ss << lexer_.GetTokenText();
  1533. if (lexer_.GetNextToken() != TOK_RPARENTHESIS) {
  1534. return TOK_ERROR;
  1535. }
  1536. V val;
  1537. ss >> val;
  1538. *val_ptr = std::make_shared<T>(val);
  1539. return lexer_.GetNextToken();
  1540. }
  1541. template <typename VT, typename V, typename T>
  1542. Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
  1543. if (tok != TOK_LPARENTHESIS) {
  1544. *val_ptr = std::make_shared<T>();
  1545. return tok;
  1546. }
  1547. return ParseScalar<VT, V>(val_ptr);
  1548. }
  1549. template <typename VT, typename V, typename T, const unsigned nbits>
  1550. Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
  1551. if (tok != TOK_LPARENTHESIS) {
  1552. *val_ptr = std::make_shared<T>(nbits);
  1553. return tok;
  1554. }
  1555. return ParseScalar<VT, V>(val_ptr);
  1556. }
  1557. template <typename T>
  1558. T StringToScalar(const std::string &text) {
  1559. std::stringstream ss;
  1560. T value;
  1561. ss << text;
  1562. ss >> value;
  1563. return value;
  1564. }
  1565. Token ParseTensor(ValuePtr *const val_ptr) {
  1566. // parse type
  1567. TypeId type;
  1568. if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
  1569. return TOK_ERROR;
  1570. }
  1571. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1572. return TOK_ERROR;
  1573. }
  1574. type = static_cast<TypeId>(StringToScalar<int>(lexer_.GetTokenText()));
  1575. if (lexer_.GetNextToken() != TOK_RPARENTHESIS) {
  1576. return TOK_ERROR;
  1577. }
  1578. // parse shape
  1579. std::vector<int> shape;
  1580. Token tok = lexer_.GetNextToken();
  1581. if (tok != TOK_LBRACKET) {
  1582. return TOK_ERROR;
  1583. }
  1584. do {
  1585. tok = lexer_.GetNextToken();
  1586. // consider case: Tensor(23)[]
  1587. if (tok != TOK_NUMBER) {
  1588. break;
  1589. }
  1590. shape.push_back(StringToScalar<int>(lexer_.GetTokenText()));
  1591. tok = lexer_.GetNextToken();
  1592. } while (tok == TOK_COMMA);
  1593. if (tok != TOK_RBRACKET) {
  1594. return TOK_ERROR;
  1595. }
  1596. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1597. return TOK_ERROR;
  1598. }
  1599. py::object tensor_obj = LoadObject(lexer_.GetTokenText());
  1600. py::array tensor_data = py::cast<py::array>(tensor_obj);
  1601. if (tensor_data == nullptr) {
  1602. return TOK_ERROR;
  1603. }
  1604. *val_ptr = TensorPy::MakeTensor(tensor_data, TypeIdToType(type));
  1605. return lexer_.GetNextToken();
  1606. }
  1607. Token ParsePrimType(Token tok, PrimType *prim_type_ptr) {
  1608. if (tok != TOK_LBRACE) {
  1609. return tok;
  1610. }
  1611. if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
  1612. return TOK_ERROR;
  1613. }
  1614. if (lexer_.GetTokenText() != "prim_type") {
  1615. return TOK_ERROR;
  1616. }
  1617. if (lexer_.GetNextToken() != TOK_EQUALITY) {
  1618. return TOK_ERROR;
  1619. }
  1620. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1621. return TOK_ERROR;
  1622. }
  1623. int val = 0;
  1624. std::stringstream ss;
  1625. ss << lexer_.GetTokenText();
  1626. ss >> val;
  1627. *prim_type_ptr = PrimType(val);
  1628. if (lexer_.GetNextToken() != TOK_RBRACE) {
  1629. return TOK_ERROR;
  1630. }
  1631. return lexer_.GetNextToken();
  1632. }
  1633. Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
  1634. if (tok != TOK_LPARENTHESIS) {
  1635. return TOK_ERROR;
  1636. }
  1637. TypePtrList type_list;
  1638. do {
  1639. TypePtr type = nullptr;
  1640. tok = ParseOneType(nullptr, lexer_.GetNextToken(), &type);
  1641. type_list.push_back(type);
  1642. } while (tok == TOK_COMMA);
  1643. if (tok != TOK_RPARENTHESIS) {
  1644. return TOK_ERROR;
  1645. }
  1646. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1647. return TOK_ERROR;
  1648. }
  1649. // load Python function from serialized file
  1650. py::object py_func = LoadObject(lexer_.GetTokenText());
  1651. MS_EXCEPTION_IF_NULL(mt_func_graph);
  1652. mt_func_graph->Register(type_list, py::function(py_func));
  1653. return lexer_.GetNextToken();
  1654. }
  1655. Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
  1656. if (tok != TOK_LBRACE) {
  1657. return tok;
  1658. }
  1659. do {
  1660. tok = ParseMultitypeFuncGraphItem(mt_func_graph, lexer_.GetNextToken());
  1661. } while (tok == TOK_COMMA);
  1662. if (tok != TOK_RBRACE) {
  1663. return TOK_ERROR;
  1664. }
  1665. return lexer_.GetNextToken();
  1666. }
  1667. Token ParseBoolValue(const std::string &key, bool *val_ptr) {
  1668. if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) {
  1669. return TOK_ERROR;
  1670. }
  1671. if (lexer_.GetNextToken() != TOK_EQUALITY) {
  1672. return TOK_ERROR;
  1673. }
  1674. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1675. return TOK_ERROR;
  1676. }
  1677. bool value = false;
  1678. {
  1679. std::stringstream ss;
  1680. ss << lexer_.GetTokenText();
  1681. ss >> value;
  1682. }
  1683. if (val_ptr != nullptr) {
  1684. *val_ptr = value;
  1685. }
  1686. return lexer_.GetNextToken();
  1687. }
  1688. Token ParseValueGradOperation(const std::string &name, ValuePtr *const val_ptr) {
  1689. if (lexer_.GetNextToken() != TOK_LBRACE) {
  1690. return TOK_ERROR;
  1691. }
  1692. // get_all=0, get_by_list=1, sens_param=1
  1693. bool get_all = false;
  1694. Token tok = ParseBoolValue("get_all", &get_all);
  1695. if (tok != TOK_COMMA) {
  1696. return TOK_ERROR;
  1697. }
  1698. bool get_by_list = false;
  1699. tok = ParseBoolValue("get_by_list", &get_by_list);
  1700. if (tok != TOK_COMMA) {
  1701. return TOK_ERROR;
  1702. }
  1703. bool sens_param = false;
  1704. tok = ParseBoolValue("sens_param", &sens_param);
  1705. if (tok != TOK_RBRACE) {
  1706. return TOK_ERROR;
  1707. }
  1708. *val_ptr = std::make_shared<prim::GradOperation>(name, get_all, get_by_list, sens_param);
  1709. return lexer_.GetNextToken();
  1710. }
  1711. Token ParseSymbolicKeyInstance(const FuncGraphPtr &func_graph, AnfNodePtr *const node_ptr = nullptr) {
  1712. if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
  1713. return TOK_ERROR;
  1714. }
  1715. if (lexer_.GetNextToken() != TOK_PARAMETER) {
  1716. return TOK_ERROR;
  1717. }
  1718. std::string param_name = lexer_.GetTokenText();
  1719. if (lexer_.GetNextToken() != TOK_RPARENTHESIS) {
  1720. return TOK_ERROR;
  1721. }
  1722. auto iter = param_nodes_.find(param_name);
  1723. if (iter == param_nodes_.end()) {
  1724. MS_LOG(EXCEPTION) << "Can not find param '" << param_name << "' for SymbolicKeyInstance at line "
  1725. << lexer_.GetLineNo();
  1726. }
  1727. PrimitivePtr embed = std::make_shared<Primitive>("embed");
  1728. std::vector<AnfNodePtr> inputs;
  1729. inputs.push_back(std::make_shared<ValueNode>(embed));
  1730. inputs.push_back(iter->second);
  1731. if (node_ptr != nullptr) {
  1732. MS_EXCEPTION_IF_NULL(func_graph);
  1733. *node_ptr = func_graph->NewCNode(inputs);
  1734. } else {
  1735. MS_LOG(EXCEPTION) << "Not processed SymbolicKeyInstance '" << param_name << "' at line " << lexer_.GetLineNo()
  1736. << ".";
  1737. }
  1738. return lexer_.GetNextToken();
  1739. }
  1740. Token ParsePrimitivePy(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr) {
  1741. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1742. return TOK_ERROR;
  1743. }
  1744. // restore python function of PrimitivePy from serialized file
  1745. py::object py_obj = LoadObject(lexer_.GetTokenText());
  1746. PrimitivePyPtr ptr = nullptr;
  1747. if (py::hasattr(py_obj, "__setattr_flag__") && py::hasattr(py_obj, "_clone")) {
  1748. auto clone_fn = py_obj.attr("_clone");
  1749. py::object new_obj = clone_fn();
  1750. ptr = new_obj.cast<PrimitivePyPtr>();
  1751. if (ptr == nullptr) {
  1752. MS_LOG(EXCEPTION) << "Cast to type 'PrimitivePyPtr' error";
  1753. }
  1754. } else {
  1755. auto len = strlen("PrimitivePy::");
  1756. if (id.size() < len) {
  1757. return TOK_ERROR;
  1758. }
  1759. ptr = std::make_shared<PrimitivePy>(id.substr(len), py_obj);
  1760. }
  1761. *val_ptr = ptr;
  1762. PrimType prim_type = kPrimTypeUnknown;
  1763. Token next = ParsePrimType(lexer_.GetNextToken(), &prim_type);
  1764. if (prim_type != kPrimTypeUnknown) {
  1765. ptr->set_prim_type(prim_type);
  1766. }
  1767. if (next != TOK_LBRACKET) {
  1768. return next;
  1769. }
  1770. // parse attributes
  1771. next = ParseAttributes(func_graph, ptr);
  1772. return next;
  1773. }
  1774. Token ParseValueGraphAndNamespace(const std::string &id, ValuePtr *const val_ptr) {
  1775. if (Match(id, "MultitypeFuncGraph::")) {
  1776. std::string name = id.substr(strlen("MultitypeFuncGraph::"));
  1777. auto mt_func_graph = std::make_shared<prim::MultitypeFuncGraph>(name);
  1778. *val_ptr = mt_func_graph;
  1779. Token next = ParseMultitypeFuncGraph(mt_func_graph, lexer_.GetNextToken());
  1780. return next;
  1781. } else if (Match(id, "HyperMapPy::")) {
  1782. *val_ptr = std::make_shared<prim::HyperMapPy>();
  1783. Token next = lexer_.GetNextToken();
  1784. // process case: fn_leaf is not null
  1785. if (next == TOK_LBRACE) {
  1786. MS_LOG(EXCEPTION) << "Need to process fn_leaf at line " << lexer_.GetLineNo();
  1787. }
  1788. return next;
  1789. } else if (Match(id, "FuncGraph::")) {
  1790. std::string func_graph_name = id.substr(strlen("FuncGraph::"));
  1791. // if the graph does not exist, create a null graph, then fill the graph when encounter the definition
  1792. // of the graph
  1793. if (func_graphs_map_.find(func_graph_name) == func_graphs_map_.end()) {
  1794. func_graphs_map_[func_graph_name] = std::make_shared<FuncGraph>();
  1795. }
  1796. *val_ptr = func_graphs_map_[func_graph_name];
  1797. return lexer_.GetNextToken();
  1798. } else if (Match(id, "NameSpace::")) {
  1799. std::string module_name = id.substr(strlen("NameSpace::"));
  1800. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1801. MS_LOG(ERROR) << "Expect TOK_AT_FILE at line " << lexer_.GetLineNo();
  1802. return TOK_ERROR;
  1803. }
  1804. // load Python module information from serialized file
  1805. py::object py_obj = LoadObject(lexer_.GetTokenText());
  1806. *val_ptr = std::make_shared<parse::NameSpace>(module_name, py_obj);
  1807. return lexer_.GetNextToken();
  1808. } else {
  1809. MS_LOG(EXCEPTION) << "Unknown id " << id << " at line " << lexer_.GetLineNo();
  1810. }
  1811. }
  1812. Token ParseValueBasic(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr,
  1813. AnfNodePtr *const node_ptr = nullptr) {
  1814. if (id == "None") {
  1815. *val_ptr = std::make_shared<None>();
  1816. return lexer_.GetNextToken();
  1817. } else if (id == "Bool") {
  1818. return ParseScalar<BoolImm, bool, Bool>(val_ptr, lexer_.GetNextToken());
  1819. } else if (id == "I8") {
  1820. return ParseScalar<Int8Imm, int8_t, Int, 8>(val_ptr, lexer_.GetNextToken());
  1821. } else if (id == "I16") {
  1822. return ParseScalar<Int16Imm, int16_t, Int, 16>(val_ptr, lexer_.GetNextToken());
  1823. } else if (id == "I32") {
  1824. return ParseScalar<Int32Imm, int32_t, Int, 32>(val_ptr, lexer_.GetNextToken());
  1825. } else if (id == "I64") {
  1826. return ParseScalar<Int64Imm, int64_t, Int, 64>(val_ptr, lexer_.GetNextToken());
  1827. } else if (id == "U8") {
  1828. return ParseScalar<UInt8Imm, uint8_t, UInt, 8>(val_ptr, lexer_.GetNextToken());
  1829. } else if (id == "U16") {
  1830. return ParseScalar<UInt16Imm, uint16_t, UInt, 16>(val_ptr, lexer_.GetNextToken());
  1831. } else if (id == "U32") {
  1832. return ParseScalar<UInt32Imm, uint32_t, UInt, 32>(val_ptr, lexer_.GetNextToken());
  1833. } else if (id == "U64") {
  1834. return ParseScalar<UInt64Imm, uint64_t, UInt, 64>(val_ptr, lexer_.GetNextToken());
  1835. } else if (id == "F16") {
  1836. // Notice: Since there is no basic data type for storing fp16, just use float instead
  1837. return ParseScalar<FP32Imm, float, Float, 16>(val_ptr, lexer_.GetNextToken());
  1838. } else if (id == "F32") {
  1839. return ParseScalar<FP32Imm, float, Float, 32>(val_ptr, lexer_.GetNextToken());
  1840. } else if (id == "F64") {
  1841. return ParseScalar<FP64Imm, double, Float, 64>(val_ptr, lexer_.GetNextToken());
  1842. } else if (id == "Tensor") {
  1843. return ParseTensor(val_ptr);
  1844. } else if (id == "SymInst") {
  1845. return ParseSymbolicKeyInstance(func_graph, node_ptr);
  1846. } else if (id == "Array") {
  1847. TypePtr type = nullptr;
  1848. Token ret = ParseTypeArray(func_graph, lexer_.GetNextToken(), &type);
  1849. *val_ptr = type;
  1850. return ret;
  1851. } else if (Match(id, "PrimitivePy::")) {
  1852. return ParsePrimitivePy(func_graph, id, val_ptr);
  1853. } else if (Match(id, "Primitive::")) {
  1854. *val_ptr = std::make_shared<Primitive>(id.substr(strlen("Primitive::")));
  1855. return lexer_.GetNextToken();
  1856. } else if (Match(id, "GradOperation::")) {
  1857. return ParseValueGradOperation(id.substr(strlen("GradOperation::")), val_ptr);
  1858. } else {
  1859. return ParseValueGraphAndNamespace(id, val_ptr);
  1860. }
  1861. }
  1862. Token SetListOrTupleValue(const FuncGraphPtr &func_graph, Token left_tok, Token next, bool node_is_valid,
  1863. const std::vector<ValuePtr> &elems, const std::vector<AnfNodePtr> &nodes,
  1864. ValuePtr *const val_ptr, AnfNodePtr *node_ptr) {
  1865. if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) {
  1866. if (node_is_valid && node_ptr != nullptr) {
  1867. MS_EXCEPTION_IF_NULL(func_graph);
  1868. *node_ptr = func_graph->NewCNode(nodes);
  1869. } else {
  1870. *val_ptr = std::make_shared<ValueTuple>(elems);
  1871. }
  1872. return lexer_.GetNextToken();
  1873. } else if (left_tok == TOK_LBRACKET && next == TOK_RBRACKET) {
  1874. if (node_is_valid && node_ptr != nullptr) {
  1875. MS_LOG(EXCEPTION) << "Encounter valid node in value list";
  1876. }
  1877. *val_ptr = std::make_shared<ValueList>(elems);
  1878. return lexer_.GetNextToken();
  1879. } else {
  1880. return TOK_ERROR;
  1881. }
  1882. }
  1883. Token ParseListOrTupleValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr,
  1884. AnfNodePtr *node_ptr = nullptr) {
  1885. Token left_tok = tok;
  1886. std::vector<ValuePtr> elems;
  1887. std::vector<AnfNodePtr> nodes;
  1888. nodes.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>("make_tuple")));
  1889. ValuePtr elem = nullptr;
  1890. AnfNodePtr node = nullptr;
  1891. bool node_is_valid = false;
  1892. bool first_flag = true;
  1893. Token next = TOK_ERROR;
  1894. do {
  1895. next = lexer_.GetNextToken();
  1896. if (first_flag) {
  1897. first_flag = false;
  1898. // case (), zero elements
  1899. if ((left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) ||
  1900. (left_tok == TOK_LBRACKET && next == TOK_RBRACKET)) {
  1901. if (left_tok == TOK_LPARENTHESIS) {
  1902. *val_ptr = std::make_shared<ValueTuple>(elems);
  1903. } else {
  1904. *val_ptr = std::make_shared<ValueList>(elems);
  1905. }
  1906. return lexer_.GetNextToken();
  1907. }
  1908. }
  1909. node = nullptr;
  1910. next = ParseValue(func_graph, next, &elem, &node);
  1911. elems.push_back(elem);
  1912. if (node != nullptr) {
  1913. nodes.push_back(node);
  1914. node_is_valid = true;
  1915. } else {
  1916. nodes.push_back(std::make_shared<ValueNode>(elem));
  1917. }
  1918. } while (next == TOK_COMMA);
  1919. return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr);
  1920. }
  1921. Token ParseValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, AnfNodePtr *node_ptr = nullptr) {
  1922. // tuple or list
  1923. if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) {
  1924. return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr);
  1925. } else if (tok == TOK_IDENTIFIER) {
  1926. return ParseValueBasic(func_graph, lexer_.GetTokenText(), val_ptr, node_ptr);
  1927. } else if (tok == TOK_STRING) {
  1928. *val_ptr = std::make_shared<StringImm>(lexer_.GetTokenText());
  1929. return lexer_.GetNextToken();
  1930. }
  1931. MS_LOG(ERROR) << "Parse error!";
  1932. return TOK_ERROR;
  1933. }
  1934. Token ParseItem(const FuncGraphPtr &func_graph, AnfNodePtr *node_ptr, ValuePtr *const val_ptr,
  1935. Token tok = TOK_INVALID) {
  1936. if (tok == TOK_INVALID) {
  1937. tok = lexer_.GetNextToken();
  1938. }
  1939. if (tok == TOK_VARIABLE) {
  1940. auto iter = cnodes_.find(lexer_.GetTokenText());
  1941. if (iter == cnodes_.end()) {
  1942. MS_LOG(EXCEPTION) << "Can not find definition of '" << lexer_.GetTokenText() << "'";
  1943. }
  1944. *node_ptr = iter->second;
  1945. } else if (tok == TOK_PARAMETER) {
  1946. AnfNodePtr param = FindParameter(func_graph, lexer_.GetTokenText());
  1947. if (param == nullptr) {
  1948. MS_LOG(EXCEPTION) << "Can not find definition of '" << lexer_.GetTokenText() << "' at line "
  1949. << lexer_.GetLineNo();
  1950. }
  1951. *node_ptr = param;
  1952. } else if (tok == TOK_IDENTIFIER || tok == TOK_LPARENTHESIS || tok == TOK_STRING) {
  1953. ValuePtr value;
  1954. AnfNodePtr node;
  1955. tok = ParseValue(func_graph, tok, &value, &node);
  1956. if (tok == TOK_ERROR) {
  1957. MS_LOG(ERROR) << "Parse value error!";
  1958. return tok;
  1959. }
  1960. if (node == nullptr) {
  1961. *val_ptr = value;
  1962. *node_ptr = std::make_shared<ValueNode>(value);
  1963. } else {
  1964. *node_ptr = node;
  1965. }
  1966. return tok;
  1967. } else {
  1968. MS_LOG(EXCEPTION) << "tok_type = " << tok;
  1969. }
  1970. return lexer_.GetNextToken();
  1971. }
  1972. Token ParseArgument(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
  1973. Token tok = lexer_.GetNextToken();
  1974. if (tok == TOK_RPARENTHESIS) {
  1975. return tok;
  1976. }
  1977. AnfNodePtr node = nullptr;
  1978. ValuePtr value = nullptr;
  1979. tok = ParseItem(func_graph, &node, &value, tok);
  1980. if (tok != TOK_ERROR) {
  1981. MS_EXCEPTION_IF_NULL(inputs_ptr);
  1982. inputs_ptr->push_back(node);
  1983. }
  1984. return tok;
  1985. }
  1986. const std::vector<FuncGraphPtr> &GetFuncGraphs() const { return func_graphs_; }
  1987. private:
  1988. Lexer lexer_;
  1989. std::vector<FuncGraphPtr> func_graphs_;
  1990. bool error_flag_ = false;
  1991. // store all parsed graphs
  1992. std::map<std::string, FuncGraphPtr> func_graphs_map_;
  1993. // map from child to parent, consider adding a 'parent' field in class Graph
  1994. std::map<FuncGraphPtr, FuncGraphPtr> parents_map_;
  1995. // map for buffering cnodes when parsing a graph
  1996. std::map<std::string, CNodePtr> cnodes_;
  1997. std::map<std::string, ParameterPtr> param_nodes_; // map parameter name to parameter
  1998. };
  1999. std::vector<FuncGraphPtr> ImportIR(const std::string &filename) {
  2000. IrParser parser(filename.c_str());
  2001. parser.ParseFile();
  2002. return parser.GetFuncGraphs();
  2003. }
  2004. #ifdef ENABLE_DUMP_IR
  2005. void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
  2006. if (func_graph == nullptr) {
  2007. MS_LOG(ERROR) << "Func graph is nullptr";
  2008. return;
  2009. }
  2010. auto ms_context = MsContext::GetInstance();
  2011. if (ms_context == nullptr) {
  2012. MS_LOG(ERROR) << "ms_context is nullptr";
  2013. return;
  2014. }
  2015. auto save_graphs_path = ms_context->save_graphs_path();
  2016. if (save_graphs_path.empty()) {
  2017. save_graphs_path = ".";
  2018. }
  2019. std::string file_path = save_graphs_path + "/" + "ms_output_" + suffix + ".pb";
  2020. if (file_path.size() > PATH_MAX) {
  2021. MS_LOG(ERROR) << "File path " << file_path << " is too long.";
  2022. return;
  2023. }
  2024. char real_path[PATH_MAX] = {0};
  2025. char *real_path_ret = nullptr;
  2026. #if defined(_WIN32) || defined(_WIN64)
  2027. real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX);
  2028. #else
  2029. real_path_ret = realpath(file_path.c_str(), real_path);
  2030. #endif
  2031. if (nullptr == real_path_ret) {
  2032. MS_LOG(DEBUG) << "dir " << file_path << " does not exit.";
  2033. } else {
  2034. std::string path_string = real_path;
  2035. if (chmod(common::SafeCStr(path_string), S_IRUSR | S_IWUSR) == -1) {
  2036. MS_LOG(ERROR) << "Modify file:" << real_path << " to rw fail.";
  2037. return;
  2038. }
  2039. }
  2040. // write to pb file
  2041. std::ofstream ofs(real_path);
  2042. if (!ofs.is_open()) {
  2043. MS_LOG(ERROR) << "Open file '" << real_path << "' failed!";
  2044. return;
  2045. }
  2046. ofs << GetFuncGraphProtoString(func_graph);
  2047. ofs.close();
  2048. // set file mode to read only by user
  2049. ChangeFileMode(file_path, S_IRUSR);
  2050. }
  2051. #else
  2052. void DumpIRProto(const FuncGraphPtr &, const std::string &) {
  2053. static bool already_printed = false;
  2054. if (already_printed) {
  2055. return;
  2056. }
  2057. already_printed = true;
  2058. MS_LOG(WARNING) << "The functionality of dumping function graph IR in protobuf format is disabled, "
  2059. << "please recompile source to enable it. See help of building script.";
  2060. }
  2061. #endif
  2062. } // namespace mindspore