You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

branch_culling.h 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_
  17. #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_
  18. #include <vector>
  19. #include <algorithm>
  20. #include "optimizer/optimizer.h"
  21. #include "optimizer/irpass.h"
  22. #include "ir/visitor.h"
  23. #include "ir/func_graph.h"
  24. #include "ir/func_graph_cloner.h"
  25. #include "operator/ops.h"
  26. namespace mindspore {
  27. namespace opt {
  28. namespace irpass {
  29. // {prim::kPrimSwitch, true, X, Y}
  30. // {prim::kPrimSwitch, false, X, Y}
  31. class SwitchSimplify : public AnfVisitor {
  32. public:
  33. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  34. Reset();
  35. auto getx = [this](const AnfNodePtr &node) -> bool {
  36. this->x_ = node;
  37. return true;
  38. };
  39. auto gety = [this](const AnfNodePtr &node) -> bool {
  40. this->y_ = node;
  41. return true;
  42. };
  43. AnfVisitor::Match(prim::kPrimSwitch, {IsValueNode<BoolImm>, getx, gety})(node);
  44. // simplify the switch
  45. if (is_match_) {
  46. if (cond_) {
  47. return x_;
  48. }
  49. return y_;
  50. }
  51. return nullptr;
  52. }
  53. void Visit(const AnfNodePtr &node) override {
  54. if (!is_match_ && IsValueNode<BoolImm>(node)) {
  55. cond_ = GetValue<bool>(GetValueNode(node));
  56. is_match_ = true;
  57. }
  58. }
  59. void Reset() {
  60. x_ = nullptr;
  61. y_ = nullptr;
  62. cond_ = false;
  63. is_match_ = false;
  64. }
  65. private:
  66. bool is_match_{false}, cond_{false};
  67. AnfNodePtr x_{nullptr}, y_{nullptr};
  68. };
  69. // {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} =>
  70. // {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
  71. class FloatTupleGetItemSwitch : public AnfVisitor {
  72. public:
  73. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  74. Reset();
  75. AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node);
  76. auto fg = node->func_graph();
  77. if (Xs_.empty() || c_ == nullptr || fg == nullptr) {
  78. return nullptr;
  79. }
  80. auto true_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), Xs_[1], c_});
  81. auto false_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), Xs_[2], c_});
  82. return fg->NewCNode({NewValueNode(prim::kPrimSwitch), Xs_[0], true_node, false_node});
  83. }
  84. void Visit(const CNodePtr &cnode) override {
  85. // {prim::kPrimSwith, X1, X2, X3}
  86. if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch) || cnode->size() != 4) {
  87. return;
  88. }
  89. // copy X1, X2, X3
  90. auto &inputs = cnode->inputs();
  91. (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
  92. }
  93. void Visit(const ValueNodePtr &vnode) override { c_ = vnode; }
  94. void Reset() {
  95. Xs_.clear();
  96. c_ = nullptr;
  97. }
  98. private:
  99. AnfNodePtr c_{nullptr};
  100. std::vector<AnfNodePtr> Xs_{};
  101. };
  102. // {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
  103. // {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}}
  104. class FloatEnvGetItemSwitch : public AnfVisitor {
  105. public:
  106. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  107. is_match_ = false;
  108. AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsNode, IsNode})(node);
  109. if (!is_match_) {
  110. return nullptr;
  111. }
  112. // {prim::kPrimEnvGetItem, {...}, X4, X5}
  113. auto cnode = node->cast<CNodePtr>();
  114. auto sw_node = cnode->input(1)->cast<CNodePtr>();
  115. auto x4 = cnode->input(2);
  116. auto x5 = cnode->input(3);
  117. is_match_ = false;
  118. AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsNode, IsNode})(sw_node);
  119. if (!is_match_) {
  120. return nullptr;
  121. }
  122. // {prim::kPrimSwitch, X1, X2, X3}
  123. auto x1 = sw_node->input(1);
  124. auto x2 = sw_node->input(2);
  125. auto x3 = sw_node->input(3);
  126. auto fg = node->func_graph();
  127. if (fg == nullptr) {
  128. return nullptr;
  129. }
  130. auto true_node = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x2, x4, x5});
  131. auto false_node = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x3, x4, x5});
  132. return fg->NewCNode({NewValueNode(prim::kPrimSwitch), x1, true_node, false_node});
  133. }
  134. void Visit(const AnfNodePtr &) override { is_match_ = true; }
  135. private:
  136. bool is_match_{false};
  137. };
  138. namespace internal {
  139. FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond);
  140. FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond);
  141. AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node,
  142. const AbstractBasePtr &true_graph_output_abs,
  143. const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond);
  144. } // namespace internal
  145. // {{prim::kPrimSwitch, X, G1, G2}, Xs}
  146. class ConvertSwitchReplacement : public AnfVisitor {
  147. public:
  148. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  149. if (!node->isa<CNode>() || node->func_graph() == nullptr) {
  150. return nullptr;
  151. }
  152. Reset();
  153. auto cnode = node->cast<CNodePtr>();
  154. if (cnode->size() < 1) {
  155. return nullptr;
  156. }
  157. // {prim::kPrimSwitch, X, G1, G2}
  158. AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode<FuncGraph>, IsValueNode<FuncGraph>})(cnode->input(0));
  159. if (g2_ == nullptr || g1_->output() == nullptr || g2_->output() == nullptr) {
  160. return nullptr;
  161. }
  162. auto true_output = g1_->output()->abstract();
  163. auto false_output = g2_->output()->abstract();
  164. auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_);
  165. auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_);
  166. std::vector<AnfNodePtr> params;
  167. auto fg = node->func_graph();
  168. auto cloned_g1 = InlineClone(trans_g1, fg, params);
  169. auto cloned_g2 = InlineClone(trans_g2, fg, params);
  170. return internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_);
  171. }
  172. void Visit(const AnfNodePtr &node) override {
  173. if (x_ == nullptr) {
  174. x_ = node;
  175. return;
  176. }
  177. AnfVisitor::Visit(node);
  178. }
  179. void Visit(const ValueNodePtr &vnode) override {
  180. auto g = GetValueNode<FuncGraphPtr>(vnode);
  181. if (g1_ == nullptr) {
  182. g1_ = g;
  183. } else {
  184. g2_ = g;
  185. }
  186. }
  187. void Reset() {
  188. x_ = nullptr;
  189. g1_ = nullptr;
  190. g2_ = nullptr;
  191. }
  192. private:
  193. AnfNodePtr x_{nullptr};
  194. FuncGraphPtr g1_{nullptr}, g2_{nullptr};
  195. };
  196. } // namespace irpass
  197. } // namespace opt
  198. } // namespace mindspore
  199. #endif // #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_