You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pattern.cc 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/pattern.h"
  17. #include "pybind_api/api_register.h"
  18. #include "pybind_api/export_flags.h"
  19. namespace mindspore {
  20. namespace opt {
  21. namespace python_pass {
  22. int Pattern::g_id_ = 0;
  23. MatchResultPtr IsPrimTypeOf::match(const AnfNodePtr &node) {
  24. if (!IsValueNode<Primitive>(node)) {
  25. return nullptr;
  26. }
  27. MatchResultPtr res = std::make_shared<MatchResult>();
  28. if (IsValueNode<Primitive>(node)) {
  29. // iterate over all primitives
  30. for (auto &iter : primitives_) {
  31. if (IsPrimitive(node, iter) || iter->name() == "*") {
  32. matched_prim_ = iter;
  33. res->add_entry(shared_from_base<IsPrimTypeOf>(), node);
  34. return res;
  35. }
  36. }
  37. }
  38. return nullptr;
  39. }
  40. MatchResultPtr CallWith::match(const AnfNodePtr &node) {
  41. if (!IsPrimitiveCNode(node)) {
  42. return nullptr;
  43. }
  44. MatchResultPtr res = std::make_shared<MatchResult>();
  45. // IsPrimitiveCNode
  46. auto cnode = node->cast<CNodePtr>();
  47. MS_EXCEPTION_IF_NULL(cnode);
  48. // Check Primitive ValueNode
  49. if (prim_pattern_ != nullptr) {
  50. // Passed in prim_pattern
  51. auto prim_value_res = prim_pattern_->match(cnode->input(0));
  52. if (prim_value_res == nullptr) {
  53. return nullptr;
  54. }
  55. res->merge(prim_value_res);
  56. } else if (prim_ != nullptr) {
  57. // Passed in primitive/primitive str
  58. if (!IsPrimitive(cnode->input(0), prim_)) {
  59. return nullptr;
  60. }
  61. } else {
  62. MS_LOG(EXCEPTION) << "Uninitialized CallWith pattern.";
  63. }
  64. // Check inputs
  65. auto p_inputs_size = inputs_.size();
  66. auto node_inputs_size = cnode->size() - 1;
  67. if (p_inputs_size != 0 && p_inputs_size != node_inputs_size) {
  68. return nullptr;
  69. }
  70. // If inputs is not specified, add node without looking into its inputs
  71. if (p_inputs_size == 0) {
  72. res->add_entry(shared_from_base<CallWith>(), cnode->input(0));
  73. return res;
  74. }
  75. bool failed = false;
  76. for (std::size_t i = 0; i < node_inputs_size; i++) {
  77. auto pattern = inputs_[i];
  78. auto input = cnode->input(i + 1);
  79. auto input_match_result = pattern->match(input);
  80. if (input_match_result == nullptr) {
  81. failed = true;
  82. break;
  83. }
  84. res->merge(input_match_result);
  85. }
  86. if (!failed) {
  87. res->add_entry(shared_from_base<CallWith>(), cnode->input(0));
  88. return res;
  89. }
  90. return nullptr;
  91. }
  92. MatchResultPtr IsIn::match(const AnfNodePtr &node) {
  93. for (auto &iter : patterns_) {
  94. auto res = iter->match(node);
  95. if (res != nullptr) {
  96. return res;
  97. }
  98. }
  99. return nullptr;
  100. }
  101. MatchResultPtr IsNot::match(const AnfNodePtr &node) {
  102. for (auto &iter : patterns_) {
  103. auto res = iter->match(node);
  104. if (res != nullptr) {
  105. return nullptr;
  106. }
  107. }
  108. auto res = std::make_shared<MatchResult>();
  109. res->add_entry(shared_from_base<IsNot>(), node);
  110. return res;
  111. }
  112. MatchResultPtr AnyPattern::match(const AnfNodePtr &node) {
  113. MatchResultPtr res = std::make_shared<MatchResult>();
  114. res->add_entry(shared_from_base<AnyPattern>(), node);
  115. return res;
  116. }
  117. AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {
  118. auto entry = match_result_.find(pattern);
  119. if (entry == match_result_.end()) {
  120. return nullptr;
  121. }
  122. return entry->second;
  123. }
  124. void MatchResult::merge(const MatchResultPtr &other_result) {
  125. auto other_result_map = other_result->_result();
  126. // add/update entries in other_result
  127. for (auto &iter : other_result_map) {
  128. match_result_[iter.first] = iter.second;
  129. }
  130. }
  131. REGISTER_PYBIND_DEFINE(
  132. Pattern, ([](const py::module *m) {
  133. (void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
  134. (void)py::class_<IsIn, std::shared_ptr<IsIn>, Pattern>(*m, "IsIn_").def(py::init<vector<PatternPtr>>());
  135. (void)py::class_<IsPrimTypeOf, std::shared_ptr<IsPrimTypeOf>, Pattern>(*m, "IsPrimTypeOf_", py::dynamic_attr())
  136. .def(py::init<vector<PrimitivePyPtr>, string, bool>())
  137. .def(py::init<vector<string>, string, bool>());
  138. (void)py::class_<CallWith, std::shared_ptr<CallWith>, Pattern>(*m, "CallWith_")
  139. .def(py::init<PatternPtr, vector<PatternPtr>, bool>())
  140. .def(py::init<PrimitivePyPtr, vector<PatternPtr>, bool>())
  141. .def(py::init<string, vector<PatternPtr>, bool>());
  142. (void)py::class_<IsNot, std::shared_ptr<IsNot>, Pattern>(*m, "IsNot_").def(py::init<vector<PatternPtr>>());
  143. (void)py::class_<AnyPattern, std::shared_ptr<AnyPattern>, Pattern>(*m, "AnyPattern").def(py::init<>());
  144. (void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
  145. .def(py::init<tensor::TensorPtr>());
  146. }));
  147. } // namespace python_pass
  148. } // namespace opt
  149. } // namespace mindspore