|
|
|
@@ -248,17 +248,18 @@ class AdjustAllReduceMulAdd : public AnfVisitor { |
|
|
|
if (addn->size() != 2) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); |
|
|
|
if (x_ == nullptr || y_ == nullptr || z_ == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto addn_op_node = addn->input(0); |
|
|
|
auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0); |
|
|
|
auto fg = node->func_graph(); |
|
|
|
AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg); |
|
|
|
AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg); |
|
|
|
AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg); |
|
|
|
return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg); |
|
|
|
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); |
|
|
|
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); |
|
|
|
AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); |
|
|
|
return NewCNode({mul_, all_reduce, y_}, fg); |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override { |
|
|
|
@@ -269,6 +270,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { |
|
|
|
AnfVisitor::Match(prim::kPrimMul)(node); |
|
|
|
level_ = 0; |
|
|
|
if (is_reduce_match_) { |
|
|
|
mul_ = node->cast<CNodePtr>()->input(0); |
|
|
|
y_ = tmp_; |
|
|
|
} else { |
|
|
|
z_ = node; |
|
|
|
@@ -280,6 +282,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (cnode->size() > 1) { |
|
|
|
all_reduce_ = cnode->input(0); |
|
|
|
x_ = cnode->input(1); |
|
|
|
is_reduce_match_ = true; |
|
|
|
} |
|
|
|
@@ -302,6 +305,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { |
|
|
|
int level_{0}; |
|
|
|
bool is_reduce_match_{false}; |
|
|
|
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; |
|
|
|
AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}; |
|
|
|
}; |
|
|
|
|
|
|
|
class ArithmeticSimplify { |
|
|
|
|