diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index 0d48fc1463..ff6e4f6170 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -248,17 +248,18 @@ class AdjustAllReduceMulAdd : public AnfVisitor { if (addn->size() != 2) { return nullptr; } - AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); if (x_ == nullptr || y_ == nullptr || z_ == nullptr) { return nullptr; } + auto addn_op_node = addn->input(0); + auto make_tuple_op_node = addn->input(1)->cast()->input(0); auto fg = node->func_graph(); - AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg); - AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg); - AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg); - return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg); + AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); + AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); + AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); + return NewCNode({mul_, all_reduce, y_}, fg); } void Visit(const AnfNodePtr &node) override { @@ -269,6 +270,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { AnfVisitor::Match(prim::kPrimMul)(node); level_ = 0; if (is_reduce_match_) { + mul_ = node->cast()->input(0); y_ = tmp_; } else { z_ = node; @@ -280,6 +282,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { auto cnode = node->cast(); if (cnode->size() > 1) { + all_reduce_ = cnode->input(0); x_ = cnode->input(1); is_reduce_match_ = true; } @@ -302,6 +305,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { int level_{0}; bool is_reduce_match_{false}; AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; + AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}; }; class ArithmeticSimplify {