You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf_ir_utils.cc 74 kB

5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "debug/anf_ir_utils.h"
  17. #include <fstream>
  18. #include <map>
  19. #include <memory>
  20. #include <unordered_map>
  21. #include <algorithm>
  22. #include "ir/graph_utils.h"
  23. #include "utils/symbolic.h"
  24. #include "ir/meta_func_graph.h"
  25. #include "ir/param_info.h"
  26. #include "pybind_api/ir/tensor_py.h"
  27. #include "pipeline/jit/parse/python_adapter.h"
  28. #include "pipeline/jit/parse/resolve.h"
  29. #include "frontend/operator/composite/composite.h"
  30. #include "frontend/operator/composite/map.h"
  31. #include "utils/ordered_map.h"
  32. #include "utils/ordered_set.h"
  33. #include "utils/utils.h"
  34. #include "debug/trace.h"
  35. #include "utils/label.h"
  36. #include "utils/ms_context.h"
  37. #include "frontend/operator/ops.h"
  38. using mindspore::tensor::TensorPy;
  39. namespace mindspore {
  40. // max number of elements in sequence
  41. const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF;
  42. // ============================================== MindSpore IR Common ==============================================
  43. // get MindSpore Intermediate Representation Path
  44. std::string GetMsIrPath(void) {
  45. std::string path;
  46. const char *path_ptr = getenv("MS_IR_PATH");
  47. if (path_ptr != nullptr) {
  48. path = path_ptr;
  49. char real_path[PATH_MAX] = {0};
  50. #if defined(_WIN32) || defined(_WIN64)
  51. if (path.size() > PATH_MAX || _fullpath(real_path, path.c_str(), PATH_MAX) == nullptr) {
  52. MS_LOG(EXCEPTION) << "MS IR Path error, " << path_ptr;
  53. }
  54. #else
  55. if (path.size() > PATH_MAX || nullptr == realpath(path.c_str(), real_path)) {
  56. MS_LOG(EXCEPTION) << "MS IR path error, " << path_ptr;
  57. }
  58. #endif
  59. path = real_path;
  60. }
  61. return path;
  62. }
  63. std::string dump_obj(const py::object &obj, const std::string &path) {
  64. py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
  65. py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path));
  66. return py::str(name);
  67. }
  68. py::object load_obj(const std::string &path) {
  69. py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
  70. py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path));
  71. return obj;
  72. }
  73. // ============================================= MindSpore IR Exporter =============================================
  74. std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) {
  75. abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
  76. TypePtr type = dyn_cast<Type>(nd->Type());
  77. std::ostringstream oss;
  78. if ((nullptr != shape) && (nullptr != type)) {
  79. oss << type->DumpText() << shape->DumpText();
  80. } else if (nullptr != type) {
  81. oss << type->DumpText();
  82. } else {
  83. oss << "Undefined";
  84. }
  85. return oss.str();
  86. }
  87. std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const {
  88. std::string pkl_path = GetMsIrPath();
  89. // if not specified env 'MS_IR_PATH', do not create any files
  90. if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) {
  91. return "null";
  92. }
  93. std::string file_prefix = id_ + "." + category;
  94. std::string file_name = dump_obj(obj, pkl_path + "/" + file_prefix);
  95. return file_prefix + file_name;
  96. }
  97. int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &param, bool throw_excp) {
  98. if (func_graph == nullptr || param == nullptr) {
  99. return -1;
  100. }
  101. FuncGraphPtr fg = func_graph;
  102. while (fg != nullptr) {
  103. if (exported.find(fg) == exported.end()) {
  104. if (!check_integrity_) {
  105. break;
  106. }
  107. MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'";
  108. }
  109. auto param_map = exported[fg];
  110. if (param_map.find(param) != param_map.end()) {
  111. return param_map[param];
  112. }
  113. fg = fg->parent();
  114. }
  115. if (throw_excp) {
  116. MS_LOG(EXCEPTION) << "Can not find index for param '" << param->DumpText() << "' for func graph '"
  117. << func_graph->DumpText() << "." << func_graph->debug_info()->get_id() << "'";
  118. }
  119. return -1;
  120. }
  121. // try to find index of parameter for SymbolicKeyInstance from all exported graphs
  122. // NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different
  123. int AnfExporter::GetParamIndexFromExported(const AnfNodePtr &param) {
  124. if (param == nullptr) {
  125. return -1;
  126. }
  127. int ret = -1;
  128. for (const auto &item : exported) {
  129. auto pram_iter = item.second.find(param);
  130. if (pram_iter != item.second.end()) {
  131. return pram_iter->second;
  132. }
  133. }
  134. return ret;
  135. }
  136. std::string AnfExporter::GetValueNodeText(const FuncGraphPtr &fg, const ValueNodePtr &node) {
  137. MS_EXCEPTION_IF_NULL(node);
  138. return GetValueText(fg, node->value());
  139. }
  140. std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph) {
  141. auto py_funcs = mt_func_graph->GetPyFunctions();
  142. if (py_funcs.empty()) {
  143. return "";
  144. }
  145. std::ostringstream oss;
  146. oss << "{";
  147. bool is_first = true;
  148. for (const auto &py_func : py_funcs) {
  149. if (is_first) {
  150. is_first = false;
  151. } else {
  152. oss << ", ";
  153. }
  154. oss << "(";
  155. for (size_t i = 0; i < py_func.first.size(); ++i) {
  156. if (i > 0) {
  157. oss << ", ";
  158. }
  159. oss << py_func.first[i]->DumpText();
  160. }
  161. oss << ")";
  162. // dump Python Function object
  163. oss << "@" << DumpObject(py_func.second, "F");
  164. }
  165. oss << "}";
  166. return oss.str();
  167. }
  168. /* inherit relation of MetaFuncGraph
  169. *
  170. * MetaGraph
  171. * ├── MultitypeGraph
  172. * ├── HyperMap
  173. * │ └── HyperMapPy
  174. * ├── Map
  175. * │ └── MapPy
  176. * ├── Tail
  177. * ├── MakeTupleGradient
  178. * ├── MakeListGradient
  179. * ├── GradOperation
  180. * └── TupleAdd
  181. */
  182. std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) {
  183. if (meta_func_graph == nullptr) {
  184. return "";
  185. }
  186. std::ostringstream oss;
  187. oss << meta_func_graph->type_name() << "::" << meta_func_graph->name();
  188. if (meta_func_graph->isa<prim::MultitypeFuncGraph>()) {
  189. prim::MultitypeFuncGraphPtr mt_func_graph = meta_func_graph->cast<prim::MultitypeFuncGraphPtr>();
  190. oss << GetMultitypeFuncGraphText(mt_func_graph);
  191. } else if (meta_func_graph
  192. ->isa<prim::HyperMapPy>()) { // this statement must before 'meta_graph->isa<prim::HyperMap>()'
  193. auto hyper_map = meta_func_graph->cast<prim::HyperMapPyPtr>();
  194. if (hyper_map->GetFnLeaf() != nullptr) {
  195. oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}";
  196. }
  197. } else if (meta_func_graph->isa<prim::HyperMap>()) {
  198. auto hyper_map = meta_func_graph->cast<prim::HyperMapPtr>();
  199. if (hyper_map->GetFnLeaf() != nullptr) {
  200. oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}";
  201. }
  202. } else if (meta_func_graph->isa<prim::MapPy>()) { // this statement must before 'meta_graph->isa<prim::Map>()'
  203. auto map = meta_func_graph->cast<prim::MapPyPtr>();
  204. if (map->GetFnLeaf() != nullptr) {
  205. oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}";
  206. }
  207. } else if (meta_func_graph->isa<prim::Map>()) {
  208. auto map = meta_func_graph->cast<prim::MapPtr>();
  209. if (map->GetFnLeaf() != nullptr) {
  210. oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}";
  211. }
  212. } else if (meta_func_graph->isa<prim::GradOperation>()) {
  213. prim::GradOperationPtr grad_op = meta_func_graph->cast<prim::GradOperationPtr>();
  214. oss << "{get_all=" << grad_op->get_all_ << ", get_by_list=" << grad_op->get_by_list_
  215. << ", sens_param=" << grad_op->sens_param_ << "}";
  216. } else if (meta_func_graph->isa<prim::Tail>()) {
  217. // do nothing
  218. } else if (meta_func_graph->isa<prim::MakeTupleGradient>()) {
  219. // do nothing
  220. } else if (meta_func_graph->isa<prim::MakeListGradient>()) {
  221. // do nothing
  222. } else if (meta_func_graph->isa<prim::TupleAdd>()) {
  223. // do nothing
  224. } else if (meta_func_graph->isa<prim::TupleSlice>()) {
  225. // do nothing
  226. } else if (meta_func_graph->isa<prim::UnpackCall>()) {
  227. // do nothing
  228. } else if (meta_func_graph->isa<prim::ZipOperation>()) {
  229. // do nothing
  230. } else if (meta_func_graph->isa<prim::ListAppend>()) {
  231. // do nothing
  232. } else if (meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
  233. // do nothing
  234. } else {
  235. MS_LOG(EXCEPTION) << "Unknown MetaFuncGraph type " << meta_func_graph->type_name();
  236. }
  237. return oss.str();
  238. }
  239. std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) {
  240. std::ostringstream oss;
  241. if (prim == nullptr) {
  242. return oss.str();
  243. }
  244. oss << prim->type_name() << "::" << prim->name();
  245. // need to serialize internal python function of PrimitivePy and record its prim_type
  246. if (prim->isa<PrimitivePy>()) {
  247. PrimitivePyPtr primpy = prim->cast<PrimitivePyPtr>();
  248. // dump related function in PrimitivePy
  249. oss << "@" << DumpObject(primpy->GetPyObj(), "P");
  250. // output primitive type
  251. oss << "{prim_type=" << static_cast<int>(prim->prim_type()) << "}";
  252. }
  253. // output primitive attributes
  254. oss << prim->GetAttrsText();
  255. if (prim->isa<prim::DoSignaturePrimitive>()) {
  256. auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim);
  257. auto &func = do_signature->function();
  258. if (func->isa<Primitive>()) {
  259. auto sig_prim = dyn_cast<Primitive>(func);
  260. oss << sig_prim->GetAttrsText();
  261. }
  262. }
  263. return oss.str();
  264. }
  265. std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) {
  266. std::ostringstream oss;
  267. if (ns == nullptr) {
  268. return oss.str();
  269. }
  270. // dump related module information in Namespace
  271. oss << ns->type_name() << "::" << ns->module() << "@" << DumpObject(ns->obj(), "N");
  272. return oss.str();
  273. }
  274. std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph,
  275. const SymbolicKeyInstancePtr &sym_inst) {
  276. MS_EXCEPTION_IF_NULL(func_graph);
  277. MS_EXCEPTION_IF_NULL(sym_inst);
  278. AnfNodePtr sym_node = sym_inst->node();
  279. MS_EXCEPTION_IF_NULL(sym_node);
  280. std::ostringstream oss;
  281. if (sym_node->isa<Parameter>()) {
  282. int idx = GetParamIndex(func_graph, sym_node, false);
  283. // if can not find SymbolicKeyInstance related parameter from ancestors,
  284. // try to find from all exported graphs
  285. if (idx < 0) {
  286. idx = GetParamIndexFromExported(sym_node);
  287. }
  288. if (idx < 0) {
  289. ParameterPtr p = dyn_cast<Parameter>(sym_node);
  290. if (p == nullptr) {
  291. MS_LOG(EXCEPTION) << "Sym_inst's node could not cast to parameter";
  292. }
  293. MS_LOG(WARNING) << "Can not find SymbolicKeyInstance: " << p->name();
  294. }
  295. oss << "SymInst(%para" << idx << ")";
  296. } else {
  297. MS_LOG(WARNING) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString();
  298. oss << "SymInst(cnode_" << sym_node->ToString() << ")";
  299. }
  300. return oss.str();
  301. }
  302. std::string AnfExporter::GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
  303. std::ostringstream oss;
  304. // output ValueList, ValueTuple
  305. ValueSequeuePtr seq = dyn_cast<ValueSequeue>(value);
  306. MS_EXCEPTION_IF_NULL(seq);
  307. MS_EXCEPTION_IF_NULL(value);
  308. bool is_tuple = value->isa<ValueTuple>();
  309. oss << (is_tuple ? "(" : "[");
  310. bool first_flag = true;
  311. for (auto elem : seq->value()) {
  312. if (first_flag) {
  313. first_flag = false;
  314. } else {
  315. oss << ", ";
  316. }
  317. oss << GetValueText(func_graph, elem);
  318. }
  319. oss << (is_tuple ? ")" : "]");
  320. return oss.str();
  321. }
  322. std::string AnfExporter::GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
  323. std::ostringstream oss;
  324. ValueDictionaryPtr dict = value->cast<ValueDictionaryPtr>();
  325. oss << "{";
  326. bool first_flag = true;
  327. for (const auto &elem : dict->value()) {
  328. if (first_flag) {
  329. first_flag = false;
  330. } else {
  331. oss << ", ";
  332. }
  333. oss << "\"" << elem.first << "\": " << GetValueText(func_graph, elem.second);
  334. }
  335. oss << "}";
  336. return oss.str();
  337. }
  338. std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr &value) {
  339. std::ostringstream oss;
  340. if (check_integrity_) {
  341. MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText();
  342. }
  343. oss << value->type_name() << "[" << value->DumpText() << "]";
  344. return oss.str();
  345. }
  346. std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
  347. std::ostringstream oss;
  348. bool is_null_ptr = (func_graph == nullptr || value == nullptr);
  349. if (is_null_ptr) {
  350. return oss.str();
  351. }
  352. if (value->isa<Primitive>()) {
  353. oss << GetPrimitiveText(value->cast<PrimitivePtr>());
  354. } else if (value->isa<MetaFuncGraph>()) {
  355. MetaFuncGraphPtr meta_func_graph = value->cast<MetaFuncGraphPtr>();
  356. oss << GetMetaFuncGraphText(meta_func_graph);
  357. } else if (value->isa<SymbolicKeyInstance>()) {
  358. oss << GetSymbolicKeyInstanceText(func_graph, value->cast<SymbolicKeyInstancePtr>());
  359. } else if (value->isa<RefKey>()) {
  360. oss << value->DumpText();
  361. } else if (value->isa<Scalar>() || value->isa<StringImm>()) {
  362. oss << value->DumpText();
  363. } else if (value->isa<tensor::Tensor>()) {
  364. auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
  365. oss << value->DumpText() << "@" << DumpObject(TensorPy::AsNumpy(*tensor_ptr), "T");
  366. } else if (value->isa<parse::Symbol>() || value->isa<None>() || value->isa<Null>()) {
  367. oss << value->DumpText();
  368. } else if (value->isa<ValueSequeue>()) {
  369. oss << GetSequenceText(func_graph, value);
  370. } else if (value->isa<ValueDictionary>()) {
  371. oss << GetDictText(func_graph, value);
  372. } else if (value->isa<ValueSlice>()) {
  373. ValueSlicePtr slice = value->cast<ValueSlicePtr>();
  374. oss << slice->DumpText();
  375. } else if (value->isa<Type>()) {
  376. oss << value->DumpText();
  377. } else if (value->isa<parse::NameSpace>()) {
  378. oss << GetNameSpaceText(value->cast<parse::NameSpacePtr>());
  379. } else if (value->isa<parse::PyObjectWrapper>()) {
  380. oss << value->type_name();
  381. } else if (value->isa<KeywordArg>()) {
  382. KeywordArgPtr keyword_arg = value->cast<KeywordArgPtr>();
  383. oss << keyword_arg->DumpText();
  384. } else {
  385. return GetOtherValueText(func_graph, value);
  386. }
  387. return oss.str();
  388. }
  389. // this function is used to output node in CNode's inputs
  390. std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
  391. const std::map<AnfNodePtr, int> &apply_map) {
  392. std::ostringstream oss;
  393. if (func_graph == nullptr || node == nullptr) {
  394. return oss.str();
  395. }
  396. if (node->isa<CNode>()) {
  397. auto iter = apply_map.find(node);
  398. if (iter == apply_map.end()) {
  399. MS_LOG(EXCEPTION) << "Can not find node '" << node->DumpText() << "' in apply_map";
  400. }
  401. oss << "%" << iter->second;
  402. } else if (node->isa<Parameter>()) {
  403. oss << "%para" << GetParamIndex(func_graph, node, check_integrity_);
  404. } else if (IsValueNode<FuncGraph>(node)) {
  405. FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
  406. oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id();
  407. if (!func_graph_set.contains(fg) && exported.find(fg) == exported.end() && export_used_) {
  408. func_graph_set.add(fg);
  409. }
  410. } else if (node->isa<ValueNode>()) {
  411. oss << GetValueNodeText(func_graph, node->cast<ValueNodePtr>());
  412. } else {
  413. MS_LOG(EXCEPTION) << "Unknown node '" << node->DumpText() << "'";
  414. }
  415. return oss.str();
  416. }
  417. void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> &parameters,
  418. OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map) {
  419. bool first_flag = true;
  420. for (const AnfNodePtr &param : parameters) {
  421. if (first_flag) {
  422. first_flag = false;
  423. ofs << " ";
  424. } else {
  425. ofs << " , ";
  426. }
  427. (*param_map)[param] = param_index;
  428. std::string type_info = GetNodeType(param);
  429. // output parameter and type
  430. if (type_info == "Undefined") {
  431. ofs << "%para" << param_index;
  432. } else {
  433. ofs << "%para" << param_index << " : " << type_info;
  434. }
  435. // dump Default value of parameter if exists
  436. const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
  437. if (param_ptr == nullptr) {
  438. MS_LOG(EXCEPTION) << "Param could not cast to parameter";
  439. }
  440. if (param_ptr->has_default()) {
  441. auto param_value = param_ptr->default_param();
  442. ofs << " = @" << DumpObject(py::cast(param_value), "D");
  443. }
  444. // output comment
  445. ofs << " # " << param->DumpText() << "\n";
  446. param_index += 1;
  447. }
  448. }
  449. void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &node) {
  450. if (node == nullptr) {
  451. return;
  452. }
  453. // output type of each input argument
  454. auto &inputs = node->inputs();
  455. if (inputs.size() > 1) {
  456. ofs << " #(";
  457. for (size_t i = 1; i < inputs.size(); ++i) {
  458. if (i != 1) {
  459. ofs << ", ";
  460. }
  461. AnfNodePtr arg = inputs[i];
  462. ofs << GetNodeType(arg);
  463. }
  464. ofs << ")";
  465. }
  466. // output other comment, map the graph name to original representation(containing unicode character)
  467. std::ostringstream comment;
  468. comment << " #";
  469. bool has_comment = false;
  470. for (size_t i = 0; i < inputs.size(); ++i) {
  471. AnfNodePtr arg = inputs[i];
  472. if (!IsValueNode<FuncGraph>(arg)) {
  473. continue;
  474. }
  475. if (!has_comment) {
  476. has_comment = true;
  477. } else {
  478. comment << ",";
  479. }
  480. FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(arg);
  481. std::string func_graph_id = fg->debug_info()->get_id();
  482. comment << " fg_" << func_graph_id << "=" << fg->ToString() << "." << func_graph_id;
  483. }
  484. if (has_comment) {
  485. ofs << comment.str();
  486. }
  487. ofs << " #scope: " << node->scope()->name();
  488. }
  489. void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes,
  490. const FuncGraphPtr &func_graph) {
  491. if (func_graph == nullptr) {
  492. return;
  493. }
  494. int idx = 1;
  495. std::map<AnfNodePtr, int> apply_map;
  496. for (const AnfNodePtr &node : nodes) {
  497. MS_EXCEPTION_IF_NULL(node);
  498. if (!node->isa<CNode>()) {
  499. continue;
  500. }
  501. auto iter = tagged_cnodes_.find(node);
  502. if (iter != tagged_cnodes_.end()) {
  503. ofs << "\n#------------------------> " << iter->second << "\n";
  504. }
  505. auto cnode = node->cast<CNodePtr>();
  506. auto &inputs = cnode->inputs();
  507. std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map);
  508. // non-return node
  509. if (node != func_graph->get_return()) {
  510. int apply_idx = idx++;
  511. apply_map[node] = apply_idx;
  512. std::string type_info = GetNodeType(node);
  513. if (type_info == "Undefined") {
  514. ofs << " %" << apply_idx << " = " << op_text << "(";
  515. } else {
  516. ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "(";
  517. }
  518. } else {
  519. ofs << " " << op_text << "(";
  520. }
  521. for (size_t i = 1; i < inputs.size(); ++i) {
  522. if (i != 1) {
  523. ofs << ", ";
  524. }
  525. AnfNodePtr arg = inputs[i];
  526. ofs << GetAnfNodeText(func_graph, arg, apply_map);
  527. }
  528. ofs << ")";
  529. // output comment
  530. OutputStatementComment(ofs, cnode);
  531. ofs << "\n";
  532. if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
  533. ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#"
  534. << label_manage::Label(cnode->debug_info()) << "\n";
  535. } else {
  536. ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n";
  537. }
  538. }
  539. }
  540. void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) {
  541. if (func_graph == nullptr) {
  542. return;
  543. }
  544. std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
  545. std::vector<AnfNodePtr> parameters = func_graph->parameters();
  546. OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;
  547. if (*(func_graph->switch_layer_input())) {
  548. ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n";
  549. }
  550. ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "."
  551. << func_graph->debug_info()->get_id() << "\n";
  552. if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
  553. ofs << trace::GetDebugInfo(func_graph->debug_info(), "# ", kSourceLineTipDiscard) << "#"
  554. << label_manage::Label(func_graph->debug_info()) << "\n";
  555. } else {
  556. ofs << trace::GetDebugInfo(func_graph->debug_info(), "# ", kSourceLineTipDiscard) << "\n";
  557. }
  558. ofs << "funcgraph fg_" << func_graph->debug_info()->get_id();
  559. // output name of parent of graph if exists
  560. if (func_graph->parent() != nullptr) {
  561. ofs << "[fg_" << func_graph->parent()->debug_info()->get_id() << "]";
  562. }
  563. ofs << "(\n";
  564. OutputParameters(ofs, parameters, &param_map);
  565. exported[func_graph] = param_map;
  566. ofs << (!parameters.empty() ? " " : "") << ") {\n";
  567. OutputCNodes(ofs, nodes, func_graph);
  568. ofs << "}\n";
  569. }
  570. void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
  571. if (func_graph == nullptr) {
  572. return;
  573. }
  574. std::ofstream ofs(filename);
  575. if (!ofs.is_open()) {
  576. MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
  577. return;
  578. }
  579. param_index = 1;
  580. func_graph_set.add(func_graph);
  581. while (!func_graph_set.empty()) {
  582. FuncGraphPtr fg = *func_graph_set.begin();
  583. ExportOneFuncGraph(ofs, fg);
  584. ofs << "\n\n";
  585. (void)func_graph_set.erase(fg);
  586. }
  587. ofs << "# num of total function graphs: " << exported.size();
  588. ofs.close();
  589. }
  590. void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
  591. if (graphs.empty()) {
  592. return;
  593. }
  594. std::ofstream ofs(filename);
  595. if (!ofs.is_open()) {
  596. MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
  597. return;
  598. }
  599. param_index = 1;
  600. for (const auto &tagged_graph : graphs) {
  601. tagged_cnodes_ = tagged_graph.second;
  602. ExportOneFuncGraph(ofs, tagged_graph.first);
  603. tagged_cnodes_.clear();
  604. ofs << "\n\n";
  605. }
  606. ofs << "# num of total function graphs: " << graphs.size();
  607. ofs.close();
  608. }
  609. #ifdef ENABLE_DUMP_IR
  610. void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) {
  611. if (func_graph == nullptr) {
  612. return;
  613. }
  614. AnfExporter exporter(id);
  615. ChangeFileMode(filename, S_IRWXU);
  616. exporter.ExportFuncGraph(filename, func_graph);
  617. // set file mode to read only by user
  618. ChangeFileMode(filename, S_IRUSR);
  619. }
  620. void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
  621. AnfExporter exporter("", false);
  622. ChangeFileMode(filename, S_IRWXU);
  623. exporter.ExportFuncGraph(filename, graphs);
  624. // set file mode to read only by user
  625. ChangeFileMode(filename, S_IRUSR);
  626. }
  627. #else
  628. void ExportIR(const std::string &, const std::string &, const FuncGraphPtr &) {
  629. static bool already_printed = false;
  630. if (already_printed) {
  631. return;
  632. }
  633. already_printed = true;
  634. MS_LOG(WARNING) << "The functionality of dumping function graph IR is disabled, "
  635. << "please recompile source to enable it. See help of building script.";
  636. }
  637. void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
  638. static bool already_printed = false;
  639. if (already_printed) {
  640. return;
  641. }
  642. already_printed = true;
  643. MS_LOG(WARNING) << "The functionality of dumping function graph IR is disabled, "
  644. << "please recompile source to enable it. See help of building script.";
  645. }
  646. #endif
  647. // ============================================= MindSpore IR Importer =============================================
  648. enum Token : int {
  649. TOK_INVALID = 0, // invalid token
  650. TOK_LPARENTHESIS, // ( left parenthesis
  651. TOK_RPARENTHESIS, // ) right parenthesis
  652. TOK_LBRACKET, // [ left bracket
  653. TOK_RBRACKET, // ] right bracket
  654. TOK_LBRACE, // { left brace
  655. TOK_RBRACE, // } right brace
  656. TOK_COMMA, // , comma
  657. TOK_EQUALITY, // = equality
  658. TOK_COLON, // : colon
  659. TOK_STAR, // * star
  660. TOK_VARIABLE, // variable
  661. TOK_AT_FILE, // @filename
  662. TOK_PARAMETER, // parameter
  663. TOK_IDENTIFIER, // identifier
  664. TOK_FUNCGRAPH, // keyword 'funcgraph'
  665. TOK_RETURN, // id prim::return
  666. TOK_STRING, // string
  667. TOK_NUMBER, // number
  668. TOK_COMMENT, // comment
  669. TOK_EOL, // end of line
  670. TOK_EOF, // end of file
  671. TOK_ERROR // file read error
  672. };
  673. std::map<Token, const char *> token_text = {
  674. {TOK_INVALID, "invalid"}, // invalid token
  675. {TOK_LPARENTHESIS, "("}, // ( left parenthesis
  676. {TOK_RPARENTHESIS, ")"}, // ) right parenthesis
  677. {TOK_LBRACKET, "["}, // [ left bracket
  678. {TOK_RBRACKET, "]"}, // ] right bracket
  679. {TOK_LBRACE, "{"}, // { left brace
  680. {TOK_RBRACE, "}"}, // } right brace
  681. {TOK_COMMA, ","}, // , comma
  682. {TOK_EQUALITY, "="}, // = equality
  683. {TOK_COLON, ":"}, // : colon
  684. {TOK_STAR, "*"}, // * start
  685. {TOK_VARIABLE, nullptr}, // variable
  686. {TOK_AT_FILE, nullptr}, // @file
  687. {TOK_PARAMETER, nullptr}, // parameter
  688. {TOK_IDENTIFIER, nullptr}, // identifier
  689. {TOK_FUNCGRAPH, "funcgraph"}, // keyword 'funcgraph'
  690. {TOK_RETURN, nullptr}, // id prim::return
  691. {TOK_STRING, nullptr}, // string
  692. {TOK_NUMBER, nullptr}, // number
  693. {TOK_COMMENT, nullptr}, // comment
  694. {TOK_EOL, "\n"}, // end of line
  695. {TOK_EOF, ""}, // end of file
  696. {TOK_ERROR, "error"} // file read error
  697. };
  698. class Lexer {
  699. public:
  700. // filename is checked in ImportIR;
  701. explicit Lexer(const char *filename) : fin(filename) {}
  702. ~Lexer() {
  703. try {
  704. if (fin.is_open()) {
  705. fin.close();
  706. }
  707. } catch (const std::exception &e) {
  708. MS_LOG(ERROR) << "Exception when closing file";
  709. } catch (...) {
  710. std::string exName(abi::__cxa_current_exception_type()->name());
  711. MS_LOG(ERROR) << "Error occurred when closing file. Exception name: " << exName;
  712. }
  713. }
  714. bool IsSingleCharToken(char ch, Token *token_ptr) {
  715. // clang-format off
  716. std::unordered_map<char, Token> char_to_token = {
  717. {'(', TOK_LPARENTHESIS},
  718. {')', TOK_RPARENTHESIS},
  719. {'[', TOK_LBRACKET},
  720. {']', TOK_RBRACKET},
  721. {'{', TOK_LBRACE},
  722. {'}', TOK_RBRACE},
  723. {',', TOK_COMMA},
  724. {'=', TOK_EQUALITY},
  725. {':', TOK_COLON},
  726. {'*', TOK_STAR}};
  727. // clang-format on
  728. auto iter = char_to_token.find(ch);
  729. if (iter == char_to_token.end()) {
  730. return false;
  731. }
  732. if (token_ptr != nullptr) {
  733. *token_ptr = iter->second;
  734. }
  735. return true;
  736. }
  737. Token GetNextToken() {
  738. #ifdef DEBUG
  739. Token token = GetNextTokenInner();
  740. const char *str = token_text[token];
  741. std::string text = (str == nullptr ? GetTokenText() : str);
  742. MS_LOG(DEBUG) << "------Parse token] " << text;
  743. return token;
  744. }
  745. Token GetNextTokenInner() {
  746. #endif
  747. tok_idx = 0;
  748. Token tok = TOK_ERROR;
  749. char ch = SkipTabAndSpace();
  750. if (ch == CODE_EOF) {
  751. return TOK_EOF;
  752. } else if (ch == CODE_ERROR) {
  753. return TOK_ERROR;
  754. } else if (IsSingleCharToken(ch, &tok)) {
  755. return tok;
  756. } else if (ch == '\r') {
  757. char c = GetChar();
  758. if (c == '\n') {
  759. line_++;
  760. return TOK_EOL;
  761. }
  762. UnGetChar(c);
  763. line_++;
  764. return TOK_EOL;
  765. } else if (ch == '\n') {
  766. line_++;
  767. return TOK_EOL;
  768. } else if (ch == '#') {
  769. return ParseComment(ch);
  770. } else if (ch == '"') {
  771. return ParseString();
  772. } else if (ch == '%') {
  773. return ParseVariableOrParameter(ch);
  774. } else if (ch == '@') {
  775. return ParseAtFile();
  776. } else if (IsDigit(ch) || ch == '-') {
  777. return ParseNumber(ch);
  778. } else if (IsAlpha(ch) || ch == '_') {
  779. return ParseIdentifier(ch);
  780. } else {
  781. return TOK_ERROR;
  782. }
  783. }
  784. Token SkipWhiteToken() {
  785. Token tok = GetNextToken();
  786. while (tok == TOK_EOL || tok == TOK_COMMENT) {
  787. tok = GetNextToken();
  788. }
  789. return tok;
  790. }
  791. std::string GetTokenText() const { return std::string(tok_buf); }
  792. int GetLineNo() const { return line_; }
  793. private:
  794. Token ParseComment(char ch) {
  795. char c = GetChar();
  796. while (c != '\r' && c != '\n' && c != CODE_EOF) {
  797. c = GetChar();
  798. }
  799. if (ch != CODE_EOF) {
  800. UnGetChar(c);
  801. }
  802. tok_buf[0] = '#';
  803. tok_buf[1] = '\0';
  804. return TOK_COMMENT;
  805. }
  806. Token ParseString() {
  807. tok_idx = 0;
  808. char c = GetChar();
  809. while (c != '"') {
  810. if (tok_idx >= BUF_SIZE) {
  811. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  812. }
  813. if (c == '\r' || c == '\n') {
  814. MS_LOG(EXCEPTION) << "Literal newline characters are not allowed within the quote at line " << line_;
  815. }
  816. if (c == CODE_EOF) {
  817. MS_LOG(EXCEPTION) << "Encounter EOF within the quote at line " << line_;
  818. }
  819. tok_buf[tok_idx++] = c;
  820. c = GetChar();
  821. }
  822. tok_buf[tok_idx] = '\0';
  823. return TOK_STRING;
  824. }
  825. Token ParseVariableOrParameter(char ch) {
  826. tok_idx = 0;
  827. tok_buf[tok_idx++] = ch;
  828. char c = GetChar();
  829. while (IsAlphaNumeric(c)) {
  830. if (tok_idx >= BUF_SIZE) {
  831. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  832. }
  833. tok_buf[tok_idx++] = c;
  834. c = GetChar();
  835. }
  836. tok_buf[tok_idx] = '\0';
  837. UnGetChar(c);
  838. // judge parameter: %para[0-9]+
  839. tok_buf[tok_idx] = '\0';
  840. std::string param_key = "%para";
  841. if (strncmp(tok_buf, param_key.c_str(), param_key.size()) == 0) {
  842. if (tok_idx <= param_key.size()) {
  843. return TOK_ERROR;
  844. }
  845. for (auto i = static_cast<unsigned>(param_key.size()); i < tok_idx; ++i) {
  846. if (!IsDigit(tok_buf[i])) {
  847. return TOK_ERROR;
  848. }
  849. }
  850. return TOK_PARAMETER;
  851. }
  852. // judge local variable: %[0-9]+
  853. if (tok_idx == 1) {
  854. return TOK_ERROR;
  855. }
  856. for (unsigned i = 1; i < tok_idx; ++i) {
  857. if (!IsDigit(tok_buf[i])) {
  858. return TOK_ERROR;
  859. }
  860. }
  861. return TOK_VARIABLE;
  862. }
  863. Token ParseAtFile() {
  864. tok_idx = 0;
  865. char c = GetChar();
  866. while (IsAlphaNumeric(c) || c == '_' || c == '.') {
  867. if (tok_idx >= BUF_SIZE) {
  868. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  869. }
  870. tok_buf[tok_idx++] = c;
  871. c = GetChar();
  872. }
  873. tok_buf[tok_idx] = '\0';
  874. UnGetChar(c);
  875. if (tok_idx == 0) {
  876. return TOK_ERROR;
  877. }
  878. return TOK_AT_FILE;
  879. }
  880. Token ParseNumber(char ch) {
  881. tok_buf[tok_idx++] = ch;
  882. char c = GetChar();
  883. // parse number, e.g. 10, 15.6, 1e-5
  884. while (IsDigit(c) || c == '.' || c == 'e' || c == '-') {
  885. if (tok_idx >= BUF_SIZE) {
  886. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  887. }
  888. tok_buf[tok_idx++] = c;
  889. c = GetChar();
  890. }
  891. UnGetChar(c);
  892. tok_buf[tok_idx] = '\0';
  893. return TOK_NUMBER;
  894. }
  895. Token ParseIdentifier(char ch) {
  896. tok_idx = 0;
  897. tok_buf[tok_idx++] = ch;
  898. char c = GetChar();
  899. while (IsAlphaNumeric(c) || c == '.' || c == ':' || c == '_') {
  900. if (tok_idx >= BUF_SIZE) {
  901. MS_LOG(EXCEPTION) << "Length of token which is " << tok_idx << " exceeds " << BUF_SIZE;
  902. }
  903. tok_buf[tok_idx++] = c;
  904. c = GetChar();
  905. }
  906. UnGetChar(c);
  907. tok_buf[tok_idx] = '\0';
  908. if (strcmp(tok_buf, "funcgraph") == 0) {
  909. return TOK_FUNCGRAPH;
  910. }
  911. if (strcmp(tok_buf, "Primitive::return") == 0) {
  912. return TOK_RETURN;
  913. }
  914. return TOK_IDENTIFIER;
  915. }
  916. // Suppose the file only contain ASCII character
  917. char GetChar() {
  918. if (ungot_char != UNGOT_CHAR) {
  919. char ch = ungot_char;
  920. ungot_char = UNGOT_CHAR;
  921. return ch;
  922. }
  923. if (idx >= cnt) {
  924. if (fin.eof()) {
  925. return CODE_EOF;
  926. }
  927. cnt = fin.read(buffer, BUF_SIZE).gcount();
  928. if ((fin.bad() || fin.fail()) && !fin.eof()) {
  929. MS_LOG(EXCEPTION) << "Read file error!";
  930. }
  931. idx = 0;
  932. }
  933. return buffer[idx++];
  934. }
  935. void UnGetChar(char ch) {
  936. if (ungot_char == UNGOT_CHAR) {
  937. ungot_char = ch;
  938. }
  939. }
  940. static bool IsTabOrSpace(char ch) { return ch == ' ' || ch == '\t'; }
  941. static bool IsDigit(char ch) { return ch >= '0' && ch <= '9'; }
  942. static bool IsAlpha(char ch) { return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z'); }
  943. static bool IsAlphaNumeric(char ch) { return IsDigit(ch) || IsAlpha(ch); }
  944. // skip whitespace(including comment) to read a valid character
  945. char SkipTabAndSpace() {
  946. char ch = GetChar();
  947. while (IsTabOrSpace(ch)) {
  948. ch = GetChar();
  949. }
  950. return ch;
  951. }
  952. std::ifstream fin;
  953. static const unsigned BUF_SIZE = 4096; // lexer buffer size
  954. char buffer[BUF_SIZE + 1] = {0}; // buffer for holding text read from text
  955. std::streamsize cnt = 0; // number of valid characters in the buffer
  956. unsigned idx = 0; // index of next char the lexer to read from
  957. char tok_buf[BUF_SIZE + 1] = {0}; // token buffer
  958. unsigned tok_idx = 0; // token buffer index
  959. char ungot_char = UNGOT_CHAR; // store ungot char
  960. static const int CODE_EOF = -1; // return code of GetChar
  961. static const int CODE_ERROR = -2; // read file error
  962. static const char UNGOT_CHAR = -3; // value of ungot char
  963. int line_ = 1; // current line number
  964. };
  965. const unsigned Lexer::BUF_SIZE;
  966. class IrParser {
  967. public:
  968. explicit IrParser(const char *filename) : lexer_(filename) {}
  969. ~IrParser() {}
  970. py::object LoadObject(const std::string &file_name) const {
  971. std::string pkl_path = GetMsIrPath();
  972. py::object default_obj = load_obj(pkl_path + "/" + file_name);
  973. return default_obj;
  974. }
  975. void ParseFile() {
  976. FuncGraphPtr func_graph = ParseFuncGraph();
  977. while (func_graph != nullptr) {
  978. func_graphs_.push_back(func_graph);
  979. func_graph = ParseFuncGraph();
  980. }
  981. if (error_flag_) {
  982. MS_LOG(EXCEPTION) << "Parse Error at line: " << lexer_.GetLineNo();
  983. }
  984. MS_LOG(INFO) << "Total graphs: " << func_graphs_.size();
  985. }
  986. Token ParseParent(FuncGraphPtr *const parent_ptr) {
  987. if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
  988. return TOK_ERROR;
  989. }
  990. std::string parent_name = lexer_.GetTokenText();
  991. // NOTICE: require definition of parent graph must before child graph
  992. auto iter = func_graphs_map_.find(parent_name);
  993. if (iter == func_graphs_map_.end()) {
  994. MS_LOG(EXCEPTION) << "Can not find definition of parent func graph '" << parent_name << "' at line "
  995. << lexer_.GetLineNo();
  996. }
  997. if (parent_ptr != nullptr) {
  998. *parent_ptr = iter->second;
  999. }
  1000. if (lexer_.GetNextToken() != TOK_RBRACKET) {
  1001. return TOK_ERROR;
  1002. }
  1003. return lexer_.GetNextToken();
  1004. }
  1005. FuncGraphPtr ParseFuncGraph() {
  1006. cnodes_.clear();
  1007. Token tok = lexer_.SkipWhiteToken();
  1008. if (tok != TOK_FUNCGRAPH) {
  1009. error_flag_ = tok != TOK_EOF;
  1010. return nullptr;
  1011. }
  1012. if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
  1013. error_flag_ = true;
  1014. return nullptr;
  1015. }
  1016. std::string func_graph_name = lexer_.GetTokenText();
  1017. if (func_graphs_map_.find(func_graph_name) == func_graphs_map_.end()) {
  1018. func_graphs_map_[func_graph_name] = std::make_shared<FuncGraph>();
  1019. }
  1020. FuncGraphPtr func_graph = func_graphs_map_[func_graph_name];
  1021. MS_EXCEPTION_IF_NULL(func_graph);
  1022. MS_EXCEPTION_IF_NULL(func_graph->debug_info());
  1023. func_graph->debug_info()->set_name(func_graph_name); // for debugging
  1024. FuncGraphPtr parent = nullptr;
  1025. tok = lexer_.GetNextToken();
  1026. if (tok == TOK_LBRACKET) {
  1027. tok = ParseParent(&parent);
  1028. if (parent != nullptr) {
  1029. parents_map_[func_graph] = parent;
  1030. }
  1031. }
  1032. if (tok != TOK_LPARENTHESIS) {
  1033. error_flag_ = true;
  1034. return nullptr;
  1035. }
  1036. if (ParseParameters(func_graph) == nullptr) {
  1037. error_flag_ = true;
  1038. return nullptr;
  1039. }
  1040. if (lexer_.SkipWhiteToken() != TOK_LBRACE) {
  1041. error_flag_ = true;
  1042. return nullptr;
  1043. }
  1044. // parse statements
  1045. if (ParseStatements(func_graph) == nullptr) {
  1046. error_flag_ = true;
  1047. return nullptr;
  1048. }
  1049. func_graphs_map_[func_graph_name] = func_graph;
  1050. return func_graph;
  1051. }
  1052. FuncGraphPtr ParseStatements(const FuncGraphPtr &func_graph) {
  1053. Token tok = lexer_.SkipWhiteToken();
  1054. while (tok == TOK_VARIABLE) {
  1055. if (ParseStatement(func_graph) == nullptr) {
  1056. return nullptr;
  1057. }
  1058. tok = lexer_.SkipWhiteToken();
  1059. }
  1060. if (tok == TOK_RETURN) {
  1061. return ParseReturn(func_graph);
  1062. }
  1063. return nullptr;
  1064. }
  1065. FuncGraphPtr ParseStatement(FuncGraphPtr func_graph) {
  1066. std::string var_name = lexer_.GetTokenText();
  1067. Token tok = lexer_.GetNextToken();
  1068. AbstractBasePtr type = nullptr;
  1069. if (tok == TOK_COLON) {
  1070. tok = ParseType(func_graph, &type);
  1071. }
  1072. if (tok != TOK_EQUALITY) {
  1073. return nullptr;
  1074. }
  1075. std::vector<AnfNodePtr> inputs;
  1076. AnfNodePtr node = nullptr;
  1077. ValuePtr val = nullptr;
  1078. tok = ParseItem(func_graph, &node, &val);
  1079. if (tok != TOK_LPARENTHESIS) {
  1080. return nullptr;
  1081. }
  1082. inputs.push_back(node);
  1083. int lineno = lexer_.GetLineNo();
  1084. if (ParseArguments(func_graph, &inputs) == nullptr) {
  1085. return nullptr;
  1086. }
  1087. tok = lexer_.GetNextToken();
  1088. if (tok == TOK_COMMENT) {
  1089. tok = lexer_.GetNextToken();
  1090. }
  1091. if (tok != TOK_EOL) {
  1092. return nullptr;
  1093. }
  1094. MS_EXCEPTION_IF_NULL(func_graph);
  1095. cnodes_[var_name] = func_graph->NewCNode(inputs);
  1096. MS_EXCEPTION_IF_NULL(cnodes_[var_name]);
  1097. cnodes_[var_name]->set_debug_info(std::make_shared<NodeDebugInfo>(var_name + "@" + std::to_string(lineno)));
  1098. return func_graph;
  1099. }
  1100. FuncGraphPtr ParseReturn(FuncGraphPtr func_graph) {
  1101. if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
  1102. return nullptr;
  1103. }
  1104. AnfNodePtr input1 = nullptr;
  1105. ValuePtr value = nullptr;
  1106. Token tok = ParseItem(func_graph, &input1, &value, lexer_.GetNextToken());
  1107. int lineno = lexer_.GetLineNo();
  1108. if (tok != TOK_RPARENTHESIS) {
  1109. return nullptr;
  1110. }
  1111. tok = lexer_.GetNextToken();
  1112. if (tok == TOK_COMMENT) {
  1113. tok = lexer_.GetNextToken();
  1114. }
  1115. if (tok != TOK_EOL) {
  1116. return nullptr;
  1117. }
  1118. if (lexer_.SkipWhiteToken() != TOK_RBRACE) {
  1119. return nullptr;
  1120. }
  1121. PrimitivePtr prim = std::make_shared<Primitive>("return");
  1122. ValueNodePtr input0 = std::make_shared<ValueNode>(prim);
  1123. std::vector<AnfNodePtr> inputs;
  1124. inputs.push_back(input0);
  1125. inputs.push_back(input1);
  1126. MS_EXCEPTION_IF_NULL(func_graph);
  1127. CNodePtr ret = func_graph->NewCNode(inputs);
  1128. MS_EXCEPTION_IF_NULL(ret);
  1129. ret->set_debug_info(std::make_shared<NodeDebugInfo>(std::string("ret@") + std::to_string(lineno)));
  1130. func_graph->set_return(ret);
  1131. return func_graph;
  1132. }
  1133. void SetBasicType(TypePtr *ptr, const TypePtr &dtype) const {
  1134. if (ptr == nullptr) {
  1135. return;
  1136. }
  1137. *ptr = dtype;
  1138. }
  1139. void SetTupleType(TypePtr *ptr) {
  1140. if (ptr == nullptr) {
  1141. return;
  1142. }
  1143. *ptr = std::make_shared<Tuple>();
  1144. }
  1145. void SetTupleType(TypePtr *ptr, const TypePtrList &elems) {
  1146. if (ptr == nullptr) {
  1147. return;
  1148. }
  1149. *ptr = std::make_shared<Tuple>(elems);
  1150. }
  1151. void SetArrayType(TypePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &) {
  1152. if (ptr == nullptr) {
  1153. return;
  1154. }
  1155. *ptr = std::make_shared<TensorType>(elem_type);
  1156. }
  1157. void SetListType(TypePtr *ptr) {
  1158. if (ptr == nullptr) {
  1159. return;
  1160. }
  1161. *ptr = std::make_shared<List>();
  1162. }
  1163. void SetListType(TypePtr *ptr, const TypePtrList &elems) {
  1164. if (ptr == nullptr) {
  1165. return;
  1166. }
  1167. *ptr = std::make_shared<List>(elems);
  1168. }
  1169. void SetJTaggedType(TypePtr *ptr, const TypePtr &elem) {
  1170. if (ptr == nullptr) {
  1171. return;
  1172. }
  1173. *ptr = std::make_shared<JTagged>(elem);
  1174. }
  1175. void SetBasicType(AbstractBasePtr *ptr, const TypePtr &dtype) const {
  1176. if (ptr == nullptr) {
  1177. return;
  1178. }
  1179. *ptr = std::make_shared<abstract::AbstractScalar>(dtype);
  1180. }
  1181. // void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {}
  1182. void SetBasicType(AbstractBasePtr *const ptr, const TypeNonePtr &) const {
  1183. if (ptr == nullptr) {
  1184. return;
  1185. }
  1186. *ptr = std::make_shared<abstract::AbstractNone>();
  1187. }
  1188. void SetBasicType(AbstractBasePtr *, const FunctionPtr &) const {}
  1189. void SetBasicType(AbstractBasePtr *, const TensorTypePtr &) const {}
  1190. void SetTupleType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
  1191. if (ptr == nullptr) {
  1192. return;
  1193. }
  1194. // if one of elems is nullptr, just return
  1195. if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
  1196. return;
  1197. }
  1198. *ptr = std::make_shared<abstract::AbstractTuple>(elems);
  1199. }
  1200. void SetArrayType(AbstractBasePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &shape) {
  1201. if (ptr == nullptr) {
  1202. return;
  1203. }
  1204. *ptr = std::make_shared<abstract::AbstractTensor>(elem_type, shape);
  1205. }
  1206. void SetListType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
  1207. if (ptr == nullptr) {
  1208. return;
  1209. }
  1210. if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
  1211. return;
  1212. }
  1213. *ptr = std::make_shared<abstract::AbstractList>(elems);
  1214. }
  1215. void SetJTaggedType(AbstractBasePtr *const ptr, const AbstractBasePtr &elem) {
  1216. if (ptr == nullptr) {
  1217. return;
  1218. }
  1219. *ptr = std::make_shared<abstract::AbstractJTagged>(elem);
  1220. }
  1221. template <typename T>
  1222. Token ParseTypeVector(const FuncGraphPtr &func_graph, Token tok, const std::string &type, T *const ptr = nullptr) {
  1223. if (tok != TOK_LBRACKET) {
  1224. MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol.";
  1225. return tok;
  1226. }
  1227. bool first_flag = true;
  1228. std::vector<T> elems;
  1229. do {
  1230. tok = lexer_.GetNextToken();
  1231. if (first_flag) {
  1232. if (tok == TOK_RBRACKET) {
  1233. return lexer_.GetNextToken();
  1234. }
  1235. first_flag = false;
  1236. }
  1237. T elem = nullptr;
  1238. tok = ParseOneType(func_graph, tok, &elem);
  1239. elems.push_back(elem);
  1240. if (tok == TOK_STAR) {
  1241. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1242. return TOK_ERROR;
  1243. }
  1244. int num_elems = StringToScalar<int>(lexer_.GetTokenText());
  1245. if (num_elems < 1 || num_elems > NUM_MAX_SEQUENCE_ELEMS) {
  1246. MS_LOG(EXCEPTION) << "Number of elements " << num_elems << " is out of range [1, " << NUM_MAX_SEQUENCE_ELEMS
  1247. << "]";
  1248. }
  1249. for (int i = 0; i < num_elems - 1; ++i) {
  1250. elems.push_back(elem);
  1251. }
  1252. tok = lexer_.GetNextToken();
  1253. }
  1254. } while (tok == TOK_COMMA);
  1255. if (tok != TOK_RBRACKET) {
  1256. return TOK_ERROR;
  1257. }
  1258. if (type == "Tuple") {
  1259. SetTupleType(ptr, elems);
  1260. } else if (type == "List") {
  1261. SetListType(ptr, elems);
  1262. } else {
  1263. MS_LOG(EXCEPTION) << "This method does not support " << type << " parse.";
  1264. }
  1265. return lexer_.GetNextToken();
  1266. }
  1267. template <typename T>
  1268. Token ParseTypeArray(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
  1269. if (tok != TOK_LPARENTHESIS) {
  1270. if (ptr != nullptr) {
  1271. SetBasicType(ptr, std::make_shared<TensorType>());
  1272. }
  1273. return tok;
  1274. }
  1275. // process Array element type
  1276. TypePtr elem_type = nullptr;
  1277. std::vector<int> shape;
  1278. tok = ParseOneType(func_graph, lexer_.GetNextToken(), &elem_type);
  1279. if (tok != TOK_RPARENTHESIS) {
  1280. return TOK_ERROR;
  1281. }
  1282. tok = lexer_.GetNextToken();
  1283. if (tok != TOK_LBRACKET) {
  1284. // NOTICE: if shape.size == 0, is this ok?
  1285. SetArrayType(ptr, elem_type, shape);
  1286. return tok;
  1287. }
  1288. // process Array shape
  1289. do {
  1290. tok = lexer_.GetNextToken();
  1291. // case: Array(I32)[]
  1292. if (tok != TOK_NUMBER) {
  1293. break;
  1294. }
  1295. shape.push_back(StringToScalar<int>(lexer_.GetTokenText()));
  1296. tok = lexer_.GetNextToken();
  1297. } while (tok == TOK_COMMA);
  1298. if (tok != TOK_RBRACKET) {
  1299. return TOK_ERROR;
  1300. }
  1301. SetArrayType(ptr, elem_type, shape);
  1302. return lexer_.GetNextToken();
  1303. }
  1304. bool IsNumberType(const std::string &type, TypeId *typeid_ptr) {
  1305. // clang-format off
  1306. static std::unordered_map<std::string, TypeId> basic_types = {
  1307. {"Bool", kNumberTypeBool},
  1308. {"I8", kNumberTypeInt8},
  1309. {"I16", kNumberTypeInt16},
  1310. {"I32", kNumberTypeInt32},
  1311. {"I64", kNumberTypeInt64},
  1312. {"U8", kNumberTypeUInt8},
  1313. {"U16", kNumberTypeUInt16},
  1314. {"U32", kNumberTypeUInt32},
  1315. {"U64", kNumberTypeUInt64},
  1316. {"F16", kNumberTypeFloat16},
  1317. {"F32", kNumberTypeFloat32},
  1318. {"F64", kNumberTypeFloat64},
  1319. {"Int", kNumberTypeInt},
  1320. {"UInt", kNumberTypeUInt},
  1321. {"Float", kNumberTypeFloat},
  1322. {"Number", kObjectTypeNumber}};
  1323. // clang-format on
  1324. auto iter = basic_types.find(type);
  1325. if (iter == basic_types.end()) {
  1326. return false;
  1327. }
  1328. if (typeid_ptr != nullptr) {
  1329. *typeid_ptr = iter->second;
  1330. }
  1331. return true;
  1332. }
  1333. template <typename T>
  1334. void ParseNumberType(const std::string &type, TypeId typeId, T *const ptr = nullptr) {
  1335. TypePtr dtype = nullptr;
  1336. std::unordered_map<int, TypePtr> type_map = {
  1337. {static_cast<int>(kNumberTypeBool), std::make_shared<Bool>()}, // Bool
  1338. {static_cast<int>(kNumberTypeInt8), std::make_shared<Int>(8)}, // Int8
  1339. {static_cast<int>(kNumberTypeInt16), std::make_shared<Int>(16)}, // Int16
  1340. {static_cast<int>(kNumberTypeInt32), std::make_shared<Int>(32)}, // Int32
  1341. {static_cast<int>(kNumberTypeInt64), std::make_shared<Int>(64)}, // Int64
  1342. {static_cast<int>(kNumberTypeUInt8), std::make_shared<UInt>(8)}, // UInt8
  1343. {static_cast<int>(kNumberTypeUInt16), std::make_shared<UInt>(16)}, // UInt16
  1344. {static_cast<int>(kNumberTypeUInt32), std::make_shared<UInt>(32)}, // UInt32
  1345. {static_cast<int>(kNumberTypeUInt64), std::make_shared<UInt>(64)}, // UInt64
  1346. {static_cast<int>(kNumberTypeFloat16), std::make_shared<Float>(16)}, // Float16
  1347. {static_cast<int>(kNumberTypeFloat32), std::make_shared<Float>(32)}, // Float32
  1348. {static_cast<int>(kNumberTypeFloat64), std::make_shared<Float>(64)}, // Float64
  1349. {static_cast<int>(kNumberTypeInt), std::make_shared<Int>()}, // Int
  1350. {static_cast<int>(kNumberTypeUInt), std::make_shared<UInt>()}, // UInt
  1351. {static_cast<int>(kNumberTypeFloat), std::make_shared<Float>()}, // Float
  1352. {static_cast<int>(kObjectTypeNumber), std::make_shared<Number>()}, // Number
  1353. };
  1354. auto iter = type_map.find(static_cast<int>(typeId));
  1355. if (iter != type_map.end()) {
  1356. dtype = iter->second;
  1357. } else {
  1358. MS_LOG(EXCEPTION) << "Unknown number type " << type;
  1359. }
  1360. SetBasicType(ptr, dtype);
  1361. }
  1362. template <typename T>
  1363. Token ParseTrivalType(const std::string &type, T *const ptr = nullptr) {
  1364. if (type == "NoneType") {
  1365. SetBasicType(ptr, std::make_shared<TypeNone>());
  1366. return lexer_.GetNextToken();
  1367. } else if (type == "ProblemType") {
  1368. SetBasicType(ptr, std::make_shared<Problem>());
  1369. return lexer_.GetNextToken();
  1370. } else if (type == "ExternalType") {
  1371. SetBasicType(ptr, std::make_shared<External>());
  1372. return lexer_.GetNextToken();
  1373. } else if (type == "AnythingType") {
  1374. SetBasicType(ptr, kAnyType);
  1375. return lexer_.GetNextToken();
  1376. } else if (type == "TypeType") {
  1377. SetBasicType(ptr, std::make_shared<TypeType>());
  1378. return lexer_.GetNextToken();
  1379. } else {
  1380. MS_LOG(EXCEPTION) << "Unknown type error at line " << lexer_.GetLineNo();
  1381. }
  1382. }
  1383. template <typename T>
  1384. Token ParseOneType(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
  1385. if (tok != TOK_IDENTIFIER) {
  1386. return TOK_ERROR;
  1387. }
  1388. std::string type = lexer_.GetTokenText();
  1389. TypeId typeId = kTypeUnknown;
  1390. if (IsNumberType(type, &typeId)) {
  1391. ParseNumberType(type, typeId, ptr);
  1392. return lexer_.GetNextToken();
  1393. } else if (type == "Tuple") {
  1394. return ParseTypeVector(func_graph, lexer_.GetNextToken(), type, ptr);
  1395. } else if (type == "Tensor") {
  1396. return ParseTypeArray(func_graph, lexer_.GetNextToken(), ptr);
  1397. } else if (type == "List") {
  1398. return ParseTypeVector(func_graph, lexer_.GetNextToken(), type, ptr);
  1399. } else if (type == "Func") {
  1400. tok = lexer_.GetNextToken();
  1401. if (tok != TOK_LBRACKET) {
  1402. SetBasicType(ptr, std::make_shared<Function>());
  1403. return tok;
  1404. }
  1405. MS_LOG(EXCEPTION) << "Need to process function parameter types at line " << lexer_.GetLineNo();
  1406. } else if (type == "JT") {
  1407. tok = lexer_.GetNextToken();
  1408. if (tok != TOK_LBRACKET) {
  1409. return tok;
  1410. }
  1411. T elem = nullptr;
  1412. tok = ParseOneType(func_graph, lexer_.GetNextToken(), &elem);
  1413. SetJTaggedType(ptr, elem);
  1414. if (tok != TOK_RBRACKET) {
  1415. return TOK_ERROR;
  1416. }
  1417. return lexer_.GetNextToken();
  1418. } else if (type == "SymType") {
  1419. SetBasicType(ptr, std::make_shared<SymbolicKeyType>());
  1420. return lexer_.GetNextToken();
  1421. } else if (type == "EnvType") {
  1422. SetBasicType(ptr, std::make_shared<EnvType>());
  1423. return lexer_.GetNextToken();
  1424. } else if (Match(type, "Cls.")) {
  1425. MS_LOG(EXCEPTION) << "Need to do class type at line " << lexer_.GetLineNo();
  1426. } else {
  1427. return ParseTrivalType(type, ptr);
  1428. }
  1429. }
  1430. Token ParseType(const FuncGraphPtr &func_graph, AbstractBasePtr *const abstract = nullptr) {
  1431. return ParseOneType(func_graph, lexer_.GetNextToken(), abstract);
  1432. }
  1433. Token ParseAttributes(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
  1434. Token tok = ParseAttribute(func_graph, prim);
  1435. while (tok == TOK_COMMA) {
  1436. tok = ParseAttribute(func_graph, prim);
  1437. }
  1438. if (tok != TOK_RBRACKET) {
  1439. return TOK_ERROR;
  1440. }
  1441. return lexer_.GetNextToken();
  1442. }
  1443. Token ParseAttribute(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
  1444. Token tok = lexer_.GetNextToken();
  1445. if (tok != TOK_IDENTIFIER) {
  1446. return TOK_ERROR;
  1447. }
  1448. std::string attr_name = lexer_.GetTokenText();
  1449. if (lexer_.GetNextToken() != TOK_EQUALITY) {
  1450. return TOK_ERROR;
  1451. }
  1452. ValuePtr value = nullptr;
  1453. tok = ParseValue(func_graph, lexer_.GetNextToken(), &value);
  1454. if (prim != nullptr) {
  1455. prim->set_attr(attr_name, value);
  1456. } else {
  1457. MS_LOG(EXCEPTION) << "Non primitive obj has attributes";
  1458. }
  1459. return tok;
  1460. }
  1461. FuncGraphPtr ParseParameters(FuncGraphPtr func_graph) {
  1462. Token tok = lexer_.SkipWhiteToken();
  1463. while (tok == TOK_PARAMETER) {
  1464. ParameterPtr param = std::make_shared<Parameter>(func_graph);
  1465. param->set_name(lexer_.GetTokenText());
  1466. param_nodes_[lexer_.GetTokenText()] = param;
  1467. int lineno = lexer_.GetLineNo();
  1468. param->set_debug_info(std::make_shared<NodeDebugInfo>(lexer_.GetTokenText() + "@" + std::to_string(lineno)));
  1469. func_graph->add_parameter(param);
  1470. tok = lexer_.GetNextToken();
  1471. // parse type
  1472. if (tok == TOK_COLON) {
  1473. AbstractBasePtr type = nullptr;
  1474. tok = ParseType(func_graph, &type);
  1475. }
  1476. // parse default value
  1477. if (tok == TOK_EQUALITY) {
  1478. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1479. MS_LOG(EXCEPTION) << "Expect @file at line " << lexer_.GetLineNo();
  1480. }
  1481. // load parameter default value from serialized file
  1482. py::object default_obj = LoadObject(lexer_.GetTokenText());
  1483. auto param_value_new = py::cast<tensor::TensorPtr>(default_obj);
  1484. param->set_default_param(param_value_new);
  1485. tok = lexer_.GetNextToken();
  1486. }
  1487. if (tok == TOK_COMMENT || tok == TOK_EOL) {
  1488. tok = lexer_.SkipWhiteToken();
  1489. }
  1490. Token next = tok;
  1491. if (next == TOK_RPARENTHESIS) {
  1492. return func_graph;
  1493. } else if (next == TOK_COMMA) {
  1494. tok = lexer_.SkipWhiteToken();
  1495. } else {
  1496. return nullptr;
  1497. }
  1498. }
  1499. return tok == TOK_RPARENTHESIS ? func_graph : nullptr;
  1500. }
  1501. FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
  1502. Token tok = ParseArgument(func_graph, inputs_ptr);
  1503. while (tok == TOK_COMMA) {
  1504. tok = ParseArgument(func_graph, inputs_ptr);
  1505. }
  1506. if (tok != TOK_RPARENTHESIS) {
  1507. return nullptr;
  1508. }
  1509. return func_graph;
  1510. }
  1511. AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string &param_name) {
  1512. while (func_graph != nullptr) {
  1513. for (auto &ptr : func_graph->parameters()) {
  1514. MS_EXCEPTION_IF_NULL(ptr);
  1515. ParameterPtr param = ptr->cast<ParameterPtr>();
  1516. MS_EXCEPTION_IF_NULL(param);
  1517. if (param->name() == param_name) {
  1518. return ptr;
  1519. }
  1520. }
  1521. auto iter = parents_map_.find(func_graph);
  1522. if (iter == parents_map_.end()) {
  1523. break;
  1524. }
  1525. func_graph = iter->second;
  1526. }
  1527. return nullptr;
  1528. }
  1529. bool Match(const std::string &str, const std::string &pattern) const {
  1530. return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0;
  1531. }
  1532. template <typename T, typename V>
  1533. Token ParseScalar(ValuePtr *const val_ptr) {
  1534. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1535. return TOK_ERROR;
  1536. }
  1537. std::stringstream ss;
  1538. ss << lexer_.GetTokenText();
  1539. if (lexer_.GetNextToken() != TOK_RPARENTHESIS) {
  1540. return TOK_ERROR;
  1541. }
  1542. V val;
  1543. ss >> val;
  1544. *val_ptr = std::make_shared<T>(val);
  1545. return lexer_.GetNextToken();
  1546. }
  1547. template <typename VT, typename V, typename T>
  1548. Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
  1549. if (tok != TOK_LPARENTHESIS) {
  1550. *val_ptr = std::make_shared<T>();
  1551. return tok;
  1552. }
  1553. return ParseScalar<VT, V>(val_ptr);
  1554. }
  1555. template <typename VT, typename V, typename T, const unsigned nbits>
  1556. Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
  1557. if (tok != TOK_LPARENTHESIS) {
  1558. *val_ptr = std::make_shared<T>(nbits);
  1559. return tok;
  1560. }
  1561. return ParseScalar<VT, V>(val_ptr);
  1562. }
  1563. template <typename T>
  1564. T StringToScalar(const std::string &text) {
  1565. std::stringstream ss;
  1566. T value;
  1567. ss << text;
  1568. ss >> value;
  1569. return value;
  1570. }
  1571. Token ParseTensor(ValuePtr *const val_ptr) {
  1572. // parse type
  1573. TypeId type;
  1574. if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
  1575. return TOK_ERROR;
  1576. }
  1577. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1578. return TOK_ERROR;
  1579. }
  1580. type = static_cast<TypeId>(StringToScalar<int>(lexer_.GetTokenText()));
  1581. if (lexer_.GetNextToken() != TOK_RPARENTHESIS) {
  1582. return TOK_ERROR;
  1583. }
  1584. // parse shape
  1585. std::vector<int> shape;
  1586. Token tok = lexer_.GetNextToken();
  1587. if (tok != TOK_LBRACKET) {
  1588. return TOK_ERROR;
  1589. }
  1590. do {
  1591. tok = lexer_.GetNextToken();
  1592. // consider case: Tensor(23)[]
  1593. if (tok != TOK_NUMBER) {
  1594. break;
  1595. }
  1596. shape.push_back(StringToScalar<int>(lexer_.GetTokenText()));
  1597. tok = lexer_.GetNextToken();
  1598. } while (tok == TOK_COMMA);
  1599. if (tok != TOK_RBRACKET) {
  1600. return TOK_ERROR;
  1601. }
  1602. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1603. return TOK_ERROR;
  1604. }
  1605. py::object tensor_obj = LoadObject(lexer_.GetTokenText());
  1606. py::array tensor_data = py::cast<py::array>(tensor_obj);
  1607. if (!tensor_data) {
  1608. return TOK_ERROR;
  1609. }
  1610. *val_ptr = TensorPy::MakeTensor(tensor_data, TypeIdToType(type));
  1611. return lexer_.GetNextToken();
  1612. }
  1613. Token ParsePrimType(Token tok, PrimType *prim_type_ptr) {
  1614. if (tok != TOK_LBRACE) {
  1615. return tok;
  1616. }
  1617. if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
  1618. return TOK_ERROR;
  1619. }
  1620. if (lexer_.GetTokenText() != "prim_type") {
  1621. return TOK_ERROR;
  1622. }
  1623. if (lexer_.GetNextToken() != TOK_EQUALITY) {
  1624. return TOK_ERROR;
  1625. }
  1626. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1627. return TOK_ERROR;
  1628. }
  1629. int val = 0;
  1630. std::stringstream ss;
  1631. ss << lexer_.GetTokenText();
  1632. ss >> val;
  1633. *prim_type_ptr = PrimType(val);
  1634. if (lexer_.GetNextToken() != TOK_RBRACE) {
  1635. return TOK_ERROR;
  1636. }
  1637. return lexer_.GetNextToken();
  1638. }
  1639. Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
  1640. if (tok != TOK_LPARENTHESIS) {
  1641. return TOK_ERROR;
  1642. }
  1643. TypePtrList type_list;
  1644. do {
  1645. TypePtr type = nullptr;
  1646. tok = ParseOneType(nullptr, lexer_.GetNextToken(), &type);
  1647. type_list.push_back(type);
  1648. } while (tok == TOK_COMMA);
  1649. if (tok != TOK_RPARENTHESIS) {
  1650. return TOK_ERROR;
  1651. }
  1652. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1653. return TOK_ERROR;
  1654. }
  1655. // load Python function from serialized file
  1656. py::object py_func = LoadObject(lexer_.GetTokenText());
  1657. MS_EXCEPTION_IF_NULL(mt_func_graph);
  1658. mt_func_graph->Register(type_list, py::function(py_func));
  1659. return lexer_.GetNextToken();
  1660. }
  1661. Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
  1662. if (tok != TOK_LBRACE) {
  1663. return tok;
  1664. }
  1665. do {
  1666. tok = ParseMultitypeFuncGraphItem(mt_func_graph, lexer_.GetNextToken());
  1667. } while (tok == TOK_COMMA);
  1668. if (tok != TOK_RBRACE) {
  1669. return TOK_ERROR;
  1670. }
  1671. return lexer_.GetNextToken();
  1672. }
  1673. Token ParseBoolValue(const std::string &key, bool *val_ptr) {
  1674. if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) {
  1675. return TOK_ERROR;
  1676. }
  1677. if (lexer_.GetNextToken() != TOK_EQUALITY) {
  1678. return TOK_ERROR;
  1679. }
  1680. if (lexer_.GetNextToken() != TOK_NUMBER) {
  1681. return TOK_ERROR;
  1682. }
  1683. bool value = false;
  1684. {
  1685. std::stringstream ss;
  1686. ss << lexer_.GetTokenText();
  1687. ss >> value;
  1688. }
  1689. if (val_ptr != nullptr) {
  1690. *val_ptr = value;
  1691. }
  1692. return lexer_.GetNextToken();
  1693. }
  1694. Token ParseValueGradOperation(const std::string &name, ValuePtr *const val_ptr) {
  1695. if (lexer_.GetNextToken() != TOK_LBRACE) {
  1696. return TOK_ERROR;
  1697. }
  1698. // get_all=0, get_by_list=1, sens_param=1
  1699. bool get_all = false;
  1700. Token tok = ParseBoolValue("get_all", &get_all);
  1701. if (tok != TOK_COMMA) {
  1702. return TOK_ERROR;
  1703. }
  1704. bool get_by_list = false;
  1705. tok = ParseBoolValue("get_by_list", &get_by_list);
  1706. if (tok != TOK_COMMA) {
  1707. return TOK_ERROR;
  1708. }
  1709. bool sens_param = false;
  1710. tok = ParseBoolValue("sens_param", &sens_param);
  1711. if (tok != TOK_RBRACE) {
  1712. return TOK_ERROR;
  1713. }
  1714. *val_ptr = std::make_shared<prim::GradOperation>(name, get_all, get_by_list, sens_param);
  1715. return lexer_.GetNextToken();
  1716. }
  1717. Token ParseSymbolicKeyInstance(const FuncGraphPtr &func_graph, AnfNodePtr *const node_ptr = nullptr) {
  1718. if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
  1719. return TOK_ERROR;
  1720. }
  1721. if (lexer_.GetNextToken() != TOK_PARAMETER) {
  1722. return TOK_ERROR;
  1723. }
  1724. std::string param_name = lexer_.GetTokenText();
  1725. if (lexer_.GetNextToken() != TOK_RPARENTHESIS) {
  1726. return TOK_ERROR;
  1727. }
  1728. auto iter = param_nodes_.find(param_name);
  1729. if (iter == param_nodes_.end()) {
  1730. MS_LOG(EXCEPTION) << "Can not find param '" << param_name << "' for SymbolicKeyInstance at line "
  1731. << lexer_.GetLineNo();
  1732. }
  1733. PrimitivePtr embed = std::make_shared<Primitive>("embed");
  1734. std::vector<AnfNodePtr> inputs;
  1735. inputs.push_back(std::make_shared<ValueNode>(embed));
  1736. inputs.push_back(iter->second);
  1737. if (node_ptr != nullptr) {
  1738. MS_EXCEPTION_IF_NULL(func_graph);
  1739. *node_ptr = func_graph->NewCNode(inputs);
  1740. } else {
  1741. MS_LOG(EXCEPTION) << "Not processed SymbolicKeyInstance '" << param_name << "' at line " << lexer_.GetLineNo()
  1742. << ".";
  1743. }
  1744. return lexer_.GetNextToken();
  1745. }
  1746. Token ParsePrimitivePy(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr) {
  1747. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1748. return TOK_ERROR;
  1749. }
  1750. // restore python function of PrimitivePy from serialized file
  1751. py::object py_obj = LoadObject(lexer_.GetTokenText());
  1752. PrimitivePyPtr ptr = nullptr;
  1753. if (py::hasattr(py_obj, "__setattr_flag__") && py::hasattr(py_obj, "_clone")) {
  1754. auto clone_fn = py_obj.attr("_clone");
  1755. py::object new_obj = clone_fn();
  1756. ptr = new_obj.cast<PrimitivePyPtr>();
  1757. if (ptr == nullptr) {
  1758. MS_LOG(EXCEPTION) << "Cast to type 'PrimitivePyPtr' error";
  1759. }
  1760. } else {
  1761. auto len = strlen("PrimitivePy::");
  1762. if (id.size() < len) {
  1763. return TOK_ERROR;
  1764. }
  1765. ptr = std::make_shared<PrimitivePy>(id.substr(len), py_obj);
  1766. }
  1767. *val_ptr = ptr;
  1768. PrimType prim_type = kPrimTypeUnknown;
  1769. Token next = ParsePrimType(lexer_.GetNextToken(), &prim_type);
  1770. if (prim_type != kPrimTypeUnknown) {
  1771. ptr->set_prim_type(prim_type);
  1772. }
  1773. if (next != TOK_LBRACKET) {
  1774. return next;
  1775. }
  1776. // parse attributes
  1777. next = ParseAttributes(func_graph, ptr);
  1778. return next;
  1779. }
  1780. Token ParseValueGraphAndNamespace(const std::string &id, ValuePtr *const val_ptr) {
  1781. if (Match(id, "MultitypeFuncGraph::")) {
  1782. std::string name = id.substr(strlen("MultitypeFuncGraph::"));
  1783. auto mt_func_graph = std::make_shared<prim::MultitypeFuncGraph>(name);
  1784. *val_ptr = mt_func_graph;
  1785. Token next = ParseMultitypeFuncGraph(mt_func_graph, lexer_.GetNextToken());
  1786. return next;
  1787. } else if (Match(id, "HyperMapPy::")) {
  1788. *val_ptr = std::make_shared<prim::HyperMapPy>();
  1789. Token next = lexer_.GetNextToken();
  1790. // process case: fn_leaf is not null
  1791. if (next == TOK_LBRACE) {
  1792. MS_LOG(EXCEPTION) << "Need to process fn_leaf at line " << lexer_.GetLineNo();
  1793. }
  1794. return next;
  1795. } else if (Match(id, "FuncGraph::")) {
  1796. std::string func_graph_name = id.substr(strlen("FuncGraph::"));
  1797. // if the graph does not exist, create a null graph, then fill the graph when encounter the definition
  1798. // of the graph
  1799. if (func_graphs_map_.find(func_graph_name) == func_graphs_map_.end()) {
  1800. func_graphs_map_[func_graph_name] = std::make_shared<FuncGraph>();
  1801. }
  1802. *val_ptr = func_graphs_map_[func_graph_name];
  1803. return lexer_.GetNextToken();
  1804. } else if (Match(id, "NameSpace::")) {
  1805. std::string module_name = id.substr(strlen("NameSpace::"));
  1806. if (lexer_.GetNextToken() != TOK_AT_FILE) {
  1807. MS_LOG(ERROR) << "Expect TOK_AT_FILE at line " << lexer_.GetLineNo();
  1808. return TOK_ERROR;
  1809. }
  1810. // load Python module information from serialized file
  1811. py::object py_obj = LoadObject(lexer_.GetTokenText());
  1812. *val_ptr = std::make_shared<parse::NameSpace>(module_name, py_obj);
  1813. return lexer_.GetNextToken();
  1814. } else {
  1815. MS_LOG(EXCEPTION) << "Unknown id " << id << " at line " << lexer_.GetLineNo();
  1816. }
  1817. }
  1818. Token ParseValueBasic(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr,
  1819. AnfNodePtr *const node_ptr = nullptr) {
  1820. if (id == "None") {
  1821. *val_ptr = std::make_shared<None>();
  1822. return lexer_.GetNextToken();
  1823. } else if (id == "Bool") {
  1824. return ParseScalar<BoolImm, bool, Bool>(val_ptr, lexer_.GetNextToken());
  1825. } else if (id == "I8") {
  1826. return ParseScalar<Int8Imm, int8_t, Int, 8>(val_ptr, lexer_.GetNextToken());
  1827. } else if (id == "I16") {
  1828. return ParseScalar<Int16Imm, int16_t, Int, 16>(val_ptr, lexer_.GetNextToken());
  1829. } else if (id == "I32") {
  1830. return ParseScalar<Int32Imm, int32_t, Int, 32>(val_ptr, lexer_.GetNextToken());
  1831. } else if (id == "I64") {
  1832. return ParseScalar<Int64Imm, int64_t, Int, 64>(val_ptr, lexer_.GetNextToken());
  1833. } else if (id == "U8") {
  1834. return ParseScalar<UInt8Imm, uint8_t, UInt, 8>(val_ptr, lexer_.GetNextToken());
  1835. } else if (id == "U16") {
  1836. return ParseScalar<UInt16Imm, uint16_t, UInt, 16>(val_ptr, lexer_.GetNextToken());
  1837. } else if (id == "U32") {
  1838. return ParseScalar<UInt32Imm, uint32_t, UInt, 32>(val_ptr, lexer_.GetNextToken());
  1839. } else if (id == "U64") {
  1840. return ParseScalar<UInt64Imm, uint64_t, UInt, 64>(val_ptr, lexer_.GetNextToken());
  1841. } else if (id == "F16") {
  1842. // Notice: Since there is no basic data type for storing fp16, just use float instead
  1843. return ParseScalar<FP32Imm, float, Float, 16>(val_ptr, lexer_.GetNextToken());
  1844. } else if (id == "F32") {
  1845. return ParseScalar<FP32Imm, float, Float, 32>(val_ptr, lexer_.GetNextToken());
  1846. } else if (id == "F64") {
  1847. return ParseScalar<FP64Imm, double, Float, 64>(val_ptr, lexer_.GetNextToken());
  1848. } else if (id == "Tensor") {
  1849. return ParseTensor(val_ptr);
  1850. } else if (id == "SymInst") {
  1851. return ParseSymbolicKeyInstance(func_graph, node_ptr);
  1852. } else if (id == "Array") {
  1853. TypePtr type = nullptr;
  1854. Token ret = ParseTypeArray(func_graph, lexer_.GetNextToken(), &type);
  1855. *val_ptr = type;
  1856. return ret;
  1857. } else if (Match(id, "PrimitivePy::")) {
  1858. return ParsePrimitivePy(func_graph, id, val_ptr);
  1859. } else if (Match(id, "Primitive::")) {
  1860. *val_ptr = std::make_shared<Primitive>(id.substr(strlen("Primitive::")));
  1861. return lexer_.GetNextToken();
  1862. } else if (Match(id, "GradOperation::")) {
  1863. return ParseValueGradOperation(id.substr(strlen("GradOperation::")), val_ptr);
  1864. } else {
  1865. return ParseValueGraphAndNamespace(id, val_ptr);
  1866. }
  1867. }
  1868. Token SetListOrTupleValue(const FuncGraphPtr &func_graph, Token left_tok, Token next, bool node_is_valid,
  1869. const std::vector<ValuePtr> &elems, const std::vector<AnfNodePtr> &nodes,
  1870. ValuePtr *const val_ptr, AnfNodePtr *node_ptr) {
  1871. if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) {
  1872. if (node_is_valid && node_ptr != nullptr) {
  1873. MS_EXCEPTION_IF_NULL(func_graph);
  1874. *node_ptr = func_graph->NewCNode(nodes);
  1875. } else {
  1876. *val_ptr = std::make_shared<ValueTuple>(elems);
  1877. }
  1878. return lexer_.GetNextToken();
  1879. } else if (left_tok == TOK_LBRACKET && next == TOK_RBRACKET) {
  1880. if (node_is_valid && node_ptr != nullptr) {
  1881. MS_LOG(EXCEPTION) << "Encounter valid node in value list";
  1882. }
  1883. *val_ptr = std::make_shared<ValueList>(elems);
  1884. return lexer_.GetNextToken();
  1885. } else {
  1886. return TOK_ERROR;
  1887. }
  1888. }
  1889. Token ParseListOrTupleValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr,
  1890. AnfNodePtr *node_ptr = nullptr) {
  1891. Token left_tok = tok;
  1892. std::vector<ValuePtr> elems;
  1893. std::vector<AnfNodePtr> nodes;
  1894. nodes.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>("make_tuple")));
  1895. ValuePtr elem = nullptr;
  1896. AnfNodePtr node = nullptr;
  1897. bool node_is_valid = false;
  1898. bool first_flag = true;
  1899. Token next = TOK_ERROR;
  1900. do {
  1901. next = lexer_.GetNextToken();
  1902. if (first_flag) {
  1903. first_flag = false;
  1904. // case (), zero elements
  1905. if ((left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) ||
  1906. (left_tok == TOK_LBRACKET && next == TOK_RBRACKET)) {
  1907. if (left_tok == TOK_LPARENTHESIS) {
  1908. *val_ptr = std::make_shared<ValueTuple>(elems);
  1909. } else {
  1910. *val_ptr = std::make_shared<ValueList>(elems);
  1911. }
  1912. return lexer_.GetNextToken();
  1913. }
  1914. }
  1915. node = nullptr;
  1916. next = ParseValue(func_graph, next, &elem, &node);
  1917. elems.push_back(elem);
  1918. if (node != nullptr) {
  1919. nodes.push_back(node);
  1920. node_is_valid = true;
  1921. } else {
  1922. nodes.push_back(std::make_shared<ValueNode>(elem));
  1923. }
  1924. } while (next == TOK_COMMA);
  1925. return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr);
  1926. }
  1927. Token ParseValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, AnfNodePtr *node_ptr = nullptr) {
  1928. // tuple or list
  1929. if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) {
  1930. return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr);
  1931. } else if (tok == TOK_IDENTIFIER) {
  1932. return ParseValueBasic(func_graph, lexer_.GetTokenText(), val_ptr, node_ptr);
  1933. } else if (tok == TOK_STRING) {
  1934. *val_ptr = std::make_shared<StringImm>(lexer_.GetTokenText());
  1935. return lexer_.GetNextToken();
  1936. }
  1937. MS_LOG(ERROR) << "Parse error!";
  1938. return TOK_ERROR;
  1939. }
  1940. Token ParseItem(const FuncGraphPtr &func_graph, AnfNodePtr *node_ptr, ValuePtr *const val_ptr,
  1941. Token tok = TOK_INVALID) {
  1942. if (tok == TOK_INVALID) {
  1943. tok = lexer_.GetNextToken();
  1944. }
  1945. if (tok == TOK_VARIABLE) {
  1946. auto iter = cnodes_.find(lexer_.GetTokenText());
  1947. if (iter == cnodes_.end()) {
  1948. MS_LOG(EXCEPTION) << "Can not find definition of '" << lexer_.GetTokenText() << "'";
  1949. }
  1950. *node_ptr = iter->second;
  1951. } else if (tok == TOK_PARAMETER) {
  1952. AnfNodePtr param = FindParameter(func_graph, lexer_.GetTokenText());
  1953. if (param == nullptr) {
  1954. MS_LOG(EXCEPTION) << "Can not find definition of '" << lexer_.GetTokenText() << "' at line "
  1955. << lexer_.GetLineNo();
  1956. }
  1957. *node_ptr = param;
  1958. } else if (tok == TOK_IDENTIFIER || tok == TOK_LPARENTHESIS || tok == TOK_STRING) {
  1959. ValuePtr value;
  1960. AnfNodePtr node;
  1961. tok = ParseValue(func_graph, tok, &value, &node);
  1962. if (tok == TOK_ERROR) {
  1963. MS_LOG(ERROR) << "Parse value error!";
  1964. return tok;
  1965. }
  1966. if (node == nullptr) {
  1967. *val_ptr = value;
  1968. *node_ptr = std::make_shared<ValueNode>(value);
  1969. } else {
  1970. *node_ptr = node;
  1971. }
  1972. return tok;
  1973. } else {
  1974. MS_LOG(EXCEPTION) << "tok_type = " << tok;
  1975. }
  1976. return lexer_.GetNextToken();
  1977. }
  1978. Token ParseArgument(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
  1979. Token tok = lexer_.GetNextToken();
  1980. if (tok == TOK_RPARENTHESIS) {
  1981. return tok;
  1982. }
  1983. AnfNodePtr node = nullptr;
  1984. ValuePtr value = nullptr;
  1985. tok = ParseItem(func_graph, &node, &value, tok);
  1986. if (tok != TOK_ERROR) {
  1987. MS_EXCEPTION_IF_NULL(inputs_ptr);
  1988. inputs_ptr->push_back(node);
  1989. }
  1990. return tok;
  1991. }
  1992. const std::vector<FuncGraphPtr> &GetFuncGraphs() const { return func_graphs_; }
  1993. private:
  1994. Lexer lexer_;
  1995. std::vector<FuncGraphPtr> func_graphs_;
  1996. bool error_flag_ = false;
  1997. // store all parsed graphs
  1998. std::map<std::string, FuncGraphPtr> func_graphs_map_;
  1999. // map from child to parent, consider adding a 'parent' field in class Graph
  2000. std::map<FuncGraphPtr, FuncGraphPtr> parents_map_;
  2001. // map for buffering cnodes when parsing a graph
  2002. std::map<std::string, CNodePtr> cnodes_;
  2003. std::map<std::string, ParameterPtr> param_nodes_; // map parameter name to parameter
  2004. };
  2005. std::vector<FuncGraphPtr> ImportIR(const std::string &filename) {
  2006. IrParser parser(filename.c_str());
  2007. parser.ParseFile();
  2008. return parser.GetFuncGraphs();
  2009. }
  2010. #ifdef ENABLE_DUMP_IR
  2011. void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
  2012. if (func_graph == nullptr) {
  2013. MS_LOG(ERROR) << "Func graph is nullptr";
  2014. return;
  2015. }
  2016. auto ms_context = MsContext::GetInstance();
  2017. if (ms_context == nullptr) {
  2018. MS_LOG(ERROR) << "ms_context is nullptr";
  2019. return;
  2020. }
  2021. auto save_graphs_path = ms_context->save_graphs_path();
  2022. if (save_graphs_path.empty()) {
  2023. save_graphs_path = ".";
  2024. }
  2025. std::string file_path = save_graphs_path + "/" + "ms_output_" + suffix + ".pb";
  2026. if (file_path.size() > PATH_MAX) {
  2027. MS_LOG(ERROR) << "File path " << file_path << " is too long.";
  2028. return;
  2029. }
  2030. char real_path[PATH_MAX] = {0};
  2031. char *real_path_ret = nullptr;
  2032. #if defined(_WIN32) || defined(_WIN64)
  2033. real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX);
  2034. #else
  2035. real_path_ret = realpath(file_path.c_str(), real_path);
  2036. #endif
  2037. if (nullptr == real_path_ret) {
  2038. MS_LOG(DEBUG) << "dir " << file_path << " does not exit.";
  2039. } else {
  2040. std::string path_string = real_path;
  2041. if (chmod(common::SafeCStr(path_string), S_IRUSR | S_IWUSR) == -1) {
  2042. MS_LOG(ERROR) << "Modify file:" << real_path << " to rw fail.";
  2043. return;
  2044. }
  2045. }
  2046. // write to pb file
  2047. std::ofstream ofs(real_path);
  2048. if (!ofs.is_open()) {
  2049. MS_LOG(ERROR) << "Open file '" << real_path << "' failed!";
  2050. return;
  2051. }
  2052. ofs << GetFuncGraphProtoString(func_graph);
  2053. ofs.close();
  2054. // set file mode to read only by user
  2055. ChangeFileMode(file_path, S_IRUSR);
  2056. }
  2057. #else
  2058. void DumpIRProto(const FuncGraphPtr &, const std::string &) {
  2059. static bool already_printed = false;
  2060. if (already_printed) {
  2061. return;
  2062. }
  2063. already_printed = true;
  2064. MS_LOG(WARNING) << "The functionality of dumping function graph IR in protobuf format is disabled, "
  2065. << "please recompile source to enable it. See help of building script.";
  2066. }
  2067. #endif
  2068. } // namespace mindspore