| @@ -248,17 +248,18 @@ class AdjustAllReduceMulAdd : public AnfVisitor { | |||||
| if (addn->size() != 2) { | if (addn->size() != 2) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); | AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); | ||||
| if (x_ == nullptr || y_ == nullptr || z_ == nullptr) { | if (x_ == nullptr || y_ == nullptr || z_ == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto addn_op_node = addn->input(0); | |||||
| auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0); | |||||
| auto fg = node->func_graph(); | auto fg = node->func_graph(); | ||||
| AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg); | |||||
| AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg); | |||||
| AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg); | |||||
| return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg); | |||||
| AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); | |||||
| AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); | |||||
| AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); | |||||
| return NewCNode({mul_, all_reduce, y_}, fg); | |||||
| } | } | ||||
| void Visit(const AnfNodePtr &node) override { | void Visit(const AnfNodePtr &node) override { | ||||
| @@ -269,6 +270,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { | |||||
| AnfVisitor::Match(prim::kPrimMul)(node); | AnfVisitor::Match(prim::kPrimMul)(node); | ||||
| level_ = 0; | level_ = 0; | ||||
| if (is_reduce_match_) { | if (is_reduce_match_) { | ||||
| mul_ = node->cast<CNodePtr>()->input(0); | |||||
| y_ = tmp_; | y_ = tmp_; | ||||
| } else { | } else { | ||||
| z_ = node; | z_ = node; | ||||
| @@ -280,6 +282,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { | |||||
| if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { | if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| if (cnode->size() > 1) { | if (cnode->size() > 1) { | ||||
| all_reduce_ = cnode->input(0); | |||||
| x_ = cnode->input(1); | x_ = cnode->input(1); | ||||
| is_reduce_match_ = true; | is_reduce_match_ = true; | ||||
| } | } | ||||
| @@ -302,6 +305,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { | |||||
| int level_{0}; | int level_{0}; | ||||
| bool is_reduce_match_{false}; | bool is_reduce_match_{false}; | ||||
| AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; | AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; | ||||
| AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}; | |||||
| }; | }; | ||||
| class ArithmeticSimplify { | class ArithmeticSimplify { | ||||