Merge pull request !3696 from Giancarlo/update_adjust_allreducetags/v0.7.0-beta
| @@ -95,37 +95,37 @@ AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePt | |||||
| // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> | // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> | ||||
| // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} | // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} | ||||
| AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | ||||
| Reset(); | |||||
| // {prim::kPrimAddN, Zs} | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto addn = node->cast<CNodePtr>(); | |||||
| if (addn->size() != 2) { | |||||
| return nullptr; | |||||
| } | |||||
| AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); | |||||
| if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto addn_maketuple = addn->input(1); | |||||
| auto fg = all_reduce_fg_; | |||||
| // addn inputs cross the graph, make the inputs same as allreduce node. | |||||
| if (z_->isa<CNode>() && fg != z_->func_graph()) { | |||||
| auto cnode_z = z_->cast<CNodePtr>(); | |||||
| z_ = NewCNode(cnode_z->inputs(), fg); | |||||
| } | |||||
| auto addn_op_node = addn->input(0); | |||||
| auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0); | |||||
| PatternNode x, y, z; | |||||
| auto all_reduce_pat = PPrimitive(prim::kPrimAllReduce, x); | |||||
| auto mul_pat = PBinOperation(prim::kPrimMul, all_reduce_pat, y, true); | |||||
| auto admktup_pat = PBinOperation(prim::kPrimMakeTuple, mul_pat, z, true); | |||||
| auto addn_pat = PPrimitive(prim::kPrimAddN, admktup_pat); | |||||
| auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr { | |||||
| auto fg = all_reduce_pat.GetFuncGraph(); | |||||
| auto z_ = z.GetNode(node); | |||||
| // If addn inputs cross the graph, make the inputs same as allreduce node. | |||||
| if (z_->isa<CNode>() && fg != z_->func_graph()) { | |||||
| auto cnode_z = z_->cast<CNodePtr>(); | |||||
| z_ = NewCNode(cnode_z->inputs(), fg); | |||||
| } | |||||
| AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); | |||||
| AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); | |||||
| AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); | |||||
| AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); | |||||
| ProcessDependEdge(fg, addn_maketuple, all_reduce); | |||||
| return mul; | |||||
| auto addn_cnode = addn_pat.GetOriginalNode()->cast<CNodePtr>(); | |||||
| auto addn_op_node = addn_cnode->input(0); | |||||
| auto make_tuple_op_node = addn_cnode->input(1)->cast<CNodePtr>()->input(0); | |||||
| auto all_reduce_prim = all_reduce_pat.GetOriginalNode()->cast<CNodePtr>()->input(0); | |||||
| mul_cnode_ = mul_pat.GetOriginalNode(); | |||||
| auto mul_prim = mul_cnode_->cast<CNodePtr>()->input(0); | |||||
| auto addn_maketuple = admktup_pat.GetOriginalNode(); | |||||
| AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x.GetNode(node)}, fg); | |||||
| AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); | |||||
| AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg); | |||||
| AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg); | |||||
| ProcessDependEdge(fg, addn_maketuple, all_reduce); | |||||
| return mul; | |||||
| }; | |||||
| MATCH_REPLACE_LAMBDA(node, addn_pat, adjust_lambda); | |||||
| return nullptr; | |||||
| } | } | ||||
| void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, | void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, | ||||
| @@ -146,48 +146,6 @@ void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfN | |||||
| } | } | ||||
| } | } | ||||
| void AdjustAllReduceMulAdd::Visit(const AnfNodePtr &node) { | |||||
| if (level_ == 0) { | |||||
| level_ = 1; | |||||
| is_reduce_match_ = false; | |||||
| // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} | |||||
| AnfVisitor::Match(prim::kPrimMul)(node); | |||||
| level_ = 0; | |||||
| if (is_reduce_match_) { | |||||
| mul_ = node->cast<CNodePtr>()->input(0); | |||||
| mul_cnode_ = node->cast<CNodePtr>(); | |||||
| y_ = tmp_; | |||||
| } else { | |||||
| z_ = node; | |||||
| } | |||||
| } | |||||
| if (level_ == 1) { | |||||
| // {prim::kPrimAllReduce, X} | |||||
| if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode->size() > 1) { | |||||
| all_reduce_ = cnode->input(0); | |||||
| x_ = cnode->input(1); | |||||
| is_reduce_match_ = true; | |||||
| all_reduce_fg_ = cnode->func_graph(); | |||||
| } | |||||
| } else { | |||||
| tmp_ = node; | |||||
| } | |||||
| } | |||||
| } | |||||
| void AdjustAllReduceMulAdd::Reset() { | |||||
| level_ = 0; | |||||
| is_reduce_match_ = false; | |||||
| x_ = nullptr; | |||||
| y_ = nullptr; | |||||
| z_ = nullptr; | |||||
| tmp_ = nullptr; | |||||
| all_reduce_fg_ = nullptr; | |||||
| } | |||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -38,20 +38,14 @@ namespace irpass { | |||||
| // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> | // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> | ||||
| // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} | // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} | ||||
| class AdjustAllReduceMulAdd : public AnfVisitor { | |||||
| class AdjustAllReduceMulAdd : public OptimizerCaller { | |||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | ||||
| void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node); | void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node); | ||||
| void Visit(const AnfNodePtr &node) override; | |||||
| void Reset(); | |||||
| private: | private: | ||||
| int level_{0}; | |||||
| bool is_reduce_match_{false}; | |||||
| AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; | |||||
| AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr}; | |||||
| FuncGraphPtr all_reduce_fg_{nullptr}; | |||||
| AnfNodePtr mul_cnode_{nullptr}; | |||||
| }; | }; | ||||
| class ArithmeticSimplify : public OptimizerCaller { | class ArithmeticSimplify : public OptimizerCaller { | ||||
| @@ -94,8 +94,8 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > { | |||||
| ~PBinOperation() = default; | ~PBinOperation() = default; | ||||
| AnfNodePtr GetNode(const AnfNodePtr &node) const { | AnfNodePtr GetNode(const AnfNodePtr &node) const { | ||||
| AnfNodePtr lhs = x_.GetNode(node->func_graph()); | |||||
| AnfNodePtr rhs = y_.GetNode(node->func_graph()); | |||||
| AnfNodePtr lhs = x_.GetNode(node); | |||||
| AnfNodePtr rhs = y_.GetNode(node); | |||||
| AnfNodePtrList list = {NewValueNode(prim_), lhs, rhs}; | AnfNodePtrList list = {NewValueNode(prim_), lhs, rhs}; | ||||
| return NewCNode(list, node->func_graph()); | return NewCNode(list, node->func_graph()); | ||||
| } | } | ||||
| @@ -113,25 +113,42 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > { | |||||
| if (!x_.TryCapture(inputs[2]) || !y_.TryCapture(inputs[1])) { | if (!x_.TryCapture(inputs[2]) || !y_.TryCapture(inputs[1])) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| captured_binop_node_ = node; | |||||
| return true; | return true; | ||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| captured_binop_node_ = node; | |||||
| return true; | return true; | ||||
| } | } | ||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| /// Returns the original node captured by this Binary Operation Pattern. | |||||
| /// Throws exception if a node was not captured before. | |||||
| AnfNodePtr GetOriginalNode() const { | |||||
| if (captured_binop_node_ == nullptr) { | |||||
| MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it."; | |||||
| } | |||||
| return captured_binop_node_; | |||||
| } | |||||
| void Reset() const { | void Reset() const { | ||||
| x_.Reset(); | x_.Reset(); | ||||
| y_.Reset(); | y_.Reset(); | ||||
| captured_binop_node_ = nullptr; | |||||
| } | } | ||||
| using Internal = const PBinOperation<T, T2> &; | |||||
| private: | private: | ||||
| const PrimitivePtr prim_; | const PrimitivePtr prim_; | ||||
| typename T::Internal x_; | typename T::Internal x_; | ||||
| typename T2::Internal y_; | typename T2::Internal y_; | ||||
| bool is_commutative_{false}; | bool is_commutative_{false}; | ||||
| mutable AnfNodePtr captured_binop_node_{nullptr}; | |||||
| }; | }; | ||||
| /// | /// | ||||
| @@ -265,10 +282,11 @@ class PCNode : public PBase<PCNode<TArgs...> > { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| using Internal = const PCNode<TArgs...> &; | |||||
| void Reset() const { | void Reset() const { | ||||
| tuple_utils::PTupleResetCapture reset; | tuple_utils::PTupleResetCapture reset; | ||||
| tuple_utils::apply_func_tuple(&reset, args_); | tuple_utils::apply_func_tuple(&reset, args_); | ||||
| has_min_extra_nodes_ = false; | |||||
| extra_nodes_.clear(); | extra_nodes_.clear(); | ||||
| } | } | ||||
| @@ -316,6 +334,9 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > { | |||||
| AnfNodePtrList tokens(inputs.begin() + 1, inputs.end()); | AnfNodePtrList tokens(inputs.begin() + 1, inputs.end()); | ||||
| tuple_utils::PTupleCapture capture_func(tokens); | tuple_utils::PTupleCapture capture_func(tokens); | ||||
| tuple_utils::apply_func_tuple(&capture_func, args_); | tuple_utils::apply_func_tuple(&capture_func, args_); | ||||
| if (capture_func.captured_) { | |||||
| captured_prim_node_ = node; | |||||
| } | |||||
| return capture_func.captured_; | return capture_func.captured_; | ||||
| } | } | ||||
| return false; | return false; | ||||
| @@ -329,9 +350,11 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > { | |||||
| tuple_utils::apply_func_tuple(&capture_func, args_); | tuple_utils::apply_func_tuple(&capture_func, args_); | ||||
| // If it could capture the initial set of nodes specified in the Pattern | // If it could capture the initial set of nodes specified in the Pattern | ||||
| // and there are enough extra inputs to add | // and there are enough extra inputs to add | ||||
| if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) { | |||||
| extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); | |||||
| return true; | |||||
| if (capture_func.captured_) { | |||||
| captured_prim_node_ = node; | |||||
| if (inputs.size() > pattern_arg_len + 1) { | |||||
| extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); | |||||
| } | |||||
| } | } | ||||
| return capture_func.captured_; | return capture_func.captured_; | ||||
| } | } | ||||
| @@ -349,19 +372,42 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| /// Returns the FuncGraph of the original node captured by this Primitive Pattern. | |||||
| /// Throws exception if a node was not captured before. | |||||
| FuncGraphPtr GetFuncGraph() const { | |||||
| if (captured_prim_node_ == nullptr) { | |||||
| MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get its FuncGraph."; | |||||
| } | |||||
| return captured_prim_node_->func_graph(); | |||||
| } | |||||
| /// Returns the original node captured by this Primitive Pattern. | |||||
| /// Throws exception if a node was not captured before. | |||||
| AnfNodePtr GetOriginalNode() const { | |||||
| if (captured_prim_node_ == nullptr) { | |||||
| MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it."; | |||||
| } | |||||
| return captured_prim_node_; | |||||
| } | |||||
| void Reset() const { | void Reset() const { | ||||
| tuple_utils::PTupleResetCapture reset; | tuple_utils::PTupleResetCapture reset; | ||||
| tuple_utils::apply_func_tuple(&reset, args_); | tuple_utils::apply_func_tuple(&reset, args_); | ||||
| has_min_extra_nodes_ = false; | |||||
| extra_nodes_.clear(); | extra_nodes_.clear(); | ||||
| captured_prim_node_ = nullptr; | |||||
| } | } | ||||
| using Internal = const PPrimitive<TArgs...> &; | |||||
| private: | private: | ||||
| const PrimitivePtr prim_; | const PrimitivePtr prim_; | ||||
| std::tuple<typename TArgs::Internal...> args_; | std::tuple<typename TArgs::Internal...> args_; | ||||
| mutable AnfNodePtrList extra_nodes_; | mutable AnfNodePtrList extra_nodes_; | ||||
| mutable bool has_min_extra_nodes_{false}; | mutable bool has_min_extra_nodes_{false}; | ||||
| mutable size_t min_extra_nodes_{0}; | mutable size_t min_extra_nodes_{0}; | ||||
| mutable AnfNodePtr captured_prim_node_{nullptr}; | |||||
| }; | }; | ||||
| /// | /// | ||||