Merge pull request !3696 from Giancarlo/update_adjust_allreducetags/v0.7.0-beta
| @@ -95,37 +95,37 @@ AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePt | |||
| // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> | |||
| // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} | |||
| AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||
| Reset(); | |||
| // {prim::kPrimAddN, Zs} | |||
| if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { | |||
| return nullptr; | |||
| } | |||
| auto addn = node->cast<CNodePtr>(); | |||
| if (addn->size() != 2) { | |||
| return nullptr; | |||
| } | |||
| AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); | |||
| if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto addn_maketuple = addn->input(1); | |||
| auto fg = all_reduce_fg_; | |||
| // addn inputs cross the graph, make the inputs same as allreduce node. | |||
| if (z_->isa<CNode>() && fg != z_->func_graph()) { | |||
| auto cnode_z = z_->cast<CNodePtr>(); | |||
| z_ = NewCNode(cnode_z->inputs(), fg); | |||
| } | |||
| auto addn_op_node = addn->input(0); | |||
| auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0); | |||
| PatternNode x, y, z; | |||
| auto all_reduce_pat = PPrimitive(prim::kPrimAllReduce, x); | |||
| auto mul_pat = PBinOperation(prim::kPrimMul, all_reduce_pat, y, true); | |||
| auto admktup_pat = PBinOperation(prim::kPrimMakeTuple, mul_pat, z, true); | |||
| auto addn_pat = PPrimitive(prim::kPrimAddN, admktup_pat); | |||
| auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr { | |||
| auto fg = all_reduce_pat.GetFuncGraph(); | |||
| auto z_ = z.GetNode(node); | |||
| // If addn inputs cross the graph, make the inputs same as allreduce node. | |||
| if (z_->isa<CNode>() && fg != z_->func_graph()) { | |||
| auto cnode_z = z_->cast<CNodePtr>(); | |||
| z_ = NewCNode(cnode_z->inputs(), fg); | |||
| } | |||
| AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); | |||
| AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); | |||
| AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); | |||
| AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); | |||
| ProcessDependEdge(fg, addn_maketuple, all_reduce); | |||
| return mul; | |||
| auto addn_cnode = addn_pat.GetOriginalNode()->cast<CNodePtr>(); | |||
| auto addn_op_node = addn_cnode->input(0); | |||
| auto make_tuple_op_node = addn_cnode->input(1)->cast<CNodePtr>()->input(0); | |||
| auto all_reduce_prim = all_reduce_pat.GetOriginalNode()->cast<CNodePtr>()->input(0); | |||
| mul_cnode_ = mul_pat.GetOriginalNode(); | |||
| auto mul_prim = mul_cnode_->cast<CNodePtr>()->input(0); | |||
| auto addn_maketuple = admktup_pat.GetOriginalNode(); | |||
| AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x.GetNode(node)}, fg); | |||
| AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); | |||
| AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg); | |||
| AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg); | |||
| ProcessDependEdge(fg, addn_maketuple, all_reduce); | |||
| return mul; | |||
| }; | |||
| MATCH_REPLACE_LAMBDA(node, addn_pat, adjust_lambda); | |||
| return nullptr; | |||
| } | |||
| void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, | |||
| @@ -146,48 +146,6 @@ void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfN | |||
| } | |||
| } | |||
| void AdjustAllReduceMulAdd::Visit(const AnfNodePtr &node) { | |||
| if (level_ == 0) { | |||
| level_ = 1; | |||
| is_reduce_match_ = false; | |||
| // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} | |||
| AnfVisitor::Match(prim::kPrimMul)(node); | |||
| level_ = 0; | |||
| if (is_reduce_match_) { | |||
| mul_ = node->cast<CNodePtr>()->input(0); | |||
| mul_cnode_ = node->cast<CNodePtr>(); | |||
| y_ = tmp_; | |||
| } else { | |||
| z_ = node; | |||
| } | |||
| } | |||
| if (level_ == 1) { | |||
| // {prim::kPrimAllReduce, X} | |||
| if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode->size() > 1) { | |||
| all_reduce_ = cnode->input(0); | |||
| x_ = cnode->input(1); | |||
| is_reduce_match_ = true; | |||
| all_reduce_fg_ = cnode->func_graph(); | |||
| } | |||
| } else { | |||
| tmp_ = node; | |||
| } | |||
| } | |||
| } | |||
| void AdjustAllReduceMulAdd::Reset() { | |||
| level_ = 0; | |||
| is_reduce_match_ = false; | |||
| x_ = nullptr; | |||
| y_ = nullptr; | |||
| z_ = nullptr; | |||
| tmp_ = nullptr; | |||
| all_reduce_fg_ = nullptr; | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -38,20 +38,14 @@ namespace irpass { | |||
| // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> | |||
| // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} | |||
| class AdjustAllReduceMulAdd : public AnfVisitor { | |||
| class AdjustAllReduceMulAdd : public OptimizerCaller { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||
| void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node); | |||
| void Visit(const AnfNodePtr &node) override; | |||
| void Reset(); | |||
| private: | |||
| int level_{0}; | |||
| bool is_reduce_match_{false}; | |||
| AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; | |||
| AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr}; | |||
| FuncGraphPtr all_reduce_fg_{nullptr}; | |||
| AnfNodePtr mul_cnode_{nullptr}; | |||
| }; | |||
| class ArithmeticSimplify : public OptimizerCaller { | |||
| @@ -94,8 +94,8 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > { | |||
| ~PBinOperation() = default; | |||
| AnfNodePtr GetNode(const AnfNodePtr &node) const { | |||
| AnfNodePtr lhs = x_.GetNode(node->func_graph()); | |||
| AnfNodePtr rhs = y_.GetNode(node->func_graph()); | |||
| AnfNodePtr lhs = x_.GetNode(node); | |||
| AnfNodePtr rhs = y_.GetNode(node); | |||
| AnfNodePtrList list = {NewValueNode(prim_), lhs, rhs}; | |||
| return NewCNode(list, node->func_graph()); | |||
| } | |||
| @@ -113,25 +113,42 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > { | |||
| if (!x_.TryCapture(inputs[2]) || !y_.TryCapture(inputs[1])) { | |||
| return false; | |||
| } | |||
| captured_binop_node_ = node; | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| captured_binop_node_ = node; | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| /// Returns the original node captured by this Binary Operation Pattern. | |||
| /// Throws exception if a node was not captured before. | |||
| AnfNodePtr GetOriginalNode() const { | |||
| if (captured_binop_node_ == nullptr) { | |||
| MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it."; | |||
| } | |||
| return captured_binop_node_; | |||
| } | |||
| void Reset() const { | |||
| x_.Reset(); | |||
| y_.Reset(); | |||
| captured_binop_node_ = nullptr; | |||
| } | |||
| using Internal = const PBinOperation<T, T2> &; | |||
| private: | |||
| const PrimitivePtr prim_; | |||
| typename T::Internal x_; | |||
| typename T2::Internal y_; | |||
| bool is_commutative_{false}; | |||
| mutable AnfNodePtr captured_binop_node_{nullptr}; | |||
| }; | |||
| /// | |||
| @@ -265,10 +282,11 @@ class PCNode : public PBase<PCNode<TArgs...> > { | |||
| return *this; | |||
| } | |||
| using Internal = const PCNode<TArgs...> &; | |||
| void Reset() const { | |||
| tuple_utils::PTupleResetCapture reset; | |||
| tuple_utils::apply_func_tuple(&reset, args_); | |||
| has_min_extra_nodes_ = false; | |||
| extra_nodes_.clear(); | |||
| } | |||
| @@ -316,6 +334,9 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > { | |||
| AnfNodePtrList tokens(inputs.begin() + 1, inputs.end()); | |||
| tuple_utils::PTupleCapture capture_func(tokens); | |||
| tuple_utils::apply_func_tuple(&capture_func, args_); | |||
| if (capture_func.captured_) { | |||
| captured_prim_node_ = node; | |||
| } | |||
| return capture_func.captured_; | |||
| } | |||
| return false; | |||
| @@ -329,9 +350,11 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > { | |||
| tuple_utils::apply_func_tuple(&capture_func, args_); | |||
| // If it could capture the initial set of nodes specified in the Pattern | |||
| // and there are enough extra inputs to add | |||
| if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) { | |||
| extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); | |||
| return true; | |||
| if (capture_func.captured_) { | |||
| captured_prim_node_ = node; | |||
| if (inputs.size() > pattern_arg_len + 1) { | |||
| extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); | |||
| } | |||
| } | |||
| return capture_func.captured_; | |||
| } | |||
| @@ -349,19 +372,42 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > { | |||
| return *this; | |||
| } | |||
| /// Returns the FuncGraph of the original node captured by this Primitive Pattern. | |||
| /// Throws exception if a node was not captured before. | |||
| FuncGraphPtr GetFuncGraph() const { | |||
| if (captured_prim_node_ == nullptr) { | |||
| MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get its FuncGraph."; | |||
| } | |||
| return captured_prim_node_->func_graph(); | |||
| } | |||
| /// Returns the original node captured by this Primitive Pattern. | |||
| /// Throws exception if a node was not captured before. | |||
| AnfNodePtr GetOriginalNode() const { | |||
| if (captured_prim_node_ == nullptr) { | |||
| MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it."; | |||
| } | |||
| return captured_prim_node_; | |||
| } | |||
| void Reset() const { | |||
| tuple_utils::PTupleResetCapture reset; | |||
| tuple_utils::apply_func_tuple(&reset, args_); | |||
| has_min_extra_nodes_ = false; | |||
| extra_nodes_.clear(); | |||
| captured_prim_node_ = nullptr; | |||
| } | |||
| using Internal = const PPrimitive<TArgs...> &; | |||
| private: | |||
| const PrimitivePtr prim_; | |||
| std::tuple<typename TArgs::Internal...> args_; | |||
| mutable AnfNodePtrList extra_nodes_; | |||
| mutable bool has_min_extra_nodes_{false}; | |||
| mutable size_t min_extra_nodes_{0}; | |||
| mutable AnfNodePtr captured_prim_node_{nullptr}; | |||
| }; | |||
| /// | |||