You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

merge_addn.h 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_
  17. #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_
  18. #include <vector>
  19. #include <algorithm>
  20. #include "optimizer/irpass.h"
  21. #include "optimizer/optimizer.h"
  22. #include "ir/visitor.h"
  23. #include "operator/ops.h"
  24. namespace mindspore {
  25. namespace opt {
  26. namespace irpass {
  27. // {PrimAddN, {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}} ->
  28. // {{PrimAddNClass}, {prim::kPrimMakeTuple, Xs, Ys}}
  29. // {PrimAddN, {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}} ->
  30. // {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}}
  31. class MergeAddN : public AnfVisitor {
  32. public:
  33. MergeAddN() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {}
  34. ~MergeAddN() override = default;
  35. AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
  36. Reset();
  37. optimizer_ = optimizer;
  38. is_outer_ = true;
  39. AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
  40. if (!is_match_ || node->func_graph() == nullptr) {
  41. return nullptr;
  42. }
  43. auto fg = node->func_graph();
  44. // {PrimAddNClass}
  45. auto addn_node = fg->NewCNode({NewValueNode(PrimAddN_)});
  46. // {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs}
  47. (void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple));
  48. auto make_node = fg->NewCNode(args_);
  49. return fg->NewCNode({addn_node, make_node});
  50. }
  51. void Visit(const CNodePtr &cnode) override {
  52. if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
  53. return;
  54. }
  55. auto &inputs = cnode->inputs();
  56. if (is_outer_) {
  57. (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Ys_));
  58. is_outer_ = false;
  59. is_inner_ = true;
  60. // {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}
  61. AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs[1]);
  62. if (is_match_) {
  63. if (!is_unique(inputs[1])) {
  64. is_match_ = false;
  65. return;
  66. }
  67. (void)Ys_.erase(Ys_.begin());
  68. (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_));
  69. (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_));
  70. return;
  71. }
  72. // {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}
  73. AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs.back());
  74. if (is_match_) {
  75. if (!is_unique(inputs.back())) {
  76. is_match_ = false;
  77. return;
  78. }
  79. Ys_.pop_back();
  80. (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_));
  81. (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_));
  82. return;
  83. }
  84. return;
  85. }
  86. if (is_inner_) {
  87. is_match_ = true;
  88. (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
  89. }
  90. }
  91. bool is_unique(const AnfNodePtr &node) {
  92. auto mng = optimizer_->resource()->manager();
  93. auto &node_users = mng->node_users();
  94. if (node_users.find(node) == node_users.end()) {
  95. return false;
  96. }
  97. size_t n_use = node_users[node].size();
  98. return n_use == 1;
  99. }
  100. void Reset() {
  101. Xs_.clear();
  102. Ys_.clear();
  103. args_.clear();
  104. is_inner_ = false;
  105. is_outer_ = false;
  106. is_match_ = false;
  107. }
  108. private:
  109. ValuePtr PrimAddN_;
  110. OptimizerPtr optimizer_{nullptr};
  111. std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{};
  112. bool is_inner_{false}, is_outer_{false}, is_match_{false};
  113. };
  114. // {PrimAddN, {kPrimMakeTuple, Xs}}
  115. class AddNZeroFilter : public AnfVisitor {
  116. public:
  117. AddNZeroFilter() : PrimAddN_(prim::GetPythonOps("AddN", "mindspore.ops.operations")) {}
  118. ~AddNZeroFilter() override = default;
  119. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  120. Reset();
  121. AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
  122. if (filtered_Xs_.empty() || node->func_graph() == nullptr) {
  123. return nullptr;
  124. }
  125. // if only two node in filtered_nodes, {make_tuple, x}. return x.
  126. if (filtered_Xs_.size() == 2) {
  127. return filtered_Xs_[1];
  128. }
  129. // if only one node in filtered_nodes, all node is zerolike, return one of the input.
  130. if (filtered_Xs_.size() == 1 && Xs_.size() > 0) {
  131. return Xs_[0];
  132. }
  133. if (!has_zero_like_) {
  134. return nullptr;
  135. }
  136. auto fg = node->func_graph();
  137. auto addn = fg->NewCNode({NewValueNode(PrimAddN_)});
  138. auto make_tuple = fg->NewCNode(filtered_Xs_);
  139. return fg->NewCNode({addn, make_tuple});
  140. }
  141. void Visit(const CNodePtr &cnode) override {
  142. if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
  143. return;
  144. }
  145. auto &inputs = cnode->inputs();
  146. (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
  147. // {kPrimMakeTuple, X1, X2, ...}
  148. filtered_Xs_.push_back(NewValueNode(prim::kPrimMakeTuple));
  149. for (auto &x : Xs_) {
  150. if (!IsPrimitiveCNode(x, prim::kPrimZerosLikeTensor)) {
  151. filtered_Xs_.push_back(x);
  152. } else {
  153. has_zero_like_ = true;
  154. }
  155. }
  156. }
  157. void Reset() {
  158. Xs_.clear();
  159. filtered_Xs_.clear();
  160. has_zero_like_ = false;
  161. }
  162. private:
  163. ValuePtr PrimAddN_;
  164. std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
  165. bool has_zero_like_{false};
  166. };
  167. } // namespace irpass
  168. } // namespace opt
  169. } // namespace mindspore
  170. #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_