You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parse.h 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
  19. #define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
  20. #include <limits>
  21. #include <vector>
  22. #include <string>
  23. #include <map>
  24. #include <set>
  25. #include <stack>
  26. #include <memory>
  27. #include "utils/misc.h"
  28. #include "ir/anf.h"
  29. #include "pipeline/jit/parse/parse_base.h"
  30. #include "pipeline/jit/parse/python_adapter.h"
  31. #include "pipeline/jit/parse/function_block.h"
  32. namespace mindspore {
  33. namespace parse {
  34. // Parse status define
  35. enum ParseStatusCode : int {
  36. PARSE_SUCCESS = 0,
  37. PARSE_FUNCTION_IS_NULL, // python function is null
  38. PARSE_PARAMETER_INVALID, // parameter is invalid
  39. PARSE_NO_RETURN, // function no return node
  40. PARSE_NODE_TYPE_NO_MATCH, // ast node type is error
  41. PARSE_NODE_TYPE_UNKOWN, // node type is unkown
  42. PARSE_NODE_METHOD_UNSUPPORTED, // no method to parse the node
  43. PARSE_DONT_RESOLVE_SYMBOL, // can't resolve the string
  44. PARSE_NOT_SUPPORTED_COMPARE_EXPR, // the comparison is not supported
  45. PARSE_FAILURE = 0xFF
  46. };
  47. // max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
  48. // will be sunk(i.e. not unrolled)
  49. // NOTE: Since when the for loop was unrolled, it depends backend operators `tuple_getitem` and `scalar_add` which were
  50. // not implemented, so here set MAX_FOR_LOOP_COUNT to int max limit to override default value `600`. This will make
  51. // the for loop will always be unrolled, but don't worry about the memory were exhausted, an exception will be raised
  52. // when function call depth execeeds the limit `context.get_context('max_call_depth')`.
  53. const int MAX_FOR_LOOP_COUNT = std::numeric_limits<int>::max();
  54. class AstNodeType;
  55. class ParseAst;
  56. // Save loop info for 'continue' and 'break' statements.
  57. struct Loop {
  58. // Loop header block.
  59. FunctionBlockPtr header;
  60. // Loop iterator node, used in 'for loop'.
  61. AnfNodePtr iterator;
  62. // Loop end block.
  63. FunctionBlockPtr end;
  64. Loop(const FunctionBlockPtr &header, const AnfNodePtr &iterator, const FunctionBlockPtr &end)
  65. : header(header), iterator(iterator), end(end) {}
  66. ~Loop() = default;
  67. };
  68. // Loop context for loop stack management.
  69. class LoopContext {
  70. public:
  71. LoopContext(std::stack<Loop> *loops, const FunctionBlockPtr &header, const AnfNodePtr &iterator) : loops_(loops) {
  72. loops_->emplace(header, iterator, nullptr);
  73. }
  74. ~LoopContext() { loops_->pop(); }
  75. const FunctionBlockPtr &EndBlock() const { return loops_->top().end; }
  76. private:
  77. std::stack<Loop> *loops_;
  78. };
  79. // Parser to parse python function
  80. class Parser {
  81. public:
  82. explicit Parser(const std::shared_ptr<ParseAst> &ast);
  83. ~Parser() {}
  84. FuncGraphPtr ParseFuncGraph();
  85. FuncGraphPtr func_graph() const { return func_graph_; }
  86. ParseStatusCode errcode() const { return errcode_; }
  87. std::shared_ptr<ParseAst> ast() const { return ast_; }
  88. // get location info from the ast node
  89. LocationPtr GetLocation(const py::object &node) const;
  90. static void InitParserEnvironment(const py::object &obj);
  91. static void CleanParserResource();
  92. static FuncGraphPtr GetTopFuncGraph() { return top_func_graph_.lock(); }
  93. static void UpdateTopFuncGraph(const FuncGraphPtr &func_graph);
  94. private:
  95. // process the stmt node method list
  96. FunctionBlockPtr ParseReturn(const FunctionBlockPtr &block, const py::object &node);
  97. // parse expression
  98. FunctionBlockPtr ParseExpr(const FunctionBlockPtr &block, const py::object &node);
  99. // process a if statement
  100. FunctionBlockPtr ParseIf(const FunctionBlockPtr &block, const py::object &node);
  101. // process a while statement
  102. FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node);
  103. // process a for statement
  104. FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node);
  105. FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node);
  106. FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node);
  107. // process a function def statement
  108. FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node);
  109. // process a augment assign
  110. FunctionBlockPtr ParseAugAssign(const FunctionBlockPtr &block, const py::object &node);
  111. // process a global declaration
  112. FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node);
  113. // process assign statement
  114. FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node);
  115. // process break statement
  116. FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node);
  117. // process continue statement
  118. FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node);
  119. // process pass statement
  120. FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node);
  121. // process the expr and slice node method list
  122. AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node);
  123. // process a variable name
  124. AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node);
  125. // process NoneType
  126. AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node);
  127. // process Ellipsis
  128. AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node);
  129. // process a integer or float number
  130. AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node);
  131. // process a string variable
  132. AnfNodePtr ParseStr(const FunctionBlockPtr &block, const py::object &node);
  133. // process a name
  134. AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node);
  135. // process a function call
  136. AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node);
  137. // process function 'super'
  138. AnfNodePtr ParseSuper(const FunctionBlockPtr &block, const py::list &args);
  139. // process the if expression
  140. AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node);
  141. // process class type define
  142. AnfNodePtr ParseAttribute(const FunctionBlockPtr &block, const py::object &node);
  143. // process a compare expression
  144. AnfNodePtr ParseCompare(const FunctionBlockPtr &block, const py::object &node);
  145. // process a bool operation
  146. AnfNodePtr ParseBoolOp(const FunctionBlockPtr &block, const py::object &node);
  147. // process a lambda operation
  148. AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node);
  149. // process a tuple
  150. AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node);
  151. // process a tuple
  152. AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node);
  153. // process a tuple
  154. AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node);
  155. // process a slice
  156. AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node);
  157. // process a extslice
  158. AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node);
  159. // process a tuple
  160. AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node);
  161. // process a unaryop
  162. AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node);
  163. // process a dict ast node expression
  164. AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node);
  165. // generate argument nodes for ast function node
  166. void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node);
  167. // generate argument default value for ast function node
  168. void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
  169. // parse ast function node
  170. FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
  171. // parse ast statements
  172. FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
  173. // parse one ast statement node
  174. FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node);
  175. // parse an ast expresion node
  176. AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node);
  177. void MakeConditionBlocks(const FunctionBlockPtr &block, const FunctionBlockPtr &trueBlock,
  178. const FunctionBlockPtr &falseBlock);
  179. void RemoveUnnecessaryPhis();
  180. // write a new var
  181. void WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node);
  182. // assign value to single variable name
  183. void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
  184. // assign value to tuple
  185. void HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
  186. // assign value to class member
  187. void HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
  188. // assign value to subscript
  189. void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
  190. // process a bool operation value list
  191. AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode);
  192. CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node,
  193. const AnfNodePtr &op_iter);
  194. CNodePtr GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
  195. const AnfNodePtr &op_hasnext);
  196. FunctionBlockPtr GenerateBlockInFor(const TraceInfoPtr &trace_info);
  197. bool ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node,
  198. std::vector<AnfNodePtr> *packed_arguments);
  199. bool ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *packed_arguments,
  200. std::vector<AnfNodePtr> *group_arguments);
  201. AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
  202. const std::vector<AnfNodePtr> &packed_arguments,
  203. const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const;
  204. ScopePtr GetScopeForParseFunction();
  205. void BuildMethodMap();
  206. FunctionBlockPtr MakeFunctionBlock(const Parser &parse) {
  207. FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse);
  208. // In order to keep effect order in the sub-graphs which generated by control flow.
  209. // We copy the flags from the top graph to the sub-graphs.
  210. if (func_graph_ && !func_graph_->attrs().empty()) {
  211. block->func_graph()->set_attrs(func_graph_->attrs());
  212. }
  213. func_block_list_.push_back(block);
  214. return block;
  215. }
  216. // return a make tuple for input elements list
  217. AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
  218. // shared_ptr will be hold by GraphManager, so just hold a weak ref here.
  219. static FuncGraphWeakPtr top_func_graph_;
  220. // Python function id, used to indicate whether two CNodes come from the same Python function
  221. const std::shared_ptr<ParseAst> &ast_;
  222. FuncGraphPtr func_graph_;
  223. // error code setwhen parsing ast tree
  224. ParseStatusCode errcode_;
  225. // hold all reference for FunctionBlock in this round of parsing,
  226. // so in FunctionBlock class we can use FunctionBlock* in member
  227. // pre_blocks_ and jumps_ to break reference cycle.
  228. std::vector<FunctionBlockPtr> func_block_list_;
  229. using pStmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
  230. using pExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
  231. // define the function map to parse ast Statement
  232. std::map<std::string, pStmtFunc> stmt_method_map_;
  233. // define the function map to parse ast expression
  234. std::map<std::string, pExprFunc> expr_method_map_;
  235. // Save current loops to support 'continue', 'break' statement.
  236. std::stack<Loop> loops_;
  237. };
  238. // AST node type define code to ast
  239. class AstNodeType {
  240. public:
  241. AstNodeType(const py::object &node, const std::string &name, AstMainType type)
  242. : node_(node), node_name_(name), main_type_(type) {}
  243. ~AstNodeType() {}
  244. std::string node_name() const { return node_name_; }
  245. py::object node() const { return node_; }
  246. AstMainType main_type() const { return main_type_; }
  247. private:
  248. const py::object &node_;
  249. const std::string node_name_;
  250. AstMainType main_type_;
  251. };
  252. using AstNodeTypePtr = std::shared_ptr<AstNodeType>;
  253. // A helper class to parse python function
  254. class ParseAst {
  255. public:
  256. explicit ParseAst(const py::object &obj) : obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {}
  257. ~ParseAst() = default;
  258. bool InitParseAstInfo(const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
  259. py::object GetAstNode();
  260. py::list GetArgs(const py::object &func_node);
  261. py::list GetArgsDefaultValues(const py::object &func_node);
  262. AstNodeTypePtr GetNodeType(const py::object &node);
  263. AstSubType GetOpType(const py::object &node);
  264. template <class... T>
  265. py::object CallParserObjMethod(const std::string &method, const T &... args) {
  266. return python_adapter::CallPyObjMethod(parser_, method, args...);
  267. }
  268. template <class... T>
  269. py::object CallParseModFunction(const std::string &function, const T &... args) {
  270. return python_adapter::CallPyModFn(module_, function, args...);
  271. }
  272. const std::string &function_name() const { return function_name_; }
  273. const std::string &function_module() const { return function_module_; }
  274. const std::string &function_filename() const { return function_filename_; }
  275. int function_line_offset() const { return function_line_offset_; }
  276. py::function function() { return function_; }
  277. ParseTargetTypeDef target_type() const { return target_type_; }
  278. py::object obj() { return obj_; }
  279. py::object parser() { return parser_; }
  280. py::object module() { return module_; }
  281. py::object ast_tree() { return ast_tree_; }
  282. bool IsClassMember(const py::object &node);
  283. private:
  284. // save obj,eg: class instance or function
  285. py::object obj_;
  286. // function or class method.
  287. py::function function_;
  288. py::object ast_tree_;
  289. py::object parser_;
  290. py::module module_;
  291. // Is function or method
  292. ParseTargetTypeDef target_type_;
  293. std::string function_name_;
  294. std::string function_module_;
  295. std::string function_filename_;
  296. int function_line_offset_;
  297. };
  298. // update the graph flags
  299. bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph);
  300. AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
  301. TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph);
  302. } // namespace parse
  303. } // namespace mindspore
  304. #endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_