| @@ -112,7 +112,6 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||
| ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<DereluFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>()); | |||
| @@ -20,28 +20,23 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node, | |||
| const AnfNodePtr &add3, const AnfNodePtr &add5, const AnfNodePtr &real_div0, | |||
| const AnfNodePtr &real_div1) { | |||
| AnfNodePtr LambNextMVWithDecayRule::GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &new_node, const AnfNodePtr &add3, | |||
| const AnfNodePtr &add5, const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| MS_EXCEPTION_IF_NULL(add3); | |||
| MS_EXCEPTION_IF_NULL(real_div0); | |||
| MS_EXCEPTION_IF_NULL(real_div1); | |||
| MS_EXCEPTION_IF_NULL(add5); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto add0 = GetAnfNodeByVar(equiv, add0_var_); | |||
| MS_EXCEPTION_IF_NULL(add0); | |||
| auto add1 = GetAnfNodeByVar(equiv, add1_var_); | |||
| MS_EXCEPTION_IF_NULL(add1); | |||
| // Set abstract of new node | |||
| AbstractBasePtrList new_node_list; | |||
| new_node_list.push_back(add3->abstract()); | |||
| auto real_div0_cnode = real_div0->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(real_div0_cnode); | |||
| AnfNodePtr add0 = real_div0_cnode->input(1); | |||
| MS_EXCEPTION_IF_NULL(add0); | |||
| new_node_list.push_back(add0->abstract()); | |||
| auto real_div1_cnode = real_div1->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(real_div1_cnode); | |||
| AnfNodePtr add1 = real_div1_cnode->input(1); | |||
| MS_EXCEPTION_IF_NULL(add1); | |||
| new_node_list.push_back(add1->abstract()); | |||
| new_node_list.push_back(add5->abstract()); | |||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_list); | |||
| @@ -58,94 +53,8 @@ AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const An | |||
| return new_node_outputs[3]; | |||
| } | |||
| void GetSharedInputNodesByAdd5(const AnfNodePtr &node, AnfNodePtr *mul4, AnfNodePtr *real_div0, AnfNodePtr *real_div1, | |||
| AnfNodePtr *constant_add2_y_input) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto add5_cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(add5_cnode); | |||
| if (add5_cnode->inputs().size() < kAddInputNum) { | |||
| MS_LOG(EXCEPTION) << "The input size of Add5 is less than " << kAddInputNum; | |||
| } | |||
| *mul4 = add5_cnode->input(2); | |||
| AnfNodePtr real_div4 = add5_cnode->input(1); | |||
| MS_EXCEPTION_IF_NULL(real_div4); | |||
| auto real_div4_cnode = real_div4->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(real_div4_cnode); | |||
| if (real_div4_cnode->inputs().size() < kRealDivInputNum) { | |||
| MS_LOG(EXCEPTION) << "The input size of RealDiv4 is less than " << kRealDivInputNum; | |||
| } | |||
| *real_div0 = real_div4_cnode->input(1); | |||
| AnfNodePtr add4 = real_div4_cnode->input(2); | |||
| MS_EXCEPTION_IF_NULL(add4); | |||
| auto add4_cnode = add4->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(add4_cnode); | |||
| if (add4_cnode->inputs().size() < kAddInputNum) { | |||
| MS_LOG(EXCEPTION) << "The input size of Add4 is less than " << kAddInputNum; | |||
| } | |||
| AnfNodePtr sqrt1 = add4_cnode->input(1); | |||
| MS_EXCEPTION_IF_NULL(sqrt1); | |||
| auto sqrt1_cnode = sqrt1->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(sqrt1_cnode); | |||
| if (sqrt1_cnode->inputs().size() < kSqrtInputNum) { | |||
| MS_LOG(EXCEPTION) << "The input size of Sqrt1 is less than " << kSqrtInputNum; | |||
| } | |||
| *real_div1 = sqrt1_cnode->input(1); | |||
| *constant_add2_y_input = add4_cnode->input(2); | |||
| } | |||
| bool MatchAdd3(const AnfNodePtr &add3, const AnfNodePtr &mul4, const AnfNodePtr &real_div0, const AnfNodePtr &real_div1, | |||
| const AnfNodePtr &constant_add2_y) { | |||
| if (add3 == nullptr || !add3->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto add3_cnode = add3->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(add3_cnode); | |||
| if (AnfAlgo::GetCNodeName(add3_cnode) != prim::kPrimTensorAdd->name() || | |||
| add3_cnode->inputs().size() != kAddInputNum) { | |||
| return false; | |||
| } | |||
| // Check the shared input nodes. | |||
| if (add3_cnode->input(2) != mul4) { | |||
| return false; | |||
| } | |||
| AnfNodePtr real_div2 = add3_cnode->input(1); | |||
| MS_EXCEPTION_IF_NULL(real_div2); | |||
| auto real_div2_cnode = real_div2->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(real_div2_cnode); | |||
| if (AnfAlgo::GetCNodeName(real_div2_cnode) != prim::kPrimMul->name() || | |||
| real_div2_cnode->inputs().size() != kMulInputNum) { | |||
| return false; | |||
| } | |||
| if (real_div2_cnode->input(1) != real_div0) { | |||
| return false; | |||
| } | |||
| AnfNodePtr sqrt0 = real_div2_cnode->input(2); | |||
| MS_EXCEPTION_IF_NULL(sqrt0); | |||
| auto sqrt0_cnode = sqrt0->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(sqrt0_cnode); | |||
| if (AnfAlgo::GetCNodeName(sqrt0_cnode) != kRsqrtOpName || sqrt0_cnode->inputs().size() != kRsqrtInputNum) { | |||
| return false; | |||
| } | |||
| AnfNodePtr add2 = sqrt0_cnode->input(1); | |||
| MS_EXCEPTION_IF_NULL(add2); | |||
| auto add2_cnode = add2->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(add2_cnode); | |||
| if (AnfAlgo::GetCNodeName(add2_cnode) != prim::kPrimTensorAdd->name() || | |||
| add2_cnode->inputs().size() != kAddInputNum) { | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(add2_cnode->input(2)); | |||
| MS_EXCEPTION_IF_NULL(constant_add2_y); | |||
| return add2_cnode->input(1) == real_div1 && *(add2_cnode->input(2)) == *constant_add2_y; | |||
| } | |||
| } // namespace | |||
| AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &add3, const AnfNodePtr &add5, | |||
| const AnfNodePtr &real_div0, | |||
| const AnfNodePtr &real_div1, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(add3); | |||
| @@ -167,7 +76,7 @@ AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGrap | |||
| MS_EXCEPTION_IF_NULL(constant_add2_y_node); | |||
| new_node_inputs.push_back(constant_add2_y_node); | |||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||
| return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, real_div0, real_div1); | |||
| return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv); | |||
| } | |||
| const BaseRef LambNextMVWithDecayRule::DefinePattern() const { | |||
| @@ -175,44 +84,82 @@ const BaseRef LambNextMVWithDecayRule::DefinePattern() const { | |||
| MS_EXCEPTION_IF_NULL(prim_sqrt); | |||
| const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_deal_div); | |||
| VectorRef mul4 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[4], input_vars_[6]}); | |||
| VectorRef add0 = | |||
| VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}), | |||
| VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]})}); | |||
| VectorRef real_div0 = VectorRef({prim_deal_div, add0, input_vars_[5]}); | |||
| VectorRef add1 = | |||
| VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}), | |||
| VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]})}); | |||
| VectorRef real_div1 = VectorRef({prim_deal_div, add1, input_vars_[2]}); | |||
| VectorRef real_div4 = VectorRef( | |||
| {prim_deal_div, real_div0, VectorRef({prim::kPrimTensorAdd, VectorRef({prim_sqrt, real_div1}), constant_add2_y_})}); | |||
| return VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); | |||
| VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); | |||
| VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); | |||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | |||
| VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | |||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | |||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); | |||
| VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); | |||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||
| VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | |||
| VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | |||
| VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); | |||
| VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); | |||
| return add5; | |||
| } | |||
| const BaseRef LambNextMVWithDecayRule::DefineAnotherPattern() const { | |||
| const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_rsqrt); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||
| VarPtr Zs = std::make_shared<SeqVar>(); | |||
| MS_EXCEPTION_IF_NULL(Xs); | |||
| MS_EXCEPTION_IF_NULL(Ys); | |||
| MS_EXCEPTION_IF_NULL(Zs); | |||
| // Two patterns share: real_div0, real_div1, mul4, constant_add2_y_ | |||
| VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); | |||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | |||
| VectorRef mul4 = VectorRef({mul4_var_, Zs}); | |||
| VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); | |||
| VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | |||
| VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0}); | |||
| VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); | |||
| return add3; | |||
| } | |||
| bool LambNextMVWithDecayRule::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||
| auto empty_equiv = std::make_shared<Equiv>(); | |||
| MS_EXCEPTION_IF_NULL(child_primitive_vars_); | |||
| EquivPtr another_equiv = | |||
| child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node, | |||
| *child_primitive_vars_, empty_equiv); | |||
| if (another_equiv != nullptr && !another_equiv->empty()) { | |||
| return IsShareNodes(equiv, another_equiv); | |||
| } | |||
| return false; | |||
| } | |||
| bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const { | |||
| return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) && | |||
| IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_); | |||
| } | |||
| const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| // Get the shared input nodes in patterns of add5 and add3 | |||
| AnfNodePtr mul4 = nullptr; | |||
| AnfNodePtr real_div0 = nullptr; | |||
| AnfNodePtr real_div1 = nullptr; | |||
| AnfNodePtr constant_add2_y_input = nullptr; | |||
| GetSharedInputNodesByAdd5(node, &mul4, &real_div0, &real_div1, &constant_add2_y_input); | |||
| // Get add3 and try to match the add3 pattern | |||
| AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_); | |||
| MS_EXCEPTION_IF_NULL(mul4); | |||
| // Get add3 and match the add3 pattern | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| if (manager->node_users().find(mul4) == manager->node_users().end()) { | |||
| MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input"; | |||
| } | |||
| AnfNodeIndexSet mul4_output_node_index_set = manager->node_users()[mul4]; | |||
| auto iter = std::find_if( | |||
| mul4_output_node_index_set.begin(), mul4_output_node_index_set.end(), | |||
| [&node, &mul4, &real_div0, &real_div1, &constant_add2_y_input](const std::pair<AnfNodePtr, int> &node_index) { | |||
| return node_index.first != node && MatchAdd3(node_index.first, mul4, real_div0, real_div1, constant_add2_y_input); | |||
| }); | |||
| if (iter != mul4_output_node_index_set.end()) { | |||
| return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, real_div0, real_div1, equiv); | |||
| AnfNodeIndexSet mul4_outputs = manager->node_users()[mul4]; | |||
| auto iter = std::find_if(mul4_outputs.begin(), mul4_outputs.end(), | |||
| [&node, &equiv, this](const std::pair<AnfNodePtr, int> &node_index) { | |||
| return node_index.first != node && MatchAnotherPattern(node_index.first, equiv); | |||
| }); | |||
| if (iter != mul4_outputs.end()) { | |||
| return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, equiv); | |||
| } | |||
| return nullptr; | |||
| } | |||
| @@ -18,6 +18,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "pre_activate/common/optimizer.h" | |||
| #include "pre_activate/common/helper.h" | |||
| @@ -25,8 +26,13 @@ namespace mindspore { | |||
| namespace opt { | |||
| class LambNextMVWithDecayRule : public PatternProcessPass { | |||
| public: | |||
| explicit LambNextMVWithDecayRule(bool multigraph = true) | |||
| : PatternProcessPass("lamb_next_mv_with_decay_rule", multigraph) { | |||
| explicit LambNextMVWithDecayRule(const std::string &name = "lamb_next_mv_with_decay_rule_cond4", | |||
| bool multigraph = true) | |||
| : PatternProcessPass(name, multigraph), | |||
| child_pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(), | |||
| std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual), | |||
| std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))), | |||
| child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) { | |||
| for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) { | |||
| input_vars_.push_back(std::make_shared<Var>()); | |||
| } | |||
| @@ -34,20 +40,39 @@ class LambNextMVWithDecayRule : public PatternProcessPass { | |||
| constant_mul_input_vars_.push_back(std::make_shared<Var>()); | |||
| } | |||
| constant_add2_y_ = std::make_shared<Var>(); | |||
| mul4_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name())); | |||
| real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); | |||
| real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); | |||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||
| } | |||
| ~LambNextMVWithDecayRule() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| virtual const BaseRef DefineAnotherPattern() const; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3, | |||
| const AnfNodePtr &add5, const AnfNodePtr &real_div0, | |||
| const AnfNodePtr &real_div1, const EquivPtr &equiv) const; | |||
| protected: | |||
| bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const; | |||
| // check two patterns whether share the same nodes or not | |||
| bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const; | |||
| AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node, | |||
| const AnfNodePtr &add3, const AnfNodePtr &add5, const EquivPtr &equiv) const; | |||
| AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3, | |||
| const AnfNodePtr &add5, const EquivPtr &equiv) const; | |||
| PatternEngine child_pattern_engine_; | |||
| PrimitiveVarMapPtr child_primitive_vars_; | |||
| std::vector<VarPtr> input_vars_; | |||
| std::vector<VarPtr> constant_mul_input_vars_; | |||
| // nodes which two patterns share | |||
| VarPtr constant_add2_y_; | |||
| VarPtr mul4_var_; | |||
| VarPtr real_div0_var_; | |||
| VarPtr real_div1_var_; | |||
| // part of output nodes | |||
| VarPtr add0_var_; | |||
| VarPtr add1_var_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -64,6 +64,8 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph, | |||
| AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); | |||
| AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); | |||
| AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(false), new_node); | |||
| auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); | |||
| AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); | |||
| return new_node; | |||
| } | |||
| @@ -64,6 +64,8 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, | |||
| AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); | |||
| AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); | |||
| AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(true), new_node); | |||
| auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); | |||
| AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); | |||
| return new_node; | |||
| } | |||
| @@ -539,5 +539,169 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i | |||
| primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); | |||
| } | |||
| } | |||
| bool AnfEqual(const BaseRef &a, const BaseRef &b) { | |||
| if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) { | |||
| auto a_node = utils::cast<AnfNodePtr>(a); | |||
| auto b_node = utils::cast<AnfNodePtr>(b); | |||
| if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) { | |||
| auto a_value_node = a_node->cast<ValueNodePtr>(); | |||
| auto a_value = a_value_node->value(); | |||
| auto a_prim = a_value->cast<PrimitivePtr>(); | |||
| auto b_value_node = b_node->cast<ValueNodePtr>(); | |||
| auto b_value = b_value_node->value(); | |||
| auto b_prim = b_value->cast<PrimitivePtr>(); | |||
| return a_prim->name() == b_prim->name(); | |||
| } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) { | |||
| auto a_value_node_ptr = a_node->cast<ValueNodePtr>(); | |||
| if (a_value_node_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||
| } | |||
| auto a_value_ptr = a_value_node_ptr->value(); | |||
| if (a_value_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||
| } | |||
| auto b_value_node_ptr = b_node->cast<ValueNodePtr>(); | |||
| if (b_value_node_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||
| } | |||
| auto b_value_ptr = b_value_node_ptr->value(); | |||
| if (b_value_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||
| } | |||
| return (*a_value_ptr) == (*b_value_ptr); | |||
| } | |||
| MS_LOG(DEBUG) << "check AnfNodePtr equal"; | |||
| } | |||
| if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) { | |||
| MS_LOG(DEBUG) << "check GraphPtr equal"; | |||
| } | |||
| return a == b; | |||
| } | |||
| bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { | |||
| // To matchCNode and Kernel's type | |||
| if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) { | |||
| return true; | |||
| } | |||
| return a.type() == b.type(); | |||
| } | |||
| namespace { | |||
| ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { | |||
| if (utils::isa<int>(sexp)) { | |||
| return NewValueNode(utils::cast<int>(sexp)); | |||
| } | |||
| if (utils::isa<float>(sexp)) { | |||
| return NewValueNode(utils::cast<float>(sexp)); | |||
| } | |||
| if (utils::isa<bool>(sexp)) { | |||
| return NewValueNode(utils::cast<bool>(sexp)); | |||
| } | |||
| if (utils::isa<ValuePtr>(sexp)) { | |||
| return NewValueNode(utils::cast<ValuePtr>(sexp)); | |||
| } | |||
| return nullptr; | |||
| } | |||
| CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) { | |||
| if (utils::isa<FuncGraphPtr>(graph)) { | |||
| return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph)); | |||
| } | |||
| if (utils::isa<VarPtr>(graph)) { | |||
| return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph)); | |||
| } | |||
| return nullptr; | |||
| } | |||
| VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { | |||
| if (utils::isa<VarPtr>(graph)) { | |||
| MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); | |||
| return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr); | |||
| } | |||
| if (utils::isa<FuncGraphPtr>(graph)) { | |||
| MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); | |||
| return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph)); | |||
| } | |||
| MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | |||
| bool multigraph) { | |||
| MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | |||
| std::vector<AnfNodePtr> input_nodes; | |||
| const auto &tuple = utils::cast<VectorRef>(sexp); | |||
| if (multigraph && utils::isa<VarPtr>(graph)) { | |||
| for (auto &x : tuple) { | |||
| AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true); | |||
| input_nodes.push_back(node); | |||
| } | |||
| VarPtr var_ptr = utils::cast<VarPtr>(graph); | |||
| return std::make_shared<CNode>(input_nodes, var_ptr); | |||
| } | |||
| for (auto &x : tuple) { | |||
| AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); | |||
| input_nodes.push_back(node); | |||
| } | |||
| return CreateCNodeWithGraph(input_nodes, graph); | |||
| } | |||
| } // namespace | |||
| AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { | |||
| MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | |||
| MS_EXCEPTION_IF_NULL(primitive_vars); | |||
| if (utils::isa<VectorRef>(sexp)) { | |||
| return HandleSexpVector(sexp, graph, primitive_vars, multigraph); | |||
| } | |||
| if (utils::isa<VarPtr>(sexp)) { | |||
| auto var_ptr = utils::cast<VarPtr>(sexp); | |||
| MS_EXCEPTION_IF_NULL(var_ptr); | |||
| if (var_ptr->primitive()) { | |||
| (*primitive_vars)[var_ptr->primitive()] = var_ptr; | |||
| return NewValueNode(var_ptr->primitive()); | |||
| } | |||
| return CreateVarNodeWithSexp(sexp, graph); | |||
| } | |||
| if (utils::isa<AnfNodePtr>(sexp)) { | |||
| return utils::cast<AnfNodePtr>(sexp); | |||
| } | |||
| auto value_node = CreateValueNodeWithSexp(sexp); | |||
| if (value_node == nullptr) { | |||
| MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString(); | |||
| } | |||
| return value_node; | |||
| } | |||
| bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) { | |||
| MS_EXCEPTION_IF_NULL(equiv1); | |||
| MS_EXCEPTION_IF_NULL(equiv2); | |||
| MS_EXCEPTION_IF_NULL(var_node); | |||
| auto equiv1_node = GetAnfNodeByVar(equiv1, var_node); | |||
| MS_EXCEPTION_IF_NULL(equiv1_node); | |||
| auto equiv2_node = GetAnfNodeByVar(equiv2, var_node); | |||
| MS_EXCEPTION_IF_NULL(equiv2_node); | |||
| return equiv1_node == equiv2_node; | |||
| } | |||
| AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) { | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| MS_EXCEPTION_IF_NULL(var_node); | |||
| auto iter = (*equiv).find(var_node); | |||
| if (iter == (*equiv).end()) { | |||
| MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched."; | |||
| return nullptr; | |||
| } | |||
| auto res = utils::cast<AnfNodePtr>(iter->second); | |||
| if (res == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node"; | |||
| } | |||
| return res; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -23,6 +23,7 @@ | |||
| #include "ir/func_graph.h" | |||
| #include "session/kernel_graph.h" | |||
| #include "common/utils.h" | |||
| #include "pre_activate/common/pattern_engine.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -162,6 +163,19 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | |||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs); | |||
| bool AnfEqual(const BaseRef &a, const BaseRef &b); | |||
| bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); | |||
| AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | |||
| bool multigraph = false); | |||
| // Check var_node in two equivs is the same node | |||
| bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node); | |||
| // Get anf_node from equiv by var_node | |||
| AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ | |||
| @@ -29,148 +29,6 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | |||
| bool multigraph); | |||
| ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { | |||
| if (utils::isa<int>(sexp)) { | |||
| return NewValueNode(utils::cast<int>(sexp)); | |||
| } | |||
| if (utils::isa<float>(sexp)) { | |||
| return NewValueNode(utils::cast<float>(sexp)); | |||
| } | |||
| if (utils::isa<bool>(sexp)) { | |||
| return NewValueNode(utils::cast<bool>(sexp)); | |||
| } | |||
| if (utils::isa<ValuePtr>(sexp)) { | |||
| return NewValueNode(utils::cast<ValuePtr>(sexp)); | |||
| } | |||
| return nullptr; | |||
| } | |||
| CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) { | |||
| if (utils::isa<FuncGraphPtr>(graph)) { | |||
| return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph)); | |||
| } | |||
| if (utils::isa<VarPtr>(graph)) { | |||
| return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph)); | |||
| } | |||
| return nullptr; | |||
| } | |||
| VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { | |||
| if (utils::isa<VarPtr>(graph)) { | |||
| MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); | |||
| return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr); | |||
| } | |||
| if (utils::isa<FuncGraphPtr>(graph)) { | |||
| MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); | |||
| return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph)); | |||
| } | |||
| MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | |||
| bool multigraph = false) { | |||
| MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | |||
| MS_EXCEPTION_IF_NULL(primitive_vars); | |||
| if (utils::isa<VectorRef>(sexp)) { | |||
| return HandleSexpVector(sexp, graph, primitive_vars, multigraph); | |||
| } | |||
| if (utils::isa<VarPtr>(sexp)) { | |||
| auto var_ptr = utils::cast<VarPtr>(sexp); | |||
| MS_EXCEPTION_IF_NULL(var_ptr); | |||
| if (var_ptr->primitive()) { | |||
| (*primitive_vars)[var_ptr->primitive()] = var_ptr; | |||
| return NewValueNode(var_ptr->primitive()); | |||
| } | |||
| return CreateVarNodeWithSexp(sexp, graph); | |||
| } | |||
| if (utils::isa<AnfNodePtr>(sexp)) { | |||
| return utils::cast<AnfNodePtr>(sexp); | |||
| } | |||
| auto value_node = CreateValueNodeWithSexp(sexp); | |||
| if (value_node == nullptr) { | |||
| MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString(); | |||
| } | |||
| return value_node; | |||
| } | |||
| AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | |||
| bool multigraph) { | |||
| MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | |||
| std::vector<AnfNodePtr> input_nodes; | |||
| const auto &tuple = utils::cast<VectorRef>(sexp); | |||
| if (multigraph && utils::isa<VarPtr>(graph)) { | |||
| for (auto &x : tuple) { | |||
| AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true); | |||
| input_nodes.push_back(node); | |||
| } | |||
| VarPtr var_ptr = utils::cast<VarPtr>(graph); | |||
| return std::make_shared<CNode>(input_nodes, var_ptr); | |||
| } | |||
| for (auto &x : tuple) { | |||
| AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); | |||
| input_nodes.push_back(node); | |||
| } | |||
| return CreateCNodeWithGraph(input_nodes, graph); | |||
| } | |||
| } // namespace | |||
| static bool AnfEqual(const BaseRef &a, const BaseRef &b) { | |||
| if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) { | |||
| auto a_node = utils::cast<AnfNodePtr>(a); | |||
| auto b_node = utils::cast<AnfNodePtr>(b); | |||
| if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) { | |||
| auto a_value_node = a_node->cast<ValueNodePtr>(); | |||
| auto a_value = a_value_node->value(); | |||
| auto a_prim = a_value->cast<PrimitivePtr>(); | |||
| auto b_value_node = b_node->cast<ValueNodePtr>(); | |||
| auto b_value = b_value_node->value(); | |||
| auto b_prim = b_value->cast<PrimitivePtr>(); | |||
| return a_prim->name() == b_prim->name(); | |||
| } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) { | |||
| auto a_value_node_ptr = a_node->cast<ValueNodePtr>(); | |||
| if (a_value_node_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||
| } | |||
| auto a_value_ptr = a_value_node_ptr->value(); | |||
| if (a_value_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||
| } | |||
| auto b_value_node_ptr = b_node->cast<ValueNodePtr>(); | |||
| if (b_value_node_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||
| } | |||
| auto b_value_ptr = b_value_node_ptr->value(); | |||
| if (b_value_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||
| } | |||
| return (*a_value_ptr) == (*b_value_ptr); | |||
| } | |||
| MS_LOG(DEBUG) << "check AnfNodePtr equal"; | |||
| } | |||
| if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) { | |||
| MS_LOG(DEBUG) << "check GraphPtr equal"; | |||
| } | |||
| return a == b; | |||
| } | |||
| static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { | |||
| // To matchCNode and Kernel's type | |||
| if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) { | |||
| return true; | |||
| } | |||
| return a.type() == b.type(); | |||
| } | |||
| PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph) | |||
| : NodePass(name), | |||
| multigraph_(multigraph), | |||
| @@ -28,6 +28,7 @@ | |||
| #include "pre_activate/common/pattern_engine.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "common/utils.h" | |||
| #include "pre_activate/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -17,6 +17,7 @@ | |||
| #include "common/backend_common_test.h" | |||
| #include "common/py_func_graph_fetcher.h" | |||
| #include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -14,7 +14,6 @@ | |||
| # ============================================================================ | |||
| from mindspore.ops import Primitive | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| batch_norm_grad = G.BatchNormGrad(is_training=False) | |||
| @@ -24,7 +24,6 @@ make_tuple = Primitive('make_tuple') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| LambNextMVWithDecay = Primitive('LambNextMVWithDecay') | |||
| class FnDict: | |||
| def __init__(self): | |||
| self.fnDict = {} | |||
| @@ -35,7 +34,6 @@ class FnDict: | |||
| def __getitem__(self, name): | |||
| return self.fnDict[name] | |||
| def test_lamb_next_mv_with_decay_rule(tag): | |||
| fns = FnDict() | |||