diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index fcdf7efbac..4011e0fe14 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -112,7 +112,6 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc index 58efbaf710..21bbf1a5d7 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc @@ -20,28 +20,23 @@ namespace mindspore { namespace opt { -namespace { -AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node, - const AnfNodePtr &add3, const AnfNodePtr &add5, const AnfNodePtr &real_div0, - const AnfNodePtr &real_div1) { +AnfNodePtr LambNextMVWithDecayRule::GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, + const AnfNodePtr &new_node, const AnfNodePtr &add3, + const AnfNodePtr &add5, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(new_node); MS_EXCEPTION_IF_NULL(add3); - MS_EXCEPTION_IF_NULL(real_div0); - MS_EXCEPTION_IF_NULL(real_div1); MS_EXCEPTION_IF_NULL(add5); + MS_EXCEPTION_IF_NULL(equiv); + auto add0 = GetAnfNodeByVar(equiv, add0_var_); + MS_EXCEPTION_IF_NULL(add0); + auto add1 = GetAnfNodeByVar(equiv, add1_var_); + MS_EXCEPTION_IF_NULL(add1); + // Set abstract of new node AbstractBasePtrList new_node_list; new_node_list.push_back(add3->abstract()); - auto real_div0_cnode = real_div0->cast(); - MS_EXCEPTION_IF_NULL(real_div0_cnode); - AnfNodePtr add0 = real_div0_cnode->input(1); - MS_EXCEPTION_IF_NULL(add0); new_node_list.push_back(add0->abstract()); - auto real_div1_cnode = real_div1->cast(); - MS_EXCEPTION_IF_NULL(real_div1_cnode); - AnfNodePtr add1 = real_div1_cnode->input(1); - MS_EXCEPTION_IF_NULL(add1); new_node_list.push_back(add1->abstract()); new_node_list.push_back(add5->abstract()); auto abstract_tuple = std::make_shared(new_node_list); @@ -58,94 +53,8 @@ AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const An return new_node_outputs[3]; } -void GetSharedInputNodesByAdd5(const AnfNodePtr &node, AnfNodePtr *mul4, AnfNodePtr *real_div0, AnfNodePtr *real_div1, - AnfNodePtr *constant_add2_y_input) { - MS_EXCEPTION_IF_NULL(node); - auto add5_cnode = node->cast(); - MS_EXCEPTION_IF_NULL(add5_cnode); - if (add5_cnode->inputs().size() < kAddInputNum) { - MS_LOG(EXCEPTION) << "The input size of Add5 is less than " << kAddInputNum; - } - *mul4 = add5_cnode->input(2); - - AnfNodePtr real_div4 = add5_cnode->input(1); - MS_EXCEPTION_IF_NULL(real_div4); - auto real_div4_cnode = real_div4->cast(); - MS_EXCEPTION_IF_NULL(real_div4_cnode); - if (real_div4_cnode->inputs().size() < kRealDivInputNum) { - MS_LOG(EXCEPTION) << "The input size of RealDiv4 is less than " << kRealDivInputNum; - } - *real_div0 = real_div4_cnode->input(1); - - AnfNodePtr add4 = real_div4_cnode->input(2); - MS_EXCEPTION_IF_NULL(add4); - auto add4_cnode = add4->cast(); - MS_EXCEPTION_IF_NULL(add4_cnode); - if (add4_cnode->inputs().size() < kAddInputNum) { - MS_LOG(EXCEPTION) << "The input size of Add4 is less than " << kAddInputNum; - } - AnfNodePtr sqrt1 = add4_cnode->input(1); - MS_EXCEPTION_IF_NULL(sqrt1); - auto sqrt1_cnode = sqrt1->cast(); - MS_EXCEPTION_IF_NULL(sqrt1_cnode); - if (sqrt1_cnode->inputs().size() < kSqrtInputNum) { - MS_LOG(EXCEPTION) << "The input size of Sqrt1 is less than " << kSqrtInputNum; - } - *real_div1 = sqrt1_cnode->input(1); - *constant_add2_y_input = add4_cnode->input(2); -} - -bool MatchAdd3(const AnfNodePtr &add3, const AnfNodePtr &mul4, const AnfNodePtr &real_div0, const AnfNodePtr &real_div1, - const AnfNodePtr &constant_add2_y) { - if (add3 == nullptr || !add3->isa()) { - return false; - } - auto add3_cnode = add3->cast(); - MS_EXCEPTION_IF_NULL(add3_cnode); - if (AnfAlgo::GetCNodeName(add3_cnode) != prim::kPrimTensorAdd->name() || - add3_cnode->inputs().size() != kAddInputNum) { - return false; - } - // Check the shared input nodes. - if (add3_cnode->input(2) != mul4) { - return false; - } - AnfNodePtr real_div2 = add3_cnode->input(1); - MS_EXCEPTION_IF_NULL(real_div2); - auto real_div2_cnode = real_div2->cast(); - MS_EXCEPTION_IF_NULL(real_div2_cnode); - if (AnfAlgo::GetCNodeName(real_div2_cnode) != prim::kPrimMul->name() || - real_div2_cnode->inputs().size() != kMulInputNum) { - return false; - } - if (real_div2_cnode->input(1) != real_div0) { - return false; - } - AnfNodePtr sqrt0 = real_div2_cnode->input(2); - MS_EXCEPTION_IF_NULL(sqrt0); - auto sqrt0_cnode = sqrt0->cast(); - MS_EXCEPTION_IF_NULL(sqrt0_cnode); - if (AnfAlgo::GetCNodeName(sqrt0_cnode) != kRsqrtOpName || sqrt0_cnode->inputs().size() != kRsqrtInputNum) { - return false; - } - AnfNodePtr add2 = sqrt0_cnode->input(1); - MS_EXCEPTION_IF_NULL(add2); - auto add2_cnode = add2->cast(); - MS_EXCEPTION_IF_NULL(add2_cnode); - if (AnfAlgo::GetCNodeName(add2_cnode) != prim::kPrimTensorAdd->name() || - add2_cnode->inputs().size() != kAddInputNum) { - return false; - } - MS_EXCEPTION_IF_NULL(add2_cnode->input(2)); - MS_EXCEPTION_IF_NULL(constant_add2_y); - return add2_cnode->input(1) == real_div1 && *(add2_cnode->input(2)) == *constant_add2_y; -} -} // namespace - AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3, const AnfNodePtr &add5, - const AnfNodePtr &real_div0, - const AnfNodePtr &real_div1, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(add3); @@ -167,7 +76,7 @@ AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGrap MS_EXCEPTION_IF_NULL(constant_add2_y_node); new_node_inputs.push_back(constant_add2_y_node); auto new_node = func_graph->NewCNode(new_node_inputs); - return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, real_div0, real_div1); + return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv); } const BaseRef LambNextMVWithDecayRule::DefinePattern() const { @@ -175,44 +84,82 @@ const BaseRef LambNextMVWithDecayRule::DefinePattern() const { MS_EXCEPTION_IF_NULL(prim_sqrt); const auto prim_deal_div = std::make_shared(kRealDivOpName); MS_EXCEPTION_IF_NULL(prim_deal_div); - VectorRef mul4 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[4], input_vars_[6]}); - VectorRef add0 = - VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}), - VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]})}); - VectorRef real_div0 = VectorRef({prim_deal_div, add0, input_vars_[5]}); - VectorRef add1 = - VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}), - VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]})}); - VectorRef real_div1 = VectorRef({prim_deal_div, add1, input_vars_[2]}); - VectorRef real_div4 = VectorRef( - {prim_deal_div, real_div0, VectorRef({prim::kPrimTensorAdd, VectorRef({prim_sqrt, real_div1}), constant_add2_y_})}); - return VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); + VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); + VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); + VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); + VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); + VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); + VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); + return add5; +} + +const BaseRef LambNextMVWithDecayRule::DefineAnotherPattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + VarPtr Zs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + MS_EXCEPTION_IF_NULL(Zs); + // Two patterns share: real_div0, real_div1, mul4, constant_add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + VectorRef mul4 = VectorRef({mul4_var_, Zs}); + + VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); + VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); + VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0}); + VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); + return add3; +} + +bool LambNextMVWithDecayRule::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + VarPtr fg = std::make_shared("RootG"); + auto empty_equiv = std::make_shared(); + MS_EXCEPTION_IF_NULL(child_primitive_vars_); + EquivPtr another_equiv = + child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node, + *child_primitive_vars_, empty_equiv); + if (another_equiv != nullptr && !another_equiv->empty()) { + return IsShareNodes(equiv, another_equiv); + } + return false; +} + +bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const { + return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) && + IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_); } const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); - // Get the shared input nodes in patterns of add5 and add3 - AnfNodePtr mul4 = nullptr; - AnfNodePtr real_div0 = nullptr; - AnfNodePtr real_div1 = nullptr; - AnfNodePtr constant_add2_y_input = nullptr; - GetSharedInputNodesByAdd5(node, &mul4, &real_div0, &real_div1, &constant_add2_y_input); - // Get add3 and try to match the add3 pattern + AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_); + MS_EXCEPTION_IF_NULL(mul4); + // Get add3 and match the add3 pattern auto manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); if (manager->node_users().find(mul4) == manager->node_users().end()) { MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input"; } - AnfNodeIndexSet mul4_output_node_index_set = manager->node_users()[mul4]; - auto iter = std::find_if( - mul4_output_node_index_set.begin(), mul4_output_node_index_set.end(), - [&node, &mul4, &real_div0, &real_div1, &constant_add2_y_input](const std::pair &node_index) { - return node_index.first != node && MatchAdd3(node_index.first, mul4, real_div0, real_div1, constant_add2_y_input); - }); - if (iter != mul4_output_node_index_set.end()) { - return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, real_div0, real_div1, equiv); + AnfNodeIndexSet mul4_outputs = manager->node_users()[mul4]; + auto iter = std::find_if(mul4_outputs.begin(), mul4_outputs.end(), + [&node, &equiv, this](const std::pair &node_index) { + return node_index.first != node && MatchAnotherPattern(node_index.first, equiv); + }); + if (iter != mul4_outputs.end()) { + return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, equiv); } return nullptr; } diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h index 161ce4e956..afc48bb8c6 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h @@ -18,6 +18,7 @@ #include #include +#include #include "pre_activate/common/optimizer.h" #include "pre_activate/common/helper.h" @@ -25,8 +26,13 @@ namespace mindspore { namespace opt { class LambNextMVWithDecayRule : public PatternProcessPass { public: - explicit LambNextMVWithDecayRule(bool multigraph = true) - : PatternProcessPass("lamb_next_mv_with_decay_rule", multigraph) { + explicit LambNextMVWithDecayRule(const std::string &name = "lamb_next_mv_with_decay_rule_cond4", + bool multigraph = true) + : PatternProcessPass(name, multigraph), + child_pattern_engine_(PatternEngine(std::make_shared(), + std::function(AnfEqual), + std::function(CNodeTypeEqual))), + child_primitive_vars_(std::make_shared()) { for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) { input_vars_.push_back(std::make_shared()); } @@ -34,20 +40,39 @@ class LambNextMVWithDecayRule : public PatternProcessPass { constant_mul_input_vars_.push_back(std::make_shared()); } constant_add2_y_ = std::make_shared(); + mul4_var_ = std::make_shared(std::make_shared(prim::kPrimMul->name())); + real_div0_var_ = std::make_shared(std::make_shared(kRealDivOpName)); + real_div1_var_ = std::make_shared(std::make_shared(kRealDivOpName)); + add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); } ~LambNextMVWithDecayRule() override = default; const BaseRef DefinePattern() const override; + virtual const BaseRef DefineAnotherPattern() const; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - private: - AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3, - const AnfNodePtr &add5, const AnfNodePtr &real_div0, - const AnfNodePtr &real_div1, const EquivPtr &equiv) const; + protected: + bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const; + // check two patterns whether share the same nodes or not + bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const; + AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node, + const AnfNodePtr &add3, const AnfNodePtr &add5, const EquivPtr &equiv) const; + AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3, + const AnfNodePtr &add5, const EquivPtr &equiv) const; + PatternEngine child_pattern_engine_; + PrimitiveVarMapPtr child_primitive_vars_; std::vector input_vars_; std::vector constant_mul_input_vars_; + // nodes which two patterns share VarPtr constant_add2_y_; + VarPtr mul4_var_; + VarPtr real_div0_var_; + VarPtr real_div1_var_; + // part of output nodes + VarPtr add0_var_; + VarPtr add1_var_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc index d41d3e3c4a..45bf640abe 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc @@ -64,6 +64,8 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph, AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(false), new_node); + auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); + AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); return new_node; } diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc index 138fa33128..2f982e0413 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc @@ -64,6 +64,8 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(true), new_node); + auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); + AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); return new_node; } diff --git a/mindspore/ccsrc/pre_activate/common/helper.cc b/mindspore/ccsrc/pre_activate/common/helper.cc index 2f5cfe1423..decaaaca62 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.cc +++ b/mindspore/ccsrc/pre_activate/common/helper.cc @@ -539,5 +539,169 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &i primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); } } + +bool AnfEqual(const BaseRef &a, const BaseRef &b) { + if (utils::isa(a) && utils::isa(b)) { + auto a_node = utils::cast(a); + auto b_node = utils::cast(b); + if (IsValueNode(a_node) && IsValueNode(b_node)) { + auto a_value_node = a_node->cast(); + auto a_value = a_value_node->value(); + auto a_prim = a_value->cast(); + + auto b_value_node = b_node->cast(); + auto b_value = b_value_node->value(); + auto b_prim = b_value->cast(); + + return a_prim->name() == b_prim->name(); + } else if (a_node->isa() && b_node->isa()) { + auto a_value_node_ptr = a_node->cast(); + if (a_value_node_ptr == nullptr) { + MS_LOG(EXCEPTION) << "cast value node ptr fail"; + } + auto a_value_ptr = a_value_node_ptr->value(); + if (a_value_ptr == nullptr) { + MS_LOG(EXCEPTION) << "value ptr is nullptr"; + } + + auto b_value_node_ptr = b_node->cast(); + if (b_value_node_ptr == nullptr) { + MS_LOG(EXCEPTION) << "cast value node ptr fail"; + } + auto b_value_ptr = b_value_node_ptr->value(); + if (b_value_ptr == nullptr) { + MS_LOG(EXCEPTION) << "value ptr is nullptr"; + } + + return (*a_value_ptr) == (*b_value_ptr); + } + MS_LOG(DEBUG) << "check AnfNodePtr equal"; + } + if (utils::isa(a) && utils::isa(b)) { + MS_LOG(DEBUG) << "check GraphPtr equal"; + } + return a == b; +} + +bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { + // To matchCNode and Kernel's type + if (utils::isa(a) && utils::isa(b)) { + return true; + } + return a.type() == b.type(); +} + +namespace { +ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + return nullptr; +} + +CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { + if (utils::isa(graph)) { + return std::make_shared(input_nodes, utils::cast(graph)); + } + if (utils::isa(graph)) { + return std::make_shared(input_nodes, utils::cast(graph)); + } + return nullptr; +} + +VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { + if (utils::isa(graph)) { + MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); + return std::make_shared(utils::cast(sexp), nullptr); + } + if (utils::isa(graph)) { + MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); + return std::make_shared(utils::cast(sexp), utils::cast(graph)); + } + MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); + return nullptr; +} + +AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph) { + MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); + std::vector input_nodes; + const auto &tuple = utils::cast(sexp); + if (multigraph && utils::isa(graph)) { + for (auto &x : tuple) { + AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); + input_nodes.push_back(node); + } + VarPtr var_ptr = utils::cast(graph); + return std::make_shared(input_nodes, var_ptr); + } + + for (auto &x : tuple) { + AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); + input_nodes.push_back(node); + } + return CreateCNodeWithGraph(input_nodes, graph); +} +} // namespace + +AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { + MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); + MS_EXCEPTION_IF_NULL(primitive_vars); + if (utils::isa(sexp)) { + return HandleSexpVector(sexp, graph, primitive_vars, multigraph); + } + if (utils::isa(sexp)) { + auto var_ptr = utils::cast(sexp); + MS_EXCEPTION_IF_NULL(var_ptr); + if (var_ptr->primitive()) { + (*primitive_vars)[var_ptr->primitive()] = var_ptr; + return NewValueNode(var_ptr->primitive()); + } + return CreateVarNodeWithSexp(sexp, graph); + } + if (utils::isa(sexp)) { + return utils::cast(sexp); + } + auto value_node = CreateValueNodeWithSexp(sexp); + if (value_node == nullptr) { + MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString(); + } + return value_node; +} + +bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) { + MS_EXCEPTION_IF_NULL(equiv1); + MS_EXCEPTION_IF_NULL(equiv2); + MS_EXCEPTION_IF_NULL(var_node); + auto equiv1_node = GetAnfNodeByVar(equiv1, var_node); + MS_EXCEPTION_IF_NULL(equiv1_node); + auto equiv2_node = GetAnfNodeByVar(equiv2, var_node); + MS_EXCEPTION_IF_NULL(equiv2_node); + return equiv1_node == equiv2_node; +} + +AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) { + MS_EXCEPTION_IF_NULL(equiv); + MS_EXCEPTION_IF_NULL(var_node); + auto iter = (*equiv).find(var_node); + if (iter == (*equiv).end()) { + MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched."; + return nullptr; + } + auto res = utils::cast(iter->second); + if (res == nullptr) { + MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node"; + } + return res; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/helper.h b/mindspore/ccsrc/pre_activate/common/helper.h index 9a162d66f4..d315f6b5d9 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.h +++ b/mindspore/ccsrc/pre_activate/common/helper.h @@ -23,6 +23,7 @@ #include "ir/func_graph.h" #include "session/kernel_graph.h" #include "common/utils.h" +#include "pre_activate/common/pattern_engine.h" namespace mindspore { namespace opt { @@ -162,6 +163,19 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs); + +bool AnfEqual(const BaseRef &a, const BaseRef &b); + +bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); + +AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph = false); + +// Check var_node in two equivs is the same node +bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node); + +// Get anf_node from equiv by var_node +AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ diff --git a/mindspore/ccsrc/pre_activate/common/optimizer.cc b/mindspore/ccsrc/pre_activate/common/optimizer.cc index 0e74da3fe8..2711d87721 100644 --- a/mindspore/ccsrc/pre_activate/common/optimizer.cc +++ b/mindspore/ccsrc/pre_activate/common/optimizer.cc @@ -29,148 +29,6 @@ namespace mindspore { namespace opt { -namespace { -AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, - bool multigraph); - -ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - return nullptr; -} - -CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { - if (utils::isa(graph)) { - return std::make_shared(input_nodes, utils::cast(graph)); - } - if (utils::isa(graph)) { - return std::make_shared(input_nodes, utils::cast(graph)); - } - return nullptr; -} - -VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { - if (utils::isa(graph)) { - MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); - return std::make_shared(utils::cast(sexp), nullptr); - } - if (utils::isa(graph)) { - MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); - return std::make_shared(utils::cast(sexp), utils::cast(graph)); - } - MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); - return nullptr; -} - -AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, - bool multigraph = false) { - MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); - MS_EXCEPTION_IF_NULL(primitive_vars); - if (utils::isa(sexp)) { - return HandleSexpVector(sexp, graph, primitive_vars, multigraph); - } - if (utils::isa(sexp)) { - auto var_ptr = utils::cast(sexp); - MS_EXCEPTION_IF_NULL(var_ptr); - if (var_ptr->primitive()) { - (*primitive_vars)[var_ptr->primitive()] = var_ptr; - return NewValueNode(var_ptr->primitive()); - } - return CreateVarNodeWithSexp(sexp, graph); - } - if (utils::isa(sexp)) { - return utils::cast(sexp); - } - auto value_node = CreateValueNodeWithSexp(sexp); - if (value_node == nullptr) { - MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString(); - } - return value_node; -} - -AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, - bool multigraph) { - MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); - std::vector input_nodes; - const auto &tuple = utils::cast(sexp); - if (multigraph && utils::isa(graph)) { - for (auto &x : tuple) { - AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); - input_nodes.push_back(node); - } - VarPtr var_ptr = utils::cast(graph); - return std::make_shared(input_nodes, var_ptr); - } - - for (auto &x : tuple) { - AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); - input_nodes.push_back(node); - } - return CreateCNodeWithGraph(input_nodes, graph); -} -} // namespace - -static bool AnfEqual(const BaseRef &a, const BaseRef &b) { - if (utils::isa(a) && utils::isa(b)) { - auto a_node = utils::cast(a); - auto b_node = utils::cast(b); - if (IsValueNode(a_node) && IsValueNode(b_node)) { - auto a_value_node = a_node->cast(); - auto a_value = a_value_node->value(); - auto a_prim = a_value->cast(); - - auto b_value_node = b_node->cast(); - auto b_value = b_value_node->value(); - auto b_prim = b_value->cast(); - - return a_prim->name() == b_prim->name(); - } else if (a_node->isa() && b_node->isa()) { - auto a_value_node_ptr = a_node->cast(); - if (a_value_node_ptr == nullptr) { - MS_LOG(EXCEPTION) << "cast value node ptr fail"; - } - auto a_value_ptr = a_value_node_ptr->value(); - if (a_value_ptr == nullptr) { - MS_LOG(EXCEPTION) << "value ptr is nullptr"; - } - - auto b_value_node_ptr = b_node->cast(); - if (b_value_node_ptr == nullptr) { - MS_LOG(EXCEPTION) << "cast value node ptr fail"; - } - auto b_value_ptr = b_value_node_ptr->value(); - if (b_value_ptr == nullptr) { - MS_LOG(EXCEPTION) << "value ptr is nullptr"; - } - - return (*a_value_ptr) == (*b_value_ptr); - } - MS_LOG(DEBUG) << "check AnfNodePtr equal"; - } - if (utils::isa(a) && utils::isa(b)) { - MS_LOG(DEBUG) << "check GraphPtr equal"; - } - return a == b; -} - -static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { - // To matchCNode and Kernel's type - if (utils::isa(a) && utils::isa(b)) { - return true; - } - return a.type() == b.type(); -} - PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph) : NodePass(name), multigraph_(multigraph), diff --git a/mindspore/ccsrc/pre_activate/common/optimizer.h b/mindspore/ccsrc/pre_activate/common/optimizer.h index eade7f7789..cec23ae178 100644 --- a/mindspore/ccsrc/pre_activate/common/optimizer.h +++ b/mindspore/ccsrc/pre_activate/common/optimizer.h @@ -28,6 +28,7 @@ #include "pre_activate/common/pattern_engine.h" #include "utils/graph_utils.h" #include "common/utils.h" +#include "pre_activate/common/helper.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc index c0adeafe7b..f114c77216 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc @@ -17,6 +17,7 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" +#include "debug/anf_ir_dump.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/batchnormgrad_to_bninfergrad.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/batchnormgrad_to_bninfergrad.py index b71f638342..48cf28b325 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/batchnormgrad_to_bninfergrad.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/batchnormgrad_to_bninfergrad.py @@ -14,7 +14,6 @@ # ============================================================================ from mindspore.ops import Primitive -from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops as G batch_norm_grad = G.BatchNormGrad(is_training=False) diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py index 2c8a0d03ed..d9d1ab5c39 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py @@ -24,7 +24,6 @@ make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') LambNextMVWithDecay = Primitive('LambNextMVWithDecay') - class FnDict: def __init__(self): self.fnDict = {} @@ -35,7 +34,6 @@ class FnDict: def __getitem__(self, name): return self.fnDict[name] - def test_lamb_next_mv_with_decay_rule(tag): fns = FnDict()