You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parse_dynamic.cc 11 kB

5 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2021 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include <set>
  19. #include <vector>
  20. #include <string>
  21. #include <memory>
  22. #include "utils/hash_set.h"
  23. #include "pipeline/jit/parse/parse_dynamic.h"
  24. #include "mindspore/core/ir/cell.h"
  25. namespace mindspore::parse {
  26. static mindspore::HashSet<std::string> cell_input_args_ = {};
  27. static const std::set<std::string> ignore_judge_dynamic_cell = {
  28. "Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal",
  29. "Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"};
  30. static const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE,
  31. parse::NAMED_PRIMITIVE_NAMECONSTANT,
  32. parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};
  33. std::string DynamicParser::ParseNodeName(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node,
  34. parse::AstMainType type) {
  35. MS_EXCEPTION_IF_NULL(ast);
  36. if (py::isinstance<py::none>(node)) {
  37. MS_LOG(DEBUG) << "Get none type node!";
  38. return "";
  39. }
  40. auto node_type = ast->GetNodeType(node);
  41. MS_EXCEPTION_IF_NULL(node_type);
  42. // Check node type
  43. parse::AstMainType node_main_type = node_type->main_type();
  44. if (node_main_type != type) {
  45. MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type;
  46. return "";
  47. }
  48. std::string node_name = node_type->node_name();
  49. MS_LOG(DEBUG) << "Ast node is " << node_name;
  50. return node_name;
  51. }
  52. void DynamicParser::ParseInputArgs(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node) {
  53. MS_EXCEPTION_IF_NULL(ast);
  54. py::list args = ast->GetArgs(fn_node);
  55. for (size_t i = 1; i < args.size(); i++) {
  56. std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
  57. MS_LOG(DEBUG) << "Input arg name: " << arg_name;
  58. (void)cell_input_args_.emplace(arg_name);
  59. }
  60. }
  61. bool DynamicParser::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
  62. MS_LOG(DEBUG) << "Parse if/while expr";
  63. py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
  64. const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR);
  65. if (node_name == parse::NAMED_PRIMITIVE_COMPARE) {
  66. py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT);
  67. py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS);
  68. if (comparators_node.empty()) {
  69. MS_LOG(DEBUG) << "Get comparators node failed!";
  70. return false;
  71. }
  72. auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
  73. auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
  74. // while self.a > self.b and changed self.a or self.b
  75. if (left == parse::NAMED_PRIMITIVE_ATTRIBUTE && right == parse::NAMED_PRIMITIVE_ATTRIBUTE) {
  76. auto left_value = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
  77. std::string left_variable;
  78. if (py::hasattr(left_node, "attr") && py::hasattr(left_value, "id")) {
  79. left_variable = py::cast<std::string>(left_value.attr("id")) + py::cast<std::string>(left_node.attr("attr"));
  80. }
  81. auto right_value = parse::python_adapter::GetPyObjAttr(comparators_node[0], parse::NAMED_PRIMITIVE_VALUE);
  82. std::string right_variable;
  83. if (py::hasattr(comparators_node[0], "attr") && py::hasattr(right_value, "id")) {
  84. right_variable =
  85. py::cast<std::string>(right_value.attr("id")) + py::cast<std::string>(comparators_node[0].attr("attr"));
  86. }
  87. return ParseBodyContext(ast, node, {left_variable, right_variable});
  88. }
  89. // if a[0]
  90. if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
  91. py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
  92. left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR);
  93. }
  94. MS_LOG(DEBUG) << "Left is " << left << " Right is " << right;
  95. if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() ||
  96. unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) {
  97. return true;
  98. }
  99. }
  100. // if flag:
  101. if (node_name == parse::NAMED_PRIMITIVE_NAME) {
  102. std::string id = py::cast<std::string>(test_node.attr("id"));
  103. if (cell_input_args_.find(id) != cell_input_args_.end()) {
  104. return true;
  105. }
  106. }
  107. return false;
  108. }
  109. bool DynamicParser::ParseAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
  110. MS_LOG(DEBUG) << "Parse assign expr";
  111. py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE);
  112. const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR);
  113. if (node_name == parse::NAMED_PRIMITIVE_CALL) {
  114. py::object func_node = parse::python_adapter::GetPyObjAttr(value_node, parse::NAMED_PRIMITIVE_FUNC);
  115. const auto &func_name = ParseNodeName(ast, func_node, parse::AST_MAIN_TYPE_EXPR);
  116. if (func_name == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
  117. py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE);
  118. py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE);
  119. if (py::isinstance<py::none>(value_in_slice_node)) {
  120. MS_LOG(DEBUG) << "Parse value node is none!";
  121. return false;
  122. }
  123. const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR);
  124. std::string id;
  125. if (py::hasattr(value_in_slice_node, "id")) {
  126. id = py::cast<std::string>(value_in_slice_node.attr("id"));
  127. }
  128. if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end() ||
  129. (!id.empty() && cell_input_args_.find(id) != cell_input_args_.end())) {
  130. return true;
  131. }
  132. }
  133. }
  134. return false;
  135. }
  136. bool DynamicParser::ParseAugAssignExprNode(const std::shared_ptr<parse::ParseFunctionAst> &, const py::object &node,
  137. const std::vector<std::string> &compare_prim) {
  138. MS_LOG(DEBUG) << "Parse augassign expr";
  139. bool ret = false;
  140. if (compare_prim.empty()) {
  141. return ret;
  142. }
  143. py::object target_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TARGET);
  144. if (py::isinstance<py::none>(target_node)) {
  145. MS_LOG(DEBUG) << "Parse target node is none!";
  146. return ret;
  147. }
  148. py::object value_node = parse::python_adapter::GetPyObjAttr(target_node, parse::NAMED_PRIMITIVE_VALUE);
  149. if (py::isinstance<py::none>(value_node)) {
  150. MS_LOG(DEBUG) << "Parse value node is none!";
  151. return ret;
  152. }
  153. std::string assign_prim;
  154. if (py::hasattr(target_node, "attr") && py::hasattr(value_node, "id")) {
  155. assign_prim = py::cast<std::string>(value_node.attr("id")) + py::cast<std::string>(target_node.attr("attr"));
  156. }
  157. auto iter = std::find(compare_prim.begin(), compare_prim.end(), assign_prim);
  158. if (iter != compare_prim.end()) {
  159. ret = true;
  160. }
  161. return ret;
  162. }
  163. bool DynamicParser::ParseForExprNode(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &node) {
  164. MS_LOG(DEBUG) << "Parse for expr";
  165. py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
  166. if (py::isinstance<py::none>(body_node)) {
  167. MS_LOG(DEBUG) << "Parse body of for expression is none!";
  168. return false;
  169. }
  170. py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN);
  171. size_t count = LongToSize(pcount);
  172. MS_LOG(DEBUG) << "The for nodes count in body is " << count;
  173. for (size_t i = 0; i < count; ++i) {
  174. auto it = py::cast<py::list>(body_node)[i];
  175. const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT);
  176. if (node_name == parse::NAMED_PRIMITIVE_ASSIGN && ParseAssignExprNode(ast, it)) {
  177. return true;
  178. }
  179. }
  180. return false;
  181. }
  182. bool DynamicParser::ParseBodyContext(const std::shared_ptr<parse::ParseFunctionAst> &ast, const py::object &fn_node,
  183. const std::vector<std::string> &compare_prim) {
  184. MS_EXCEPTION_IF_NULL(ast);
  185. py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
  186. if (py::isinstance<py::none>(func_obj)) {
  187. MS_LOG(DEBUG) << "Parse body of cell is none!";
  188. return false;
  189. }
  190. py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN);
  191. size_t count = IntToSize(pcount);
  192. MS_LOG(DEBUG) << "The nodes count in body is " << count;
  193. bool ret = false;
  194. for (size_t i = 0; i < count; ++i) {
  195. auto node = py::cast<py::list>(func_obj)[i];
  196. const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT);
  197. if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) {
  198. ret = ParseAssignExprNode(ast, node);
  199. } else if (node_name == parse::NAMED_PRIMITIVE_AUGASSIGN) {
  200. ret = ParseAugAssignExprNode(ast, node, compare_prim);
  201. } else if (node_name == parse::NAMED_PRIMITIVE_FOR) {
  202. ret = ParseForExprNode(ast, node);
  203. } else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) {
  204. ret = ParseIfWhileExprNode(ast, node);
  205. }
  206. if (ret) {
  207. MS_LOG(INFO) << "Current cell is dynamic!";
  208. break;
  209. }
  210. }
  211. return ret;
  212. }
  213. std::string DynamicParser::GetCellInfo(const py::object &cell) {
  214. if (py::isinstance<Cell>(cell)) {
  215. auto c_cell = py::cast<CellPtr>(cell);
  216. MS_EXCEPTION_IF_NULL(c_cell);
  217. auto cell_info = c_cell->ToString();
  218. return cell_info;
  219. }
  220. return "";
  221. }
  222. bool DynamicParser::IsDynamicCell(const py::object &cell) {
  223. std::string cell_info = GetCellInfo(cell);
  224. if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) {
  225. return false;
  226. }
  227. // Using ast parse to check whether the construct of cell will be changed
  228. auto ast = std::make_shared<parse::ParseFunctionAst>(cell);
  229. bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
  230. if (!success) {
  231. MS_LOG(ERROR) << "Parse code to ast tree failed";
  232. return false;
  233. }
  234. py::object fn_node = ast->GetAstNode();
  235. // get the name of input args as the initialize of dynamic_variables
  236. ParseInputArgs(ast, fn_node);
  237. // parse body context
  238. bool ret = false;
  239. ret = ParseBodyContext(ast, fn_node);
  240. cell_input_args_.clear();
  241. return ret;
  242. }
  243. } // namespace mindspore::parse