You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

arithmetic_simplify.h 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
  17. #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
  18. #include <vector>
  19. #include <memory>
  20. #include <algorithm>
  21. #include "optimizer/optimizer.h"
  22. #include "optimizer/irpass.h"
  23. #include "optimizer/irpass/prim_eliminate.h"
  24. #include "ir/visitor.h"
  25. #include "operator/ops.h"
  26. namespace mindspore {
  27. namespace opt {
  28. namespace irpass {
  29. // {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0}
  30. // {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1}
  31. class MultiplyByZeroOrOne : public AnfVisitor {
  32. public:
  33. MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {}
  34. ~MultiplyByZeroOrOne() override = default;
  35. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  36. Reset();
  37. AnfVisitor::Match(prim::kPrimScalarMul)(node);
  38. if (is_zero_) {
  39. return NewValueNode(zero_);
  40. }
  41. if (is_one_) {
  42. return x_;
  43. }
  44. return nullptr;
  45. }
  46. void Visit(const AnfNodePtr &node) override {
  47. if (is_one_ || node->isa<CNode>()) {
  48. x_ = node;
  49. return;
  50. }
  51. AnfVisitor::Visit(node);
  52. if (!is_one_) {
  53. x_ = node;
  54. }
  55. }
  56. void Visit(const ValueNodePtr &vnode) override {
  57. auto value = vnode->value();
  58. if (*value == *zero_) {
  59. is_zero_ = true;
  60. } else if (*value == *one_) {
  61. is_one_ = true;
  62. }
  63. }
  64. void Reset() {
  65. x_ = nullptr;
  66. is_one_ = false;
  67. is_zero_ = false;
  68. }
  69. private:
  70. bool is_zero_{false}, is_one_{false};
  71. ValuePtr zero_, one_;
  72. AnfNodePtr x_{nullptr};
  73. };
  74. // {prim::kPrimScalarAdd, X, 0}
  75. // {prim::kPrimScalarAdd, 0, X}
  76. class AddByZero : public AnfVisitor {
  77. public:
  78. AddByZero() : zero_(MakeValue(0)) {}
  79. ~AddByZero() override = default;
  80. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  81. Reset();
  82. AnfVisitor::Match(prim::kPrimScalarAdd)(node);
  83. if (is_zero_) {
  84. return x_;
  85. }
  86. return nullptr;
  87. }
  88. void Visit(const AnfNodePtr &node) override {
  89. if (node->isa<ValueNode>() && *GetValueNode(node) == *zero_) {
  90. is_zero_ = true;
  91. return;
  92. }
  93. x_ = node;
  94. }
  95. void Reset() {
  96. x_ = nullptr;
  97. is_zero_ = false;
  98. }
  99. private:
  100. bool is_zero_{false};
  101. ValuePtr zero_;
  102. AnfNodePtr x_{nullptr};
  103. };
  104. // {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
  105. // {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
  106. class TensorAddByZero : public AnfVisitor {
  107. public:
  108. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  109. Reset();
  110. AnfVisitor::Match(prim::kPrimTensorAdd)(node);
  111. if (is_zero_) {
  112. return x_;
  113. }
  114. return nullptr;
  115. }
  116. void Visit(const AnfNodePtr &node) override {
  117. if (IsPrimitive(node, prim::kPrimZerosLike)) {
  118. is_zero_ = true;
  119. return;
  120. }
  121. x_ = node;
  122. }
  123. void Reset() {
  124. x_ = nullptr;
  125. is_zero_ = false;
  126. }
  127. private:
  128. bool is_zero_{false};
  129. AnfNodePtr x_{nullptr};
  130. };
  131. // {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
  132. class OptUpdateZeroTensor : public AnfVisitor {
  133. public:
  134. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  135. if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) {
  136. return nullptr;
  137. }
  138. // {PrimMomentum, {...}, Y, Z, Xs}
  139. auto &inputs = node->cast<CNodePtr>()->inputs();
  140. if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) {
  141. return nullptr;
  142. }
  143. auto y = inputs[2];
  144. auto z = inputs[3];
  145. // {kPrimZerosLike, X}
  146. if (inputs[1]->cast<CNodePtr>()->size() != 2) {
  147. return nullptr;
  148. }
  149. // {prim::kPrimMakeTuple, Z, Y}
  150. return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y});
  151. }
  152. };
  153. // {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
  154. // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
  155. class ConstantDuplicateMul : public AnfVisitor {
  156. public:
  157. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  158. Reset();
  159. // {prim::kPrimMul, Tensor1, {...}}
  160. AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node);
  161. if (vnode_ == nullptr || cnode_ == nullptr) {
  162. return nullptr;
  163. }
  164. auto tensor1 = vnode_;
  165. auto mul = cnode_;
  166. Reset();
  167. // {prim::kPrimMul, Tensor2, {...}}
  168. AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul);
  169. if (vnode_ == nullptr || cnode_ == nullptr) {
  170. return nullptr;
  171. }
  172. auto tensor2 = vnode_;
  173. auto cnode = cnode_;
  174. auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0));
  175. auto fg = node->func_graph();
  176. auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
  177. return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg);
  178. }
  179. void Visit(const AnfNodePtr &node) override {
  180. if (IsValueNode<tensor::Tensor>(node)) {
  181. vnode_ = node;
  182. }
  183. if (IsCNode(node)) {
  184. cnode_ = node->cast<CNodePtr>();
  185. }
  186. }
  187. void Reset() {
  188. vnode_ = nullptr;
  189. cnode_ = nullptr;
  190. }
  191. private:
  192. AnfNodePtr vnode_;
  193. CNodePtr cnode_;
  194. };
  195. // grad = AllReduce(grad) / worker_number
  196. // grad = grad + weight * decy
  197. // ->
  198. // grad = grad + weight * decy
  199. // grad = AllReduce(grad) / worker_number
  200. // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
  201. // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
  202. class AdjustAllReduceMulAdd : public AnfVisitor {
  203. public:
  204. AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
  205. Reset();
  206. // {prim::kPrimAddN, Zs}
  207. if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
  208. return nullptr;
  209. }
  210. auto addn = node->cast<CNodePtr>();
  211. if (addn->size() != 2) {
  212. return nullptr;
  213. }
  214. AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
  215. if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) {
  216. return nullptr;
  217. }
  218. auto addn_maketuple = addn->input(1);
  219. auto fg = all_reduce_fg_;
  220. // addn inputs cross the graph, make the inputs same as allreduce node.
  221. if (z_->isa<CNode>() && fg != z_->func_graph()) {
  222. auto cnode_z = z_->cast<CNodePtr>();
  223. z_ = NewCNode(cnode_z->inputs(), fg);
  224. }
  225. auto addn_op_node = addn->input(0);
  226. auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
  227. AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
  228. AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
  229. AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
  230. AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg);
  231. ProcessDependEdge(fg, addn_maketuple, all_reduce);
  232. return mul;
  233. }
  234. void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node) {
  235. // If has dynamic loss scale.
  236. auto &users_map = fg->manager()->node_users();
  237. auto it = users_map.find(mul_cnode_);
  238. if (it != users_map.end()) {
  239. auto users = it->second;
  240. for (auto &user_pair : users) {
  241. auto node = user_pair.first;
  242. if (node != addn_maketuple) {
  243. if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
  244. fg->manager()->SetEdge(node, user_pair.second, new_node);
  245. }
  246. }
  247. }
  248. }
  249. }
  250. void Visit(const AnfNodePtr &node) override {
  251. if (level_ == 0) {
  252. level_ = 1;
  253. is_reduce_match_ = false;
  254. // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
  255. AnfVisitor::Match(prim::kPrimMul)(node);
  256. level_ = 0;
  257. if (is_reduce_match_) {
  258. mul_ = node->cast<CNodePtr>()->input(0);
  259. mul_cnode_ = node->cast<CNodePtr>();
  260. y_ = tmp_;
  261. } else {
  262. z_ = node;
  263. }
  264. }
  265. if (level_ == 1) {
  266. // {prim::kPrimAllReduce, X}
  267. if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
  268. auto cnode = node->cast<CNodePtr>();
  269. if (cnode->size() > 1) {
  270. all_reduce_ = cnode->input(0);
  271. x_ = cnode->input(1);
  272. is_reduce_match_ = true;
  273. all_reduce_fg_ = cnode->func_graph();
  274. }
  275. } else {
  276. tmp_ = node;
  277. }
  278. }
  279. }
  280. void Reset() {
  281. level_ = 0;
  282. is_reduce_match_ = false;
  283. x_ = nullptr;
  284. y_ = nullptr;
  285. z_ = nullptr;
  286. tmp_ = nullptr;
  287. all_reduce_fg_ = nullptr;
  288. }
  289. private:
  290. int level_{0};
  291. bool is_reduce_match_{false};
  292. AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
  293. AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr};
  294. FuncGraphPtr all_reduce_fg_{nullptr};
  295. };
  296. class ArithmeticSimplify {
  297. public:
  298. ArithmeticSimplify()
  299. : multiply_by_zero_or_one_(),
  300. add_by_zero_(),
  301. tensor_add_by_zero_(),
  302. identity_(prim::kPrimIdentity),
  303. opt_update_zero_tensor_(),
  304. constant_duplicate_mul_() {
  305. eliminaters_.emplace_back(multiply_by_zero_or_one_);
  306. eliminaters_.emplace_back(add_by_zero_);
  307. eliminaters_.emplace_back(tensor_add_by_zero_);
  308. eliminaters_.emplace_back(identity_);
  309. eliminaters_.emplace_back(opt_update_zero_tensor_);
  310. eliminaters_.emplace_back(constant_duplicate_mul_);
  311. }
  312. ~ArithmeticSimplify() = default;
  313. AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
  314. AnfNodePtr new_node;
  315. for (auto &eliminater : eliminaters_) {
  316. new_node = eliminater(optimizer, node);
  317. if (new_node != nullptr) {
  318. return new_node;
  319. }
  320. }
  321. return nullptr;
  322. }
  323. private:
  324. MultiplyByZeroOrOne multiply_by_zero_or_one_;
  325. AddByZero add_by_zero_;
  326. TensorAddByZero tensor_add_by_zero_;
  327. PrimEliminater identity_;
  328. OptUpdateZeroTensor opt_update_zero_tensor_;
  329. ConstantDuplicateMul constant_duplicate_mul_;
  330. std::vector<TransformFuncType> eliminaters_{};
  331. };
  332. } // namespace irpass
  333. } // namespace opt
  334. } // namespace mindspore
  335. #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_