You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

arithmetic_simplify.h 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
  17. #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
  18. #include <algorithm>
  19. #include <memory>
  20. #include <vector>
  21. #include "ir/optimizer_caller.h"
  22. #include "ir/visitor.h"
  23. #include "operator/ops.h"
  24. #include "optimizer/irpass.h"
  25. #include "optimizer/irpass/prim_eliminate.h"
  26. #include "optimizer/optimizer.h"
  27. namespace mindspore {
  28. namespace opt {
  29. namespace irpass {
  30. // {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0}
  31. // {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1}
  32. class MultiplyByZeroOrOne : public AnfVisitor {
  33. public:
  34. MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {}
  35. ~MultiplyByZeroOrOne() override = default;
  36. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
  37. void Visit(const AnfNodePtr &node) override;
  38. void Visit(const ValueNodePtr &vnode) override;
  39. void Reset();
  40. private:
  41. bool is_zero_{false}, is_one_{false};
  42. ValuePtr zero_, one_;
  43. AnfNodePtr x_{nullptr};
  44. };
  45. // Support class used for checking if all values of a Tensor are equal `check_value_`
  46. // Supported data types: double, float/float32, int/int32
  47. class CheckTensorConstant {
  48. public:
  49. explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {}
  50. ~CheckTensorConstant() = default;
  51. bool IsTensorConstant(const ValuePtr &value);
  52. bool IsTensorScalarConstant(const ValuePtr &value);
  53. private:
  54. int check_value_;
  55. };
  56. class TensorMultiplyBase : public AnfVisitor {
  57. protected:
  58. void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false);
  59. // Make a new tensor (when possible) with the same shape as of `node`
  60. // If x is nullptr then fill new tensor will "0"
  61. // If x is a tensor with empty shape then fill new tensor with the single value of x
  62. // If x is a tensor with same shape as `node` then return x as result
  63. AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr);
  64. AnfNodePtr x_{nullptr};
  65. };
  66. // {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
  67. class TensorMultiplyByZero : public TensorMultiplyBase {
  68. public:
  69. TensorMultiplyByZero() : zero_(MakeValue(0)) {}
  70. ~TensorMultiplyByZero() override = default;
  71. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
  72. void Visit(const AnfNodePtr &node) override;
  73. void Visit(const ValueNodePtr &vnode) override;
  74. void Reset();
  75. private:
  76. bool is_zero_{false};
  77. ValuePtr zero_;
  78. };
  79. // {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
  80. class TensorMultiplyByOne : public TensorMultiplyBase {
  81. public:
  82. TensorMultiplyByOne() {}
  83. ~TensorMultiplyByOne() override = default;
  84. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
  85. void Visit(const AnfNodePtr &node) override;
  86. void Visit(const ValueNodePtr &vnode) override;
  87. void Reset();
  88. private:
  89. bool is_one_{false};
  90. };
  91. // {prim::kPrimScalarAdd, X, 0}
  92. // {prim::kPrimScalarAdd, 0, X}
  93. class AddByZero : public AnfVisitor {
  94. public:
  95. AddByZero() : zero_(MakeValue(0)) {}
  96. ~AddByZero() override = default;
  97. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
  98. void Visit(const AnfNodePtr &node) override;
  99. void Reset();
  100. private:
  101. bool is_zero_{false};
  102. ValuePtr zero_;
  103. AnfNodePtr x_{nullptr};
  104. };
  105. // {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
  106. // {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
  107. class TensorAddByZero : public AnfVisitor {
  108. public:
  109. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
  110. void Visit(const AnfNodePtr &node) override;
  111. void Visit(const ValueNodePtr &vnode) override;
  112. void Reset();
  113. private:
  114. bool is_zero_{false};
  115. AnfNodePtr x_{nullptr};
  116. };
  117. // {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
  118. class OptUpdateZeroTensor : public AnfVisitor {
  119. public:
  120. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
  121. };
  122. // {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
  123. // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
  124. class ConstantDuplicateMul : public AnfVisitor {
  125. public:
  126. // Support function to multiply two constant tensors: partially support broadcasting shapes
  127. template <typename T>
  128. void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
  129. int out_data_size);
  130. AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3);
  131. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
  132. void Visit(const AnfNodePtr &node) override;
  133. void Reset();
  134. private:
  135. AnfNodePtr vnode_;
  136. AnfNodePtr c_p_node_;
  137. };
  138. class PowerOneEliminate : public AnfVisitor {
  139. public:
  140. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
  141. };
  142. // grad = AllReduce(grad) / worker_number
  143. // grad = grad + weight * decy
  144. // ->
  145. // grad = grad + weight * decy
  146. // grad = AllReduce(grad) / worker_number
  147. // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
  148. // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
  149. class AdjustAllReduceMulAdd : public AnfVisitor {
  150. public:
  151. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
  152. void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node);
  153. void Visit(const AnfNodePtr &node) override;
  154. void Reset();
  155. private:
  156. int level_{0};
  157. bool is_reduce_match_{false};
  158. AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
  159. AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr};
  160. FuncGraphPtr all_reduce_fg_{nullptr};
  161. };
  162. class ArithmeticSimplify : public OptimizerCaller {
  163. public:
  164. ArithmeticSimplify()
  165. : multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()),
  166. tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()),
  167. add_by_zero_(std::make_shared<AddByZero>()),
  168. tensor_add_by_zero_(std::make_shared<TensorAddByZero>()),
  169. identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)),
  170. opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()),
  171. constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()),
  172. power_one_(std::make_shared<PowerOneEliminate>()) {
  173. eliminaters_.emplace_back(multiply_by_zero_or_one_);
  174. eliminaters_.emplace_back(tensor_multiply_by_one_);
  175. eliminaters_.emplace_back(add_by_zero_);
  176. eliminaters_.emplace_back(tensor_add_by_zero_);
  177. eliminaters_.emplace_back(identity_);
  178. eliminaters_.emplace_back(opt_update_zero_tensor_);
  179. eliminaters_.emplace_back(constant_duplicate_mul_);
  180. eliminaters_.emplace_back(power_one_);
  181. }
  182. ~ArithmeticSimplify() = default;
  183. AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
  184. private:
  185. OptimizerCallerPtr multiply_by_zero_or_one_;
  186. OptimizerCallerPtr tensor_multiply_by_one_;
  187. OptimizerCallerPtr add_by_zero_;
  188. OptimizerCallerPtr tensor_add_by_zero_;
  189. OptimizerCallerPtr identity_;
  190. OptimizerCallerPtr opt_update_zero_tensor_;
  191. OptimizerCallerPtr constant_duplicate_mul_;
  192. OptimizerCallerPtr power_one_;
  193. std::vector<OptimizerCallerPtr> eliminaters_{};
  194. };
  195. // Arithmetic Simplifications should be done after step_parallel.
  196. // eg: Mul(0, weight) where weight is a parameter will be simplified to a constant tensor
  197. // with shape(weight), but after step_parallel, shape of weight may be changed, so the
  198. // shape of the constant tensor should also be changed. So this pass is seperated from
  199. // ArithmeticSimplify and deferred until step_parallel.
  200. class ArithmeticSimplify2 : public OptimizerCaller {
  201. public:
  202. ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) {
  203. eliminaters_.emplace_back(tensor_multiply_by_zero_);
  204. }
  205. ~ArithmeticSimplify2() = default;
  206. AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
  207. private:
  208. OptimizerCallerPtr tensor_multiply_by_zero_;
  209. std::vector<OptimizerCallerPtr> eliminaters_{};
  210. };
  211. } // namespace irpass
  212. } // namespace opt
  213. } // namespace mindspore
  214. #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_