| @@ -112,7 +112,6 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>()); | ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<DereluFusion>()); | ir_fusion_pm->AddPass(std::make_shared<DereluFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>()); | ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>()); | ||||
| @@ -20,28 +20,23 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | |||||
| AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node, | |||||
| const AnfNodePtr &add3, const AnfNodePtr &add5, const AnfNodePtr &real_div0, | |||||
| const AnfNodePtr &real_div1) { | |||||
| AnfNodePtr LambNextMVWithDecayRule::GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, | |||||
| const AnfNodePtr &new_node, const AnfNodePtr &add3, | |||||
| const AnfNodePtr &add5, const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(new_node); | MS_EXCEPTION_IF_NULL(new_node); | ||||
| MS_EXCEPTION_IF_NULL(add3); | MS_EXCEPTION_IF_NULL(add3); | ||||
| MS_EXCEPTION_IF_NULL(real_div0); | |||||
| MS_EXCEPTION_IF_NULL(real_div1); | |||||
| MS_EXCEPTION_IF_NULL(add5); | MS_EXCEPTION_IF_NULL(add5); | ||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| auto add0 = GetAnfNodeByVar(equiv, add0_var_); | |||||
| MS_EXCEPTION_IF_NULL(add0); | |||||
| auto add1 = GetAnfNodeByVar(equiv, add1_var_); | |||||
| MS_EXCEPTION_IF_NULL(add1); | |||||
| // Set abstract of new node | // Set abstract of new node | ||||
| AbstractBasePtrList new_node_list; | AbstractBasePtrList new_node_list; | ||||
| new_node_list.push_back(add3->abstract()); | new_node_list.push_back(add3->abstract()); | ||||
| auto real_div0_cnode = real_div0->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(real_div0_cnode); | |||||
| AnfNodePtr add0 = real_div0_cnode->input(1); | |||||
| MS_EXCEPTION_IF_NULL(add0); | |||||
| new_node_list.push_back(add0->abstract()); | new_node_list.push_back(add0->abstract()); | ||||
| auto real_div1_cnode = real_div1->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(real_div1_cnode); | |||||
| AnfNodePtr add1 = real_div1_cnode->input(1); | |||||
| MS_EXCEPTION_IF_NULL(add1); | |||||
| new_node_list.push_back(add1->abstract()); | new_node_list.push_back(add1->abstract()); | ||||
| new_node_list.push_back(add5->abstract()); | new_node_list.push_back(add5->abstract()); | ||||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_list); | auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_list); | ||||
| @@ -58,94 +53,8 @@ AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const An | |||||
| return new_node_outputs[3]; | return new_node_outputs[3]; | ||||
| } | } | ||||
| void GetSharedInputNodesByAdd5(const AnfNodePtr &node, AnfNodePtr *mul4, AnfNodePtr *real_div0, AnfNodePtr *real_div1, | |||||
| AnfNodePtr *constant_add2_y_input) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto add5_cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(add5_cnode); | |||||
| if (add5_cnode->inputs().size() < kAddInputNum) { | |||||
| MS_LOG(EXCEPTION) << "The input size of Add5 is less than " << kAddInputNum; | |||||
| } | |||||
| *mul4 = add5_cnode->input(2); | |||||
| AnfNodePtr real_div4 = add5_cnode->input(1); | |||||
| MS_EXCEPTION_IF_NULL(real_div4); | |||||
| auto real_div4_cnode = real_div4->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(real_div4_cnode); | |||||
| if (real_div4_cnode->inputs().size() < kRealDivInputNum) { | |||||
| MS_LOG(EXCEPTION) << "The input size of RealDiv4 is less than " << kRealDivInputNum; | |||||
| } | |||||
| *real_div0 = real_div4_cnode->input(1); | |||||
| AnfNodePtr add4 = real_div4_cnode->input(2); | |||||
| MS_EXCEPTION_IF_NULL(add4); | |||||
| auto add4_cnode = add4->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(add4_cnode); | |||||
| if (add4_cnode->inputs().size() < kAddInputNum) { | |||||
| MS_LOG(EXCEPTION) << "The input size of Add4 is less than " << kAddInputNum; | |||||
| } | |||||
| AnfNodePtr sqrt1 = add4_cnode->input(1); | |||||
| MS_EXCEPTION_IF_NULL(sqrt1); | |||||
| auto sqrt1_cnode = sqrt1->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(sqrt1_cnode); | |||||
| if (sqrt1_cnode->inputs().size() < kSqrtInputNum) { | |||||
| MS_LOG(EXCEPTION) << "The input size of Sqrt1 is less than " << kSqrtInputNum; | |||||
| } | |||||
| *real_div1 = sqrt1_cnode->input(1); | |||||
| *constant_add2_y_input = add4_cnode->input(2); | |||||
| } | |||||
| bool MatchAdd3(const AnfNodePtr &add3, const AnfNodePtr &mul4, const AnfNodePtr &real_div0, const AnfNodePtr &real_div1, | |||||
| const AnfNodePtr &constant_add2_y) { | |||||
| if (add3 == nullptr || !add3->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| auto add3_cnode = add3->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(add3_cnode); | |||||
| if (AnfAlgo::GetCNodeName(add3_cnode) != prim::kPrimTensorAdd->name() || | |||||
| add3_cnode->inputs().size() != kAddInputNum) { | |||||
| return false; | |||||
| } | |||||
| // Check the shared input nodes. | |||||
| if (add3_cnode->input(2) != mul4) { | |||||
| return false; | |||||
| } | |||||
| AnfNodePtr real_div2 = add3_cnode->input(1); | |||||
| MS_EXCEPTION_IF_NULL(real_div2); | |||||
| auto real_div2_cnode = real_div2->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(real_div2_cnode); | |||||
| if (AnfAlgo::GetCNodeName(real_div2_cnode) != prim::kPrimMul->name() || | |||||
| real_div2_cnode->inputs().size() != kMulInputNum) { | |||||
| return false; | |||||
| } | |||||
| if (real_div2_cnode->input(1) != real_div0) { | |||||
| return false; | |||||
| } | |||||
| AnfNodePtr sqrt0 = real_div2_cnode->input(2); | |||||
| MS_EXCEPTION_IF_NULL(sqrt0); | |||||
| auto sqrt0_cnode = sqrt0->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(sqrt0_cnode); | |||||
| if (AnfAlgo::GetCNodeName(sqrt0_cnode) != kRsqrtOpName || sqrt0_cnode->inputs().size() != kRsqrtInputNum) { | |||||
| return false; | |||||
| } | |||||
| AnfNodePtr add2 = sqrt0_cnode->input(1); | |||||
| MS_EXCEPTION_IF_NULL(add2); | |||||
| auto add2_cnode = add2->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(add2_cnode); | |||||
| if (AnfAlgo::GetCNodeName(add2_cnode) != prim::kPrimTensorAdd->name() || | |||||
| add2_cnode->inputs().size() != kAddInputNum) { | |||||
| return false; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(add2_cnode->input(2)); | |||||
| MS_EXCEPTION_IF_NULL(constant_add2_y); | |||||
| return add2_cnode->input(1) == real_div1 && *(add2_cnode->input(2)) == *constant_add2_y; | |||||
| } | |||||
| } // namespace | |||||
| AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, | AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, | ||||
| const AnfNodePtr &add3, const AnfNodePtr &add5, | const AnfNodePtr &add3, const AnfNodePtr &add5, | ||||
| const AnfNodePtr &real_div0, | |||||
| const AnfNodePtr &real_div1, | |||||
| const EquivPtr &equiv) const { | const EquivPtr &equiv) const { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(add3); | MS_EXCEPTION_IF_NULL(add3); | ||||
| @@ -167,7 +76,7 @@ AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGrap | |||||
| MS_EXCEPTION_IF_NULL(constant_add2_y_node); | MS_EXCEPTION_IF_NULL(constant_add2_y_node); | ||||
| new_node_inputs.push_back(constant_add2_y_node); | new_node_inputs.push_back(constant_add2_y_node); | ||||
| auto new_node = func_graph->NewCNode(new_node_inputs); | auto new_node = func_graph->NewCNode(new_node_inputs); | ||||
| return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, real_div0, real_div1); | |||||
| return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv); | |||||
| } | } | ||||
| const BaseRef LambNextMVWithDecayRule::DefinePattern() const { | const BaseRef LambNextMVWithDecayRule::DefinePattern() const { | ||||
| @@ -175,44 +84,82 @@ const BaseRef LambNextMVWithDecayRule::DefinePattern() const { | |||||
| MS_EXCEPTION_IF_NULL(prim_sqrt); | MS_EXCEPTION_IF_NULL(prim_sqrt); | ||||
| const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); | const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); | ||||
| MS_EXCEPTION_IF_NULL(prim_deal_div); | MS_EXCEPTION_IF_NULL(prim_deal_div); | ||||
| VectorRef mul4 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[4], input_vars_[6]}); | |||||
| VectorRef add0 = | |||||
| VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}), | |||||
| VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]})}); | |||||
| VectorRef real_div0 = VectorRef({prim_deal_div, add0, input_vars_[5]}); | |||||
| VectorRef add1 = | |||||
| VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}), | |||||
| VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]})}); | |||||
| VectorRef real_div1 = VectorRef({prim_deal_div, add1, input_vars_[2]}); | |||||
| VectorRef real_div4 = VectorRef( | |||||
| {prim_deal_div, real_div0, VectorRef({prim::kPrimTensorAdd, VectorRef({prim_sqrt, real_div1}), constant_add2_y_})}); | |||||
| return VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); | |||||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | |||||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | |||||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | |||||
| VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | |||||
| VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); | |||||
| VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); | |||||
| return add5; | |||||
| } | |||||
| const BaseRef LambNextMVWithDecayRule::DefineAnotherPattern() const { | |||||
| const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName); | |||||
| MS_EXCEPTION_IF_NULL(prim_rsqrt); | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||||
| VarPtr Zs = std::make_shared<SeqVar>(); | |||||
| MS_EXCEPTION_IF_NULL(Xs); | |||||
| MS_EXCEPTION_IF_NULL(Ys); | |||||
| MS_EXCEPTION_IF_NULL(Zs); | |||||
| // Two patterns share: real_div0, real_div1, mul4, constant_add2_y_ | |||||
| VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | |||||
| VectorRef mul4 = VectorRef({mul4_var_, Zs}); | |||||
| VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); | |||||
| VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | |||||
| VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0}); | |||||
| VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); | |||||
| return add3; | |||||
| } | |||||
| bool LambNextMVWithDecayRule::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| VarPtr fg = std::make_shared<Var>("RootG"); | |||||
| auto empty_equiv = std::make_shared<Equiv>(); | |||||
| MS_EXCEPTION_IF_NULL(child_primitive_vars_); | |||||
| EquivPtr another_equiv = | |||||
| child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node, | |||||
| *child_primitive_vars_, empty_equiv); | |||||
| if (another_equiv != nullptr && !another_equiv->empty()) { | |||||
| return IsShareNodes(equiv, another_equiv); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const { | |||||
| return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) && | |||||
| IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_); | |||||
| } | } | ||||
| const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const EquivPtr &equiv) const { | const EquivPtr &equiv) const { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| // Get the shared input nodes in patterns of add5 and add3 | |||||
| AnfNodePtr mul4 = nullptr; | |||||
| AnfNodePtr real_div0 = nullptr; | |||||
| AnfNodePtr real_div1 = nullptr; | |||||
| AnfNodePtr constant_add2_y_input = nullptr; | |||||
| GetSharedInputNodesByAdd5(node, &mul4, &real_div0, &real_div1, &constant_add2_y_input); | |||||
| // Get add3 and try to match the add3 pattern | |||||
| AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_); | |||||
| MS_EXCEPTION_IF_NULL(mul4); | |||||
| // Get add3 and match the add3 pattern | |||||
| auto manager = func_graph->manager(); | auto manager = func_graph->manager(); | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| if (manager->node_users().find(mul4) == manager->node_users().end()) { | if (manager->node_users().find(mul4) == manager->node_users().end()) { | ||||
| MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input"; | MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input"; | ||||
| } | } | ||||
| AnfNodeIndexSet mul4_output_node_index_set = manager->node_users()[mul4]; | |||||
| auto iter = std::find_if( | |||||
| mul4_output_node_index_set.begin(), mul4_output_node_index_set.end(), | |||||
| [&node, &mul4, &real_div0, &real_div1, &constant_add2_y_input](const std::pair<AnfNodePtr, int> &node_index) { | |||||
| return node_index.first != node && MatchAdd3(node_index.first, mul4, real_div0, real_div1, constant_add2_y_input); | |||||
| }); | |||||
| if (iter != mul4_output_node_index_set.end()) { | |||||
| return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, real_div0, real_div1, equiv); | |||||
| AnfNodeIndexSet mul4_outputs = manager->node_users()[mul4]; | |||||
| auto iter = std::find_if(mul4_outputs.begin(), mul4_outputs.end(), | |||||
| [&node, &equiv, this](const std::pair<AnfNodePtr, int> &node_index) { | |||||
| return node_index.first != node && MatchAnotherPattern(node_index.first, equiv); | |||||
| }); | |||||
| if (iter != mul4_outputs.end()) { | |||||
| return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, equiv); | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | |||||
| #include "pre_activate/common/optimizer.h" | #include "pre_activate/common/optimizer.h" | ||||
| #include "pre_activate/common/helper.h" | #include "pre_activate/common/helper.h" | ||||
| @@ -25,8 +26,13 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| class LambNextMVWithDecayRule : public PatternProcessPass { | class LambNextMVWithDecayRule : public PatternProcessPass { | ||||
| public: | public: | ||||
| explicit LambNextMVWithDecayRule(bool multigraph = true) | |||||
| : PatternProcessPass("lamb_next_mv_with_decay_rule", multigraph) { | |||||
| explicit LambNextMVWithDecayRule(const std::string &name = "lamb_next_mv_with_decay_rule_cond4", | |||||
| bool multigraph = true) | |||||
| : PatternProcessPass(name, multigraph), | |||||
| child_pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(), | |||||
| std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual), | |||||
| std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))), | |||||
| child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) { | |||||
| for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) { | for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) { | ||||
| input_vars_.push_back(std::make_shared<Var>()); | input_vars_.push_back(std::make_shared<Var>()); | ||||
| } | } | ||||
| @@ -34,20 +40,39 @@ class LambNextMVWithDecayRule : public PatternProcessPass { | |||||
| constant_mul_input_vars_.push_back(std::make_shared<Var>()); | constant_mul_input_vars_.push_back(std::make_shared<Var>()); | ||||
| } | } | ||||
| constant_add2_y_ = std::make_shared<Var>(); | constant_add2_y_ = std::make_shared<Var>(); | ||||
| mul4_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name())); | |||||
| real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); | |||||
| real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName)); | |||||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||||
| } | } | ||||
| ~LambNextMVWithDecayRule() override = default; | ~LambNextMVWithDecayRule() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| virtual const BaseRef DefineAnotherPattern() const; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| private: | |||||
| AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3, | |||||
| const AnfNodePtr &add5, const AnfNodePtr &real_div0, | |||||
| const AnfNodePtr &real_div1, const EquivPtr &equiv) const; | |||||
| protected: | |||||
| bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const; | |||||
| // check two patterns whether share the same nodes or not | |||||
| bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const; | |||||
| AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node, | |||||
| const AnfNodePtr &add3, const AnfNodePtr &add5, const EquivPtr &equiv) const; | |||||
| AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3, | |||||
| const AnfNodePtr &add5, const EquivPtr &equiv) const; | |||||
| PatternEngine child_pattern_engine_; | |||||
| PrimitiveVarMapPtr child_primitive_vars_; | |||||
| std::vector<VarPtr> input_vars_; | std::vector<VarPtr> input_vars_; | ||||
| std::vector<VarPtr> constant_mul_input_vars_; | std::vector<VarPtr> constant_mul_input_vars_; | ||||
| // nodes which two patterns share | |||||
| VarPtr constant_add2_y_; | VarPtr constant_add2_y_; | ||||
| VarPtr mul4_var_; | |||||
| VarPtr real_div0_var_; | |||||
| VarPtr real_div1_var_; | |||||
| // part of output nodes | |||||
| VarPtr add0_var_; | |||||
| VarPtr add1_var_; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -64,6 +64,8 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph, | |||||
| AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); | AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); | ||||
| AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); | AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); | ||||
| AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(false), new_node); | AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(false), new_node); | ||||
| auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); | |||||
| AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); | |||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| @@ -64,6 +64,8 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, | |||||
| AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); | AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); | ||||
| AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); | AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); | ||||
| AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(true), new_node); | AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(true), new_node); | ||||
| auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); | |||||
| AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); | |||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| @@ -539,5 +539,169 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i | |||||
| primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); | primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); | ||||
| } | } | ||||
| } | } | ||||
| bool AnfEqual(const BaseRef &a, const BaseRef &b) { | |||||
| if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) { | |||||
| auto a_node = utils::cast<AnfNodePtr>(a); | |||||
| auto b_node = utils::cast<AnfNodePtr>(b); | |||||
| if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) { | |||||
| auto a_value_node = a_node->cast<ValueNodePtr>(); | |||||
| auto a_value = a_value_node->value(); | |||||
| auto a_prim = a_value->cast<PrimitivePtr>(); | |||||
| auto b_value_node = b_node->cast<ValueNodePtr>(); | |||||
| auto b_value = b_value_node->value(); | |||||
| auto b_prim = b_value->cast<PrimitivePtr>(); | |||||
| return a_prim->name() == b_prim->name(); | |||||
| } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) { | |||||
| auto a_value_node_ptr = a_node->cast<ValueNodePtr>(); | |||||
| if (a_value_node_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||||
| } | |||||
| auto a_value_ptr = a_value_node_ptr->value(); | |||||
| if (a_value_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||||
| } | |||||
| auto b_value_node_ptr = b_node->cast<ValueNodePtr>(); | |||||
| if (b_value_node_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||||
| } | |||||
| auto b_value_ptr = b_value_node_ptr->value(); | |||||
| if (b_value_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||||
| } | |||||
| return (*a_value_ptr) == (*b_value_ptr); | |||||
| } | |||||
| MS_LOG(DEBUG) << "check AnfNodePtr equal"; | |||||
| } | |||||
| if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) { | |||||
| MS_LOG(DEBUG) << "check GraphPtr equal"; | |||||
| } | |||||
| return a == b; | |||||
| } | |||||
| bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { | |||||
| // To matchCNode and Kernel's type | |||||
| if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) { | |||||
| return true; | |||||
| } | |||||
| return a.type() == b.type(); | |||||
| } | |||||
| namespace { | |||||
| ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { | |||||
| if (utils::isa<int>(sexp)) { | |||||
| return NewValueNode(utils::cast<int>(sexp)); | |||||
| } | |||||
| if (utils::isa<float>(sexp)) { | |||||
| return NewValueNode(utils::cast<float>(sexp)); | |||||
| } | |||||
| if (utils::isa<bool>(sexp)) { | |||||
| return NewValueNode(utils::cast<bool>(sexp)); | |||||
| } | |||||
| if (utils::isa<ValuePtr>(sexp)) { | |||||
| return NewValueNode(utils::cast<ValuePtr>(sexp)); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) { | |||||
| if (utils::isa<FuncGraphPtr>(graph)) { | |||||
| return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph)); | |||||
| } | |||||
| if (utils::isa<VarPtr>(graph)) { | |||||
| return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph)); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { | |||||
| if (utils::isa<VarPtr>(graph)) { | |||||
| MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); | |||||
| return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr); | |||||
| } | |||||
| if (utils::isa<FuncGraphPtr>(graph)) { | |||||
| MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); | |||||
| return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph)); | |||||
| } | |||||
| MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); | |||||
| return nullptr; | |||||
| } | |||||
| AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | |||||
| bool multigraph) { | |||||
| MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | |||||
| std::vector<AnfNodePtr> input_nodes; | |||||
| const auto &tuple = utils::cast<VectorRef>(sexp); | |||||
| if (multigraph && utils::isa<VarPtr>(graph)) { | |||||
| for (auto &x : tuple) { | |||||
| AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true); | |||||
| input_nodes.push_back(node); | |||||
| } | |||||
| VarPtr var_ptr = utils::cast<VarPtr>(graph); | |||||
| return std::make_shared<CNode>(input_nodes, var_ptr); | |||||
| } | |||||
| for (auto &x : tuple) { | |||||
| AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); | |||||
| input_nodes.push_back(node); | |||||
| } | |||||
| return CreateCNodeWithGraph(input_nodes, graph); | |||||
| } | |||||
| } // namespace | |||||
| AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { | |||||
| MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | |||||
| MS_EXCEPTION_IF_NULL(primitive_vars); | |||||
| if (utils::isa<VectorRef>(sexp)) { | |||||
| return HandleSexpVector(sexp, graph, primitive_vars, multigraph); | |||||
| } | |||||
| if (utils::isa<VarPtr>(sexp)) { | |||||
| auto var_ptr = utils::cast<VarPtr>(sexp); | |||||
| MS_EXCEPTION_IF_NULL(var_ptr); | |||||
| if (var_ptr->primitive()) { | |||||
| (*primitive_vars)[var_ptr->primitive()] = var_ptr; | |||||
| return NewValueNode(var_ptr->primitive()); | |||||
| } | |||||
| return CreateVarNodeWithSexp(sexp, graph); | |||||
| } | |||||
| if (utils::isa<AnfNodePtr>(sexp)) { | |||||
| return utils::cast<AnfNodePtr>(sexp); | |||||
| } | |||||
| auto value_node = CreateValueNodeWithSexp(sexp); | |||||
| if (value_node == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString(); | |||||
| } | |||||
| return value_node; | |||||
| } | |||||
| bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) { | |||||
| MS_EXCEPTION_IF_NULL(equiv1); | |||||
| MS_EXCEPTION_IF_NULL(equiv2); | |||||
| MS_EXCEPTION_IF_NULL(var_node); | |||||
| auto equiv1_node = GetAnfNodeByVar(equiv1, var_node); | |||||
| MS_EXCEPTION_IF_NULL(equiv1_node); | |||||
| auto equiv2_node = GetAnfNodeByVar(equiv2, var_node); | |||||
| MS_EXCEPTION_IF_NULL(equiv2_node); | |||||
| return equiv1_node == equiv2_node; | |||||
| } | |||||
| AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) { | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| MS_EXCEPTION_IF_NULL(var_node); | |||||
| auto iter = (*equiv).find(var_node); | |||||
| if (iter == (*equiv).end()) { | |||||
| MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched."; | |||||
| return nullptr; | |||||
| } | |||||
| auto res = utils::cast<AnfNodePtr>(iter->second); | |||||
| if (res == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node"; | |||||
| } | |||||
| return res; | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "session/kernel_graph.h" | #include "session/kernel_graph.h" | ||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| #include "pre_activate/common/pattern_engine.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -162,6 +163,19 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | ||||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs); | void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs); | ||||
| bool AnfEqual(const BaseRef &a, const BaseRef &b); | |||||
| bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); | |||||
| AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | |||||
| bool multigraph = false); | |||||
| // Check var_node in two equivs is the same node | |||||
| bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node); | |||||
| // Get anf_node from equiv by var_node | |||||
| AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node); | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ | #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ | ||||
| @@ -29,148 +29,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | |||||
| AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | |||||
| bool multigraph); | |||||
| ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { | |||||
| if (utils::isa<int>(sexp)) { | |||||
| return NewValueNode(utils::cast<int>(sexp)); | |||||
| } | |||||
| if (utils::isa<float>(sexp)) { | |||||
| return NewValueNode(utils::cast<float>(sexp)); | |||||
| } | |||||
| if (utils::isa<bool>(sexp)) { | |||||
| return NewValueNode(utils::cast<bool>(sexp)); | |||||
| } | |||||
| if (utils::isa<ValuePtr>(sexp)) { | |||||
| return NewValueNode(utils::cast<ValuePtr>(sexp)); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) { | |||||
| if (utils::isa<FuncGraphPtr>(graph)) { | |||||
| return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph)); | |||||
| } | |||||
| if (utils::isa<VarPtr>(graph)) { | |||||
| return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph)); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { | |||||
| if (utils::isa<VarPtr>(graph)) { | |||||
| MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); | |||||
| return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr); | |||||
| } | |||||
| if (utils::isa<FuncGraphPtr>(graph)) { | |||||
| MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); | |||||
| return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph)); | |||||
| } | |||||
| MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); | |||||
| return nullptr; | |||||
| } | |||||
| AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | |||||
| bool multigraph = false) { | |||||
| MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | |||||
| MS_EXCEPTION_IF_NULL(primitive_vars); | |||||
| if (utils::isa<VectorRef>(sexp)) { | |||||
| return HandleSexpVector(sexp, graph, primitive_vars, multigraph); | |||||
| } | |||||
| if (utils::isa<VarPtr>(sexp)) { | |||||
| auto var_ptr = utils::cast<VarPtr>(sexp); | |||||
| MS_EXCEPTION_IF_NULL(var_ptr); | |||||
| if (var_ptr->primitive()) { | |||||
| (*primitive_vars)[var_ptr->primitive()] = var_ptr; | |||||
| return NewValueNode(var_ptr->primitive()); | |||||
| } | |||||
| return CreateVarNodeWithSexp(sexp, graph); | |||||
| } | |||||
| if (utils::isa<AnfNodePtr>(sexp)) { | |||||
| return utils::cast<AnfNodePtr>(sexp); | |||||
| } | |||||
| auto value_node = CreateValueNodeWithSexp(sexp); | |||||
| if (value_node == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString(); | |||||
| } | |||||
| return value_node; | |||||
| } | |||||
| AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, | |||||
| bool multigraph) { | |||||
| MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); | |||||
| std::vector<AnfNodePtr> input_nodes; | |||||
| const auto &tuple = utils::cast<VectorRef>(sexp); | |||||
| if (multigraph && utils::isa<VarPtr>(graph)) { | |||||
| for (auto &x : tuple) { | |||||
| AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true); | |||||
| input_nodes.push_back(node); | |||||
| } | |||||
| VarPtr var_ptr = utils::cast<VarPtr>(graph); | |||||
| return std::make_shared<CNode>(input_nodes, var_ptr); | |||||
| } | |||||
| for (auto &x : tuple) { | |||||
| AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); | |||||
| input_nodes.push_back(node); | |||||
| } | |||||
| return CreateCNodeWithGraph(input_nodes, graph); | |||||
| } | |||||
| } // namespace | |||||
| static bool AnfEqual(const BaseRef &a, const BaseRef &b) { | |||||
| if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) { | |||||
| auto a_node = utils::cast<AnfNodePtr>(a); | |||||
| auto b_node = utils::cast<AnfNodePtr>(b); | |||||
| if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) { | |||||
| auto a_value_node = a_node->cast<ValueNodePtr>(); | |||||
| auto a_value = a_value_node->value(); | |||||
| auto a_prim = a_value->cast<PrimitivePtr>(); | |||||
| auto b_value_node = b_node->cast<ValueNodePtr>(); | |||||
| auto b_value = b_value_node->value(); | |||||
| auto b_prim = b_value->cast<PrimitivePtr>(); | |||||
| return a_prim->name() == b_prim->name(); | |||||
| } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) { | |||||
| auto a_value_node_ptr = a_node->cast<ValueNodePtr>(); | |||||
| if (a_value_node_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||||
| } | |||||
| auto a_value_ptr = a_value_node_ptr->value(); | |||||
| if (a_value_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||||
| } | |||||
| auto b_value_node_ptr = b_node->cast<ValueNodePtr>(); | |||||
| if (b_value_node_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "cast value node ptr fail"; | |||||
| } | |||||
| auto b_value_ptr = b_value_node_ptr->value(); | |||||
| if (b_value_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "value ptr is nullptr"; | |||||
| } | |||||
| return (*a_value_ptr) == (*b_value_ptr); | |||||
| } | |||||
| MS_LOG(DEBUG) << "check AnfNodePtr equal"; | |||||
| } | |||||
| if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) { | |||||
| MS_LOG(DEBUG) << "check GraphPtr equal"; | |||||
| } | |||||
| return a == b; | |||||
| } | |||||
| static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { | |||||
| // To matchCNode and Kernel's type | |||||
| if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) { | |||||
| return true; | |||||
| } | |||||
| return a.type() == b.type(); | |||||
| } | |||||
| PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph) | PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph) | ||||
| : NodePass(name), | : NodePass(name), | ||||
| multigraph_(multigraph), | multigraph_(multigraph), | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include "pre_activate/common/pattern_engine.h" | #include "pre_activate/common/pattern_engine.h" | ||||
| #include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| #include "pre_activate/common/helper.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "common/backend_common_test.h" | #include "common/backend_common_test.h" | ||||
| #include "common/py_func_graph_fetcher.h" | #include "common/py_func_graph_fetcher.h" | ||||
| #include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" | #include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" | ||||
| #include "debug/anf_ir_dump.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -14,7 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| from mindspore.ops import Primitive | from mindspore.ops import Primitive | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import _grad_ops as G | from mindspore.ops.operations import _grad_ops as G | ||||
| batch_norm_grad = G.BatchNormGrad(is_training=False) | batch_norm_grad = G.BatchNormGrad(is_training=False) | ||||
| @@ -24,7 +24,6 @@ make_tuple = Primitive('make_tuple') | |||||
| tuple_getitem = Primitive('tuple_getitem') | tuple_getitem = Primitive('tuple_getitem') | ||||
| LambNextMVWithDecay = Primitive('LambNextMVWithDecay') | LambNextMVWithDecay = Primitive('LambNextMVWithDecay') | ||||
| class FnDict: | class FnDict: | ||||
| def __init__(self): | def __init__(self): | ||||
| self.fnDict = {} | self.fnDict = {} | ||||
| @@ -35,7 +34,6 @@ class FnDict: | |||||
| def __getitem__(self, name): | def __getitem__(self, name): | ||||
| return self.fnDict[name] | return self.fnDict[name] | ||||
| def test_lamb_next_mv_with_decay_rule(tag): | def test_lamb_next_mv_with_decay_rule(tag): | ||||
| fns = FnDict() | fns = FnDict() | ||||