You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

arithmetic_simplify.cc 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/optimizer/irpass/arithmetic_simplify.h"
  17. namespace mindspore {
  18. namespace opt {
  19. namespace irpass {
  20. AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  21. PatternNode x, y, z;
  22. PConstant one_(node, false, 1);
  23. PConstant one_scalar_(node, false, 1, true);
  24. PConstant zero_(node, false, 0);
  25. PConstant zero_scalar_(node, false, 0, true);
  26. PConstant const_(node);
  27. PConstant const_2(node);
  28. PConstant any_const(node);
  29. // if node has keep_alive attr, it would not be eliminated.
  30. if (node->isa<CNode>()) {
  31. auto cnode = node->cast<CNodePtr>();
  32. auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
  33. if (prim->HasAttr("keep_alive") && GetValue<bool>(prim->GetAttr("keep_alive"))) {
  34. MS_LOG(INFO) << "keep node " << node->fullname_with_scope() << " alive";
  35. return nullptr;
  36. }
  37. }
  38. if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
  39. MATCH_REPLACE(node, x + zero_, x); // Add by zero
  40. MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero
  41. MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero
  42. MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one
  43. MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x); // Scalar Mul by one
  44. // Scalar Mul by zero
  45. MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue());
  46. }
  47. // Prim Eliminate (identity)
  48. MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);
  49. if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
  50. return nullptr;
  51. }
  52. // ConstantDuplicateMul
  53. auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr {
  54. auto new_mul_tensor = const_.MulByPatternConst(const_2, x.GetNode(node));
  55. auto mul_node = node->cast<CNodePtr>()->inputs()[0];
  56. if (new_mul_tensor == nullptr) {
  57. auto ttmul = NewCNode({mul_node, const_.GetNode(node), const_2.GetNode(node)}, node->func_graph());
  58. return NewCNode({mul_node, x.GetNode(node), ttmul}, node->func_graph());
  59. }
  60. auto new_cnode = NewCNode({mul_node, x.GetNode(node), new_mul_tensor}, node->func_graph());
  61. new_cnode->set_abstract(node->abstract());
  62. return new_cnode;
  63. };
  64. MATCH_REPLACE_LAMBDA(node, const_ * (const_2 * x), const_dup_lambda);
  65. if (node->func_graph() == nullptr) {
  66. return nullptr;
  67. }
  68. // OptUpdateZeroTensor: {kPrimMomentum, {kPrimZerosLike, x}, y, z, xs} -> {kPrimMakeTuple, z, y}
  69. MATCH_REPLACE(node, PPrimitive(prim::kPrimMomentum, PPrimitive(prim::kPrimZerosLike, x), y, z).MinExtraNodes(0),
  70. PPrimitive(prim::kPrimMakeTuple, z, y));
  71. // PowerOneEliminate
  72. MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimPow, x, one_scalar_), x,
  73. one_scalar_.CheckFunc(IsValueNode<Scalar>, node));
  74. return nullptr;
  75. }
  76. AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  77. if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
  78. return nullptr;
  79. }
  80. PatternNode x, y;
  81. PConstant zero_(node, false, 0);
  82. // Multiply by zero
  83. MATCH_REPLACE_IF(node, x * zero_, zero_.WithShapeAs(node),
  84. !zero_.CheckFunc(IsParam, node) && !x.CheckFunc(IsLoad, node) &&
  85. x.GetNode(node)->func_graph() == node->func_graph());
  86. auto zero_prim = PPrimitive(prim::kPrimZerosLike, y);
  87. MATCH_REPLACE_IF(node, x * zero_prim, zero_.WithShapeAs(node),
  88. !zero_prim.CheckFunc(IsParam, node) && x.GetNode(node)->func_graph() == node->func_graph());
  89. return nullptr;
  90. }
  91. // grad = AllReduce(grad) / worker_number
  92. // grad = grad + weight * decy
  93. // ->
  94. // grad = grad + weight * decy
  95. // grad = AllReduce(grad) / worker_number
  96. // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
  97. // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
  98. AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
  99. PatternNode x, y, z;
  100. auto all_reduce_pat = PPrimitive(prim::kPrimAllReduce, x);
  101. auto mul_pat = PBinOperation(prim::kPrimMul, all_reduce_pat, y, true);
  102. auto admktup_pat = PBinOperation(prim::kPrimMakeTuple, mul_pat, z, true);
  103. auto addn_pat = PPrimitive(prim::kPrimAddN, admktup_pat);
  104. auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr {
  105. auto fg = all_reduce_pat.GetFuncGraph();
  106. auto z_ = z.GetNode(node);
  107. auto x_ = x.GetNode(node);
  108. // If addn inputs cross the graph, make the inputs same as allreduce node.
  109. if (z_->isa<CNode>() && fg != z_->func_graph()) {
  110. auto cnode_z = z_->cast<CNodePtr>();
  111. z_ = NewCNode(cnode_z->inputs(), fg);
  112. }
  113. auto addn_cnode = addn_pat.GetOriginalNode()->cast<CNodePtr>();
  114. auto addn_op_node = addn_cnode->input(0);
  115. auto make_tuple_op_node = addn_cnode->input(1)->cast<CNodePtr>()->input(0);
  116. auto all_reduce_prim = all_reduce_pat.GetOriginalNode()->cast<CNodePtr>()->input(0);
  117. mul_cnode_ = mul_pat.GetOriginalNode();
  118. auto mul_prim = mul_cnode_->cast<CNodePtr>()->input(0);
  119. auto addn_maketuple = admktup_pat.GetOriginalNode();
  120. ShapeVector x_shape, z_shape;
  121. if (!x_->isa<ValueNode>()) {
  122. if ((x_->abstract() == nullptr) || !x_->abstract()->isa<abstract::AbstractTensor>()) {
  123. return nullptr;
  124. }
  125. auto x_abstract = x_->abstract()->cast<abstract::AbstractTensorPtr>();
  126. x_shape = x_abstract->shape()->shape();
  127. } else {
  128. ValuePtr x_value = x_->cast<ValueNodePtr>()->value();
  129. if (!x_value->isa<tensor::Tensor>()) {
  130. return nullptr;
  131. }
  132. auto x_tensor = GetValueNode<tensor::TensorPtr>(x_->cast<ValueNodePtr>());
  133. x_shape = x_tensor->shape();
  134. }
  135. if (!z_->isa<ValueNode>()) {
  136. if ((z_->abstract() == nullptr) || !z_->abstract()->isa<abstract::AbstractTensor>()) {
  137. return nullptr;
  138. }
  139. auto z_abstract = z_->abstract()->cast<abstract::AbstractTensorPtr>();
  140. z_shape = z_abstract->shape()->shape();
  141. } else {
  142. ValuePtr z_value = z_->cast<ValueNodePtr>()->value();
  143. if (!z_value->isa<tensor::Tensor>()) {
  144. return nullptr;
  145. }
  146. auto z_tensor = GetValueNode<tensor::TensorPtr>(z_->cast<ValueNodePtr>());
  147. z_shape = z_tensor->shape();
  148. }
  149. if (x_shape != z_shape) {
  150. // AddN requires x_ and z_ have the same shape.
  151. // If broadcasting TensorAdd is supported then can use this
  152. return nullptr;
  153. }
  154. AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
  155. AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
  156. AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg);
  157. AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg);
  158. ProcessDependEdge(fg, addn_maketuple, all_reduce);
  159. return mul;
  160. };
  161. MATCH_REPLACE_LAMBDA(node, addn_pat, adjust_lambda);
  162. return nullptr;
  163. }
  164. void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple,
  165. const AnfNodePtr &new_node) {
  166. // If has dynamic loss scale.
  167. MS_EXCEPTION_IF_NULL(fg);
  168. auto manager = fg->manager();
  169. MS_EXCEPTION_IF_NULL(manager);
  170. auto &users_map = manager->node_users();
  171. auto it = users_map.find(mul_cnode_);
  172. if (it != users_map.end()) {
  173. auto users = it->second;
  174. for (auto &user_pair : users) {
  175. auto node = user_pair.first;
  176. if (node != addn_maketuple && IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
  177. manager->SetEdge(node, user_pair.second, new_node);
  178. }
  179. }
  180. }
  181. }
  182. } // namespace irpass
  183. } // namespace opt
  184. } // namespace mindspore