You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pattern.cc 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/pattern.h"
  17. #include "pybind_api/api_register.h"
  18. namespace mindspore {
  19. namespace opt {
  20. namespace python_pass {
  21. int Pattern::g_id_ = 0;
  22. MatchResultPtr Prim::match(const AnfNodePtr &node) {
  23. if (!IsValueNode<Primitive>(node)) {
  24. return nullptr;
  25. }
  26. MatchResultPtr res = std::make_shared<MatchResult>();
  27. // iterate over all primitives
  28. for (auto &iter : primitives_) {
  29. if (IsPrimitive(node, iter) || iter->name() == "*") {
  30. matched_prim_ = iter;
  31. res->add_entry(shared_from_base<Prim>(), node);
  32. return res;
  33. }
  34. }
  35. return nullptr;
  36. }
  37. MatchResultPtr Call::match(const AnfNodePtr &node) {
  38. if (!IsPrimitiveCNode(node)) {
  39. return nullptr;
  40. }
  41. MatchResultPtr res = std::make_shared<MatchResult>();
  42. // IsPrimitiveCNode
  43. auto cnode = node->cast<CNodePtr>();
  44. MS_EXCEPTION_IF_NULL(cnode);
  45. // Check Primitive ValueNode
  46. if (prim_pattern_ != nullptr) {
  47. // Passed in prim_pattern
  48. auto prim_value_res = prim_pattern_->match(cnode->input(0));
  49. if (prim_value_res == nullptr) {
  50. return nullptr;
  51. }
  52. res->merge(prim_value_res);
  53. } else if (prim_ != nullptr) {
  54. // Passed in primitive/primitive str
  55. if (!IsPrimitive(cnode->input(0), prim_)) {
  56. return nullptr;
  57. }
  58. } else {
  59. MS_LOG(EXCEPTION) << "Uninitialized CallWith pattern.";
  60. }
  61. // Check inputs
  62. auto p_inputs_size = inputs_.size();
  63. auto node_inputs_size = cnode->size() - 1;
  64. if (p_inputs_size != 0 && p_inputs_size != node_inputs_size) {
  65. return nullptr;
  66. }
  67. // If inputs is not specified, add node without looking into its inputs
  68. if (p_inputs_size == 0) {
  69. res->add_entry(shared_from_base<Call>(), cnode->input(0));
  70. return res;
  71. }
  72. bool failed = false;
  73. for (std::size_t i = 0; i < node_inputs_size; i++) {
  74. auto pattern = inputs_[i];
  75. auto input = cnode->input(i + 1);
  76. auto input_match_result = pattern->match(input);
  77. if (input_match_result == nullptr) {
  78. failed = true;
  79. break;
  80. }
  81. res->merge(input_match_result);
  82. }
  83. if (!failed) {
  84. res->add_entry(shared_from_base<Call>(), cnode->input(0));
  85. return res;
  86. }
  87. return nullptr;
  88. }
  89. MatchResultPtr OneOf::match(const AnfNodePtr &node) {
  90. for (auto &iter : patterns_) {
  91. auto res = iter->match(node);
  92. if (res != nullptr) {
  93. res->add_entry(shared_from_base<OneOf>(), node);
  94. return res;
  95. }
  96. }
  97. return nullptr;
  98. }
  99. MatchResultPtr NoneOf::match(const AnfNodePtr &node) {
  100. for (auto &iter : patterns_) {
  101. auto res = iter->match(node);
  102. if (res != nullptr) {
  103. return nullptr;
  104. }
  105. }
  106. auto res = std::make_shared<MatchResult>();
  107. res->add_entry(shared_from_base<NoneOf>(), node);
  108. return res;
  109. }
  110. MatchResultPtr Any::match(const AnfNodePtr &node) {
  111. MatchResultPtr res = std::make_shared<MatchResult>();
  112. res->add_entry(shared_from_base<Any>(), node);
  113. return res;
  114. }
  115. MatchResultPtr Imm::match(const AnfNodePtr &node) {
  116. if (!IsValueNode<Int32Imm>(node)) {
  117. return nullptr;
  118. }
  119. // Check value
  120. auto value_node = node->cast<ValueNodePtr>();
  121. MS_EXCEPTION_IF_NULL(value_node);
  122. auto value_ptr = value_node->value()->cast<Int32ImmPtr>();
  123. MS_EXCEPTION_IF_NULL(value_ptr);
  124. if ((int32_t)value_ptr->value() == value_) {
  125. MatchResultPtr res = std::make_shared<MatchResult>();
  126. res->add_entry(shared_from_base<Imm>(), node);
  127. return res;
  128. }
  129. return nullptr;
  130. }
  131. AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {
  132. auto entry = match_result_.find(pattern);
  133. if (entry == match_result_.end()) {
  134. return nullptr;
  135. }
  136. return entry->second;
  137. }
  138. void MatchResult::merge(const MatchResultPtr &other_result) {
  139. auto other_result_map = other_result->result();
  140. // add/update entries in other_result
  141. for (auto &iter : other_result_map) {
  142. match_result_[iter.first] = iter.second;
  143. }
  144. }
  145. REGISTER_PYBIND_DEFINE(
  146. Pattern, ([](const py::module *m) {
  147. (void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
  148. (void)py::class_<OneOf, std::shared_ptr<OneOf>, Pattern>(*m, "OneOf_").def(py::init<vector<PatternPtr>>());
  149. (void)py::class_<Prim, std::shared_ptr<Prim>, Pattern>(*m, "Prim_", py::dynamic_attr())
  150. .def(py::init<vector<PrimitivePyPtr>, string>())
  151. .def(py::init<vector<string>, string>());
  152. (void)py::class_<Call, std::shared_ptr<Call>, Pattern>(*m, "Call_")
  153. .def(py::init<PatternPtr, vector<PatternPtr>>())
  154. .def(py::init<PrimitivePyPtr, vector<PatternPtr>>())
  155. .def(py::init<string, vector<PatternPtr>>());
  156. (void)py::class_<NoneOf, std::shared_ptr<NoneOf>, Pattern>(*m, "NoneOf_").def(py::init<vector<PatternPtr>>());
  157. (void)py::class_<Any, std::shared_ptr<Any>, Pattern>(*m, "Any").def(py::init<>());
  158. (void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
  159. .def(py::init<tensor::TensorPtr>());
  160. (void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_")
  161. .def(py::init<string, tensor::TensorPtr, bool, bool>());
  162. (void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int>());
  163. }));
  164. } // namespace python_pass
  165. } // namespace opt
  166. } // namespace mindspore