You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parse.h 17 kB

optimize the comment and log description 修改: ops/operations/_inner_ops.py 修改: ops/operations/_quant_ops.py 修改: ops/operations/array_ops.py 修改: ops/operations/comm_ops.py 修改: ops/operations/math_ops.py 修改: ops/operations/quantum_ops.py 修改: ops/operations/rl_ops.py 修改: ops/operations/sponge_ops.py 修改: ops/operations/sponge_update_ops.py 修改: train/__init__.py 修改: common/tensor.py 修改: train/serialization.py 修改: ccsrc/pipeline/jit/parse/parse.h 修改: explainer/benchmark/_attribution/metric.py 修改: ops/composite/multitype_ops/_constexpr_utils.py 修改: ops/operations/comm_ops.py 修改: RELEASE.md 修改: mindspore/_extends/parse/standard_method.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/concat_offset_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/dynamic_shape_cpu_kernel.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc 修改: mindspore/ccsrc/frontend/parallel/strategy.h 修改: mindspore/common/tensor.py 修改: mindspore/core/abstract/prim_arrays.cc 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/core/ops/conv2d.cc 修改: mindspore/core/ops/logical_and.h 修改: mindspore/core/ops/logical_not.h 修改: mindspore/core/ops/logical_or.h 修改: mindspore/core/ops/reduce_all.h 修改: mindspore/core/ops/reduce_any.h 修改: mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc 修改: mindspore/nn/layer/quant.py 修改: mindspore/nn/optim/sgd.py 修改: mindspore/nn/sparse/sparse.py 修改: mindspore/numpy/array_creations.py 修改: mindspore/numpy/array_ops.py 修改: mindspore/numpy/logic_ops.py 修改: mindspore/numpy/math_ops.py 修改: mindspore/ops/operations/_inner_ops.py 修改: mindspore/ops/operations/array_ops.py 修改: mindspore/ops/operations/rl_ops.py 修改: mindspore/train/_utils.py 修改: tests/ut/python/model/test_lenet_core_after_exception.py 修改: mindspore/_extends/parse/standard_method.py 修改: mindspore/ops/operations/rl_ops.py 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/core/ops/conv2d.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ctcloss_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_pull_weight_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_push_weight_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_arithmetic_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h 修改: mindspore/ccsrc/fl/server/server.cc 修改: mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc 修改: mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h 修改: mindspore/ccsrc/frontend/optimizer/irpass/inline.h 修改: mindspore/ccsrc/minddata/dataset/core/device_tensor.cc 修改: mindspore/ccsrc/minddata/dataset/core/tensor.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/emnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc 修改: mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.cc 修改: mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc 修改: mindspore/ccsrc/pipeline/jit/action.cc 修改: mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc 修改: mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_adapter.cc 修改: mindspore/compression/quant/quant_utils.py 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/dataset/engine/validators.py 修改: mindspore/lite/micro/coder/opcoders/nnacl/fp32/affine_fp32_coder.cc 修改: mindspore/lite/micro/coder/opcoders/nnacl/int8/affine_int8_coder.cc 修改: mindspore/lite/src/runtime/kernel/ascend310/src/custom_kernel.cc 修改: mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc 修改: mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.cc 修改: mindspore/lite/tools/common/graph_util.h 修改: mindspore/lite/tools/optimizer/fisson/fisson_util.cc 修改: mindspore/ops/composite/math_ops.py 修改: mindspore/ops/operations/_inner_ops.py 修改: mindspore/ops/operations/array_ops.py 修改: mindspore/ops/operations/math_ops.py 修改: mindspore/ops/operations/other_ops.py 修改: mindspore/boost/boost_cell_wrapper.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc 修改: mindspore/ccsrc/common/trans.cc 修改: mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc 修改: mindspore/lite/src/common/log_util.h 修改: mindspore/nn/wrap/loss_scale.py 修改: mindspore/parallel/nn/moe.py 修改: tests/mindspore_test_framework/mindspore_test.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc 修改: mindspore/lite/tools/common/graph_util.h 修改: mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc 修改: mindspore/core/ops/conv2d.cc 修改: tests/ut/python/model/test_lenet_core_after_exception.py
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
  19. #define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
  20. #include <limits>
  21. #include <vector>
  22. #include <string>
  23. #include <map>
  24. #include <set>
  25. #include <stack>
  26. #include <memory>
  27. #include "utils/misc.h"
  28. #include "ir/anf.h"
  29. #include "pipeline/jit/parse/parse_base.h"
  30. #include "pipeline/jit/parse/python_adapter.h"
  31. #include "pipeline/jit/parse/function_block.h"
  32. namespace mindspore {
  33. namespace parse {
  34. // Parse status define
  35. enum ParseStatusCode : int64_t {
  36. PARSE_SUCCESS = 0,
  37. PARSE_FUNCTION_IS_NULL, // Python function is null
  38. PARSE_PARAMETER_INVALID, // Parameter is invalid
  39. PARSE_NO_RETURN, // Function no return node
  40. PARSE_NODE_TYPE_NO_MATCH, // Ast node type is error
  41. PARSE_NODE_TYPE_UNKNOWN, // Node type is unknown
  42. PARSE_NODE_METHOD_UNSUPPORTED, // No method to parse the node
  43. PARSE_DONT_RESOLVE_SYMBOL, // Can't resolve the string
  44. PARSE_NOT_SUPPORTED_COMPARE_EXPR, // The comparison is not supported
  45. PARSE_FAILURE = 0xFF
  46. };
  47. // Max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
  48. // will be sunk(i.e. not unrolled)
  49. // NOTE: Since when the for loop was unrolled, it depends backend operators `tuple_getitem` and `scalar_add` which were
  50. // not implemented, so here set MAX_FOR_LOOP_COUNT to int64_t max limit to override default value `600`. This will make
  51. // the for loop will always be unrolled, but don't worry about the memory were exhausted, an exception will be raised
  52. // when function call depth exceeds the limit `context.get_context('max_call_depth')`.
  53. const int64_t MAX_FOR_LOOP_COUNT = std::numeric_limits<int64_t>::max();
  54. class AstNodeType;
  55. class ParseFunctionAst;
  56. // Save loop info for 'continue' and 'break' statements.
  57. struct Loop {
  58. // Loop header block.
  59. FunctionBlockPtr header;
  60. // Loop iterator node, used in 'for loop'.
  61. AnfNodePtr iterator;
  62. // Loop end block.
  63. FunctionBlockPtr end;
  64. Loop(const FunctionBlockPtr &header, const AnfNodePtr &iterator, const FunctionBlockPtr &end)
  65. : header(header), iterator(iterator), end(end) {}
  66. ~Loop() = default;
  67. };
  68. // Loop context for loop stack management.
  69. class LoopContext {
  70. public:
  71. LoopContext(std::stack<Loop> *loops, const FunctionBlockPtr &header, const AnfNodePtr &iterator) : loops_(loops) {
  72. loops_->emplace(header, iterator, nullptr);
  73. }
  74. ~LoopContext() { loops_->pop(); }
  75. const FunctionBlockPtr &EndBlock() const { return loops_->top().end; }
  76. private:
  77. std::stack<Loop> *loops_;
  78. };
  79. // Parser to parse python function
  80. class Parser {
  81. public:
  82. explicit Parser(const std::shared_ptr<ParseFunctionAst> &ast);
  83. ~Parser() {}
  84. FuncGraphPtr ParseFuncGraph();
  85. FuncGraphPtr func_graph() const { return func_graph_; }
  86. ParseStatusCode errcode() const { return errcode_; }
  87. std::shared_ptr<ParseFunctionAst> ast() const { return ast_; }
  88. const std::string &support_fallback() const { return support_fallback_; }
  89. // Get location info from the ast node
  90. LocationPtr GetLocation(const py::object &node) const;
  91. static void InitParserEnvironment(const py::object &obj);
  92. static void CleanParserResource();
  93. static FuncGraphPtr GetTopFuncGraph() { return top_func_graph_.lock(); }
  94. static void UpdateTopFuncGraph(const FuncGraphPtr &func_graph);
  95. private:
  96. // Process the stmt node method list
  97. FunctionBlockPtr ParseReturn(const FunctionBlockPtr &block, const py::object &node);
  98. // Parse expression
  99. FunctionBlockPtr ParseExpr(const FunctionBlockPtr &block, const py::object &node);
  100. // Process a if statement
  101. FunctionBlockPtr ParseIf(const FunctionBlockPtr &block, const py::object &node);
  102. // Process a while statement
  103. FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node);
  104. // Process a for statement
  105. FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node);
  106. FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node);
  107. FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node);
  108. // Process a function def statement
  109. FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node);
  110. // Process a augment assign
  111. FunctionBlockPtr ParseAugAssign(const FunctionBlockPtr &block, const py::object &node);
  112. // Process a global declaration
  113. FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node);
  114. // Process assign statement
  115. FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node);
  116. // Process break statement
  117. FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node);
  118. // Process continue statement
  119. FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node);
  120. // Process pass statement
  121. FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node);
  122. // Process the expr and slice node method list
  123. AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node);
  124. // Process a variable name
  125. AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node);
  126. // Process NoneType
  127. AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node);
  128. // Process Ellipsis
  129. AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node);
  130. // Process an integer or float number
  131. AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node);
  132. // Process a string variable
  133. AnfNodePtr ParseStr(const FunctionBlockPtr &block, const py::object &node);
  134. // Process a Constant
  135. AnfNodePtr ParseConstant(const FunctionBlockPtr &block, const py::object &node);
  136. // Process a name
  137. AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node);
  138. // Process a function call
  139. AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node);
  140. // Process function 'super'
  141. AnfNodePtr ParseSuper(const FunctionBlockPtr &block, const py::list &args);
  142. // Process the if expression
  143. AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node);
  144. // Process class type define
  145. AnfNodePtr ParseAttribute(const FunctionBlockPtr &block, const py::object &node);
  146. // Process a compare expression
  147. AnfNodePtr ParseCompare(const FunctionBlockPtr &block, const py::object &node);
  148. // Process a bool operation
  149. AnfNodePtr ParseBoolOp(const FunctionBlockPtr &block, const py::object &node);
  150. // Process a lambda operation
  151. AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node);
  152. // Process a tuple
  153. AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node);
  154. // Process a tuple
  155. AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node);
  156. // Process a tuple
  157. AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node);
  158. // Process a slice
  159. AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node);
  160. // Process a extslice
  161. AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node);
  162. // Process a tuple
  163. AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node);
  164. // Process a unaryop
  165. AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node);
  166. // Process a dict ast node expression
  167. AnfNodePtr ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes,
  168. const std::vector<AnfNodePtr> &value_nodes);
  169. AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node);
  170. // Process ListComp expression
  171. AnfNodePtr ParseListComp(const FunctionBlockPtr &block, const py::object &node);
  172. FunctionBlockPtr ParseListCompIter(const FunctionBlockPtr &block, const py::object &node,
  173. const py::object &generator_node);
  174. AnfNodePtr ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
  175. const py::object &node, const py::object &generator_node);
  176. // Check if the node need interpreting.
  177. AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
  178. const py::object &value_object);
  179. // Generate argument nodes for ast function node
  180. void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node);
  181. // Generate argument default value for ast function node
  182. void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
  183. // Parse ast function node
  184. FunctionBlockPtr ParseDefFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
  185. // Parse lambda function node
  186. FunctionBlockPtr ParseLambdaFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
  187. // Parse ast statements
  188. FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
  189. // Parse one ast statement node
  190. FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node);
  191. // Parse an ast expression node
  192. AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node);
  193. void MakeConditionBlocks(const FunctionBlockPtr &block, const FunctionBlockPtr &trueBlock,
  194. const FunctionBlockPtr &falseBlock);
  195. void RemoveUnnecessaryPhis();
  196. // Write a new var
  197. void WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node);
  198. // Assign value to single variable name
  199. void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
  200. // Assign value to tuple
  201. void HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
  202. // Assign value to class member
  203. void HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
  204. // Assign value to subscript
  205. void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
  206. // Process a bool operation value list
  207. AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode);
  208. CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node,
  209. const AnfNodePtr &op_iter);
  210. CNodePtr GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
  211. const AnfNodePtr &op_hasnext);
  212. FunctionBlockPtr GenerateBlock(const TraceInfoPtr &trace_info);
  213. bool ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node,
  214. std::vector<AnfNodePtr> *packed_arguments);
  215. bool ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *packed_arguments,
  216. std::vector<AnfNodePtr> *group_arguments);
  217. AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
  218. const std::vector<AnfNodePtr> &packed_arguments,
  219. const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const;
  220. ScopePtr GetScopeForParseFunction();
  221. void BuildMethodMap();
  222. FunctionBlockPtr MakeFunctionBlock(const Parser &parse) {
  223. FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse);
  224. // In order to keep effect order in the sub-graphs which generated by control flow.
  225. // We copy the flags from the top graph to the sub-graphs.
  226. if (func_graph_ && !func_graph_->attrs().empty()) {
  227. for (const auto &attr : func_graph_->attrs()) {
  228. // The flag FUNC_GRAPH_OUTPUT_NO_RECOMPUTE should be only set in the top graph.
  229. if (attr.first != FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) {
  230. block->func_graph()->set_attr(attr.first, attr.second);
  231. }
  232. }
  233. }
  234. func_block_list_.push_back(block);
  235. return block;
  236. }
  237. // Return a make tuple for input elements list
  238. AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
  239. int64_t GetForTransToWhileLoop();
  240. // The shared_ptr will be hold by GraphManager, so just hold a weak ref here.
  241. static FuncGraphWeakPtr top_func_graph_;
  242. // Python function id, used to indicate whether two CNodes come from the same Python function
  243. const std::shared_ptr<ParseFunctionAst> &ast_;
  244. FuncGraphPtr func_graph_;
  245. // Error code setwhen parsing ast tree
  246. ParseStatusCode errcode_;
  247. // Hold all reference for FunctionBlock in this round of parsing,
  248. // so in FunctionBlock class we can use FunctionBlock* in member
  249. // pre_blocks_ and jumps_ to break reference cycle.
  250. std::vector<FunctionBlockPtr> func_block_list_;
  251. using StmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
  252. using ExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
  253. // Define the function map to parse ast Statement
  254. std::map<std::string, StmtFunc> stmt_method_map_;
  255. // Define the function map to parse ast expression
  256. std::map<std::string, ExprFunc> expr_method_map_;
  257. // Save current loops to support 'continue', 'break' statement.
  258. std::stack<Loop> loops_;
  259. string max_for_loop_count_str_;
  260. string support_fallback_;
  261. };
  262. // AST node type define code to ast
  263. class AstNodeType {
  264. public:
  265. AstNodeType(const py::object &node, const std::string &name, AstMainType type)
  266. : node_(node), node_name_(name), main_type_(type) {}
  267. ~AstNodeType() {}
  268. std::string node_name() const { return node_name_; }
  269. py::object node() const { return node_; }
  270. AstMainType main_type() const { return main_type_; }
  271. private:
  272. const py::object &node_;
  273. const std::string node_name_;
  274. AstMainType main_type_;
  275. };
  276. using AstNodeTypePtr = std::shared_ptr<AstNodeType>;
  277. // A helper class to parse python function
  278. class ParseFunctionAst {
  279. public:
  280. explicit ParseFunctionAst(const py::object &obj)
  281. : obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {}
  282. ~ParseFunctionAst() = default;
  283. bool InitParseAstInfo(const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
  284. py::object GetAstNode();
  285. py::str GetAstNodeText(const py::object &node);
  286. py::list GetArgs(const py::object &func_node);
  287. py::list GetArgsDefaultValues(const py::object &func_node);
  288. AstNodeTypePtr GetNodeType(const py::object &node);
  289. AstSubType GetOpType(const py::object &node);
  290. template <class... T>
  291. py::object CallParserObjMethod(const std::string &method, const T &... args) {
  292. return python_adapter::CallPyObjMethod(parser_, method, args...);
  293. }
  294. template <class... T>
  295. py::object CallParseModFunction(const std::string &function, const T &... args) {
  296. return python_adapter::CallPyModFn(module_, function, args...);
  297. }
  298. const std::string &function_name() const { return function_name_; }
  299. const std::string &function_module() const { return function_module_; }
  300. const std::string &function_filename() const { return function_filename_; }
  301. int64_t function_line_offset() const { return function_line_offset_; }
  302. py::function function() { return function_; }
  303. ParseTargetTypeDef target_type() const { return target_type_; }
  304. py::object obj() { return obj_; }
  305. py::object parser() { return parser_; }
  306. py::object module() { return module_; }
  307. py::object ast_tree() { return ast_tree_; }
  308. bool IsClassMember(const py::object &node);
  309. private:
  310. // Save obj,eg: class instance or function
  311. py::object obj_;
  312. // Function or class method.
  313. py::function function_;
  314. py::object ast_tokens_;
  315. py::object ast_tree_;
  316. py::object parser_;
  317. py::module module_;
  318. // Is function or method
  319. ParseTargetTypeDef target_type_;
  320. std::string function_name_;
  321. std::string function_module_;
  322. std::string function_filename_;
  323. int64_t function_line_offset_;
  324. };
  325. // Update the graph flags
  326. bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph);
  327. AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
  328. TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph);
  329. } // namespace parse
  330. } // namespace mindspore
  331. #endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_