You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

arithmetic_simplify.h 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
  17. #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
  18. #include <vector>
  19. #include <memory>
  20. #include <algorithm>
  21. #include "optimizer/optimizer.h"
  22. #include "optimizer/irpass.h"
  23. #include "optimizer/irpass/prim_eliminate.h"
  24. #include "ir/visitor.h"
  25. #include "operator/ops.h"
  26. namespace mindspore {
  27. namespace opt {
  28. namespace irpass {
  29. // {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0}
  30. // {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1}
  31. class MultiplyByZeroOrOne : public AnfVisitor {
  32. public:
  33. MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {}
  34. ~MultiplyByZeroOrOne() override = default;
  35. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  36. Reset();
  37. AnfVisitor::Match(prim::kPrimScalarMul)(node);
  38. if (is_zero_) {
  39. return NewValueNode(zero_);
  40. }
  41. if (is_one_) {
  42. return x_;
  43. }
  44. return nullptr;
  45. }
  46. void Visit(const AnfNodePtr &node) override {
  47. if (is_one_ || node->isa<CNode>()) {
  48. x_ = node;
  49. return;
  50. }
  51. AnfVisitor::Visit(node);
  52. if (!is_one_) {
  53. x_ = node;
  54. }
  55. }
  56. void Visit(const ValueNodePtr &vnode) override {
  57. auto value = vnode->value();
  58. if (*value == *zero_) {
  59. is_zero_ = true;
  60. } else if (*value == *one_) {
  61. is_one_ = true;
  62. }
  63. }
  64. void Reset() {
  65. x_ = nullptr;
  66. is_one_ = false;
  67. is_zero_ = false;
  68. }
  69. private:
  70. bool is_zero_{false}, is_one_{false};
  71. ValuePtr zero_, one_;
  72. AnfNodePtr x_{nullptr};
  73. };
  74. // {prim::kPrimScalarAdd, X, 0}
  75. // {prim::kPrimScalarAdd, 0, X}
  76. class AddByZero : public AnfVisitor {
  77. public:
  78. AddByZero() : zero_(MakeValue(0)) {}
  79. ~AddByZero() override = default;
  80. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  81. Reset();
  82. AnfVisitor::Match(prim::kPrimScalarAdd)(node);
  83. if (is_zero_) {
  84. return x_;
  85. }
  86. return nullptr;
  87. }
  88. void Visit(const AnfNodePtr &node) override {
  89. if (node->isa<ValueNode>() && *GetValueNode(node) == *zero_) {
  90. is_zero_ = true;
  91. return;
  92. }
  93. x_ = node;
  94. }
  95. void Reset() {
  96. x_ = nullptr;
  97. is_zero_ = false;
  98. }
  99. private:
  100. bool is_zero_{false};
  101. ValuePtr zero_;
  102. AnfNodePtr x_{nullptr};
  103. };
  104. // {prim::kPrimTensorAdd, {PrimZerosLikeTensor, Y}, X},
  105. // {prim::kPrimTensorAdd, X, {PrimZerosLikeTensor, Y}}
  106. class TensorAddByZero : public AnfVisitor {
  107. public:
  108. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  109. Reset();
  110. AnfVisitor::Match(prim::kPrimTensorAdd)(node);
  111. if (is_zero_) {
  112. return x_;
  113. }
  114. return nullptr;
  115. }
  116. void Visit(const AnfNodePtr &node) override {
  117. if (IsPrimitive(node, prim::kPrimZerosLikeTensor)) {
  118. is_zero_ = true;
  119. return;
  120. }
  121. x_ = node;
  122. }
  123. void Reset() {
  124. x_ = nullptr;
  125. is_zero_ = false;
  126. }
  127. private:
  128. bool is_zero_{false};
  129. AnfNodePtr x_{nullptr};
  130. };
  131. // {PrimMomentum, {PrimZerosLikeTensor, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
  132. class OptUpdateZeroTensor : public AnfVisitor {
  133. public:
  134. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  135. if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) {
  136. return nullptr;
  137. }
  138. // {PrimMomentum, {...}, Y, Z, Xs}
  139. auto &inputs = node->cast<CNodePtr>()->inputs();
  140. if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLikeTensor)) {
  141. return nullptr;
  142. }
  143. auto y = inputs[2];
  144. auto z = inputs[3];
  145. // {PrimZerosLikeTensor, X}
  146. if (inputs[1]->cast<CNodePtr>()->size() != 2) {
  147. return nullptr;
  148. }
  149. // {prim::kPrimMakeTuple, Z, Y}
  150. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y});
  151. }
  152. };
  153. // {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
  154. // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
  155. class ConstantDuplicateMul : public AnfVisitor {
  156. public:
  157. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  158. Reset();
  159. // {prim::kPrimMul, Tensor1, {...}}
  160. AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node);
  161. if (vnode_ == nullptr || cnode_ == nullptr) {
  162. return nullptr;
  163. }
  164. auto tensor1 = vnode_;
  165. auto mul = cnode_;
  166. Reset();
  167. // {prim::kPrimMul, Tensor2, {...}}
  168. AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul);
  169. if (vnode_ == nullptr || cnode_ == nullptr) {
  170. return nullptr;
  171. }
  172. auto tensor2 = vnode_;
  173. auto cnode = cnode_;
  174. auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0));
  175. auto fg = node->func_graph();
  176. auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
  177. return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg);
  178. }
  179. void Visit(const AnfNodePtr &node) override {
  180. if (IsValueNode<tensor::Tensor>(node)) {
  181. vnode_ = node;
  182. }
  183. if (IsCNode(node)) {
  184. cnode_ = node->cast<CNodePtr>();
  185. }
  186. }
  187. void Reset() {
  188. vnode_ = nullptr;
  189. cnode_ = nullptr;
  190. }
  191. private:
  192. AnfNodePtr vnode_;
  193. CNodePtr cnode_;
  194. };
  195. class ArithmeticSimplify {
  196. public:
  197. ArithmeticSimplify()
  198. : multiply_by_zero_or_one_(),
  199. add_by_zero_(),
  200. tensor_add_by_zero_(),
  201. identity_(prim::kPrimIdentity),
  202. opt_update_zero_tensor_(),
  203. constant_duplicate_mul_() {
  204. eliminaters_.emplace_back(multiply_by_zero_or_one_);
  205. eliminaters_.emplace_back(add_by_zero_);
  206. eliminaters_.emplace_back(tensor_add_by_zero_);
  207. eliminaters_.emplace_back(identity_);
  208. eliminaters_.emplace_back(opt_update_zero_tensor_);
  209. eliminaters_.emplace_back(constant_duplicate_mul_);
  210. }
  211. ~ArithmeticSimplify() = default;
  212. AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
  213. AnfNodePtr new_node;
  214. for (auto &eliminater : eliminaters_) {
  215. new_node = eliminater(optimizer, node);
  216. if (new_node != nullptr) {
  217. return new_node;
  218. }
  219. }
  220. return nullptr;
  221. }
  222. private:
  223. MultiplyByZeroOrOne multiply_by_zero_or_one_;
  224. AddByZero add_by_zero_;
  225. TensorAddByZero tensor_add_by_zero_;
  226. PrimEliminater identity_;
  227. OptUpdateZeroTensor opt_update_zero_tensor_;
  228. ConstantDuplicateMul constant_duplicate_mul_;
  229. std::vector<TransformFuncType> eliminaters_{};
  230. };
  231. } // namespace irpass
  232. } // namespace opt
  233. } // namespace mindspore
  234. #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_