Merge pull request !1328 from huanghui/LambNextMvWithDecayRuleConds-fusion-passtags/v0.3.0-alpha
| @@ -98,6 +98,9 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); | ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>()); | ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>()); | ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond3>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>()); | ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>()); | ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>()); | ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>()); | ||||
| @@ -163,5 +163,128 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| const BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const { | |||||
| const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName); | |||||
| MS_EXCEPTION_IF_NULL(prim_rsqrt); | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||||
| VarPtr Zs = std::make_shared<SeqVar>(); | |||||
| MS_EXCEPTION_IF_NULL(Xs); | |||||
| MS_EXCEPTION_IF_NULL(Ys); | |||||
| MS_EXCEPTION_IF_NULL(Zs); | |||||
| VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | |||||
| VectorRef mul4 = VectorRef({mul4_var_, Zs}); | |||||
| VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); | |||||
| VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | |||||
| VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); | |||||
| VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||||
| return add3; | |||||
| } | |||||
| const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const { | |||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| MS_EXCEPTION_IF_NULL(prim_sqrt); | |||||
| const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| MS_EXCEPTION_IF_NULL(prim_deal_div); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, input_vars_[0], constant_mul_input_vars_[3]}); | |||||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | |||||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | |||||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | |||||
| VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | |||||
| VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); | |||||
| VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); | |||||
| return add5; | |||||
| } | |||||
| const BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const { | |||||
| const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName); | |||||
| MS_EXCEPTION_IF_NULL(prim_rsqrt); | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||||
| VarPtr Zs = std::make_shared<SeqVar>(); | |||||
| MS_EXCEPTION_IF_NULL(Xs); | |||||
| MS_EXCEPTION_IF_NULL(Ys); | |||||
| MS_EXCEPTION_IF_NULL(Zs); | |||||
| VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | |||||
| VectorRef mul4 = VectorRef({mul4_var_, Zs}); | |||||
| VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); | |||||
| VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | |||||
| VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); | |||||
| VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||||
| return add3; | |||||
| } | |||||
| const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const { | |||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| MS_EXCEPTION_IF_NULL(prim_sqrt); | |||||
| const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| MS_EXCEPTION_IF_NULL(prim_deal_div); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); | |||||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | |||||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | |||||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, sqrt1}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | |||||
| VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | |||||
| VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); | |||||
| VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); | |||||
| return add5; | |||||
| } | |||||
| const BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const { | |||||
| const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName); | |||||
| MS_EXCEPTION_IF_NULL(prim_rsqrt); | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||||
| VarPtr Zs = std::make_shared<SeqVar>(); | |||||
| MS_EXCEPTION_IF_NULL(Xs); | |||||
| MS_EXCEPTION_IF_NULL(Ys); | |||||
| MS_EXCEPTION_IF_NULL(Zs); | |||||
| VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | |||||
| VectorRef mul4 = VectorRef({mul4_var_, Zs}); | |||||
| VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); | |||||
| VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | |||||
| VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); | |||||
| VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||||
| return add3; | |||||
| } | |||||
| const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const { | |||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| MS_EXCEPTION_IF_NULL(prim_sqrt); | |||||
| const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| MS_EXCEPTION_IF_NULL(prim_deal_div); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); | |||||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | |||||
| VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | |||||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | |||||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | |||||
| VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | |||||
| VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]}); | |||||
| VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); | |||||
| return add5; | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -74,6 +74,36 @@ class LambNextMVWithDecayRule : public PatternProcessPass { | |||||
| VarPtr add0_var_; | VarPtr add0_var_; | ||||
| VarPtr add1_var_; | VarPtr add1_var_; | ||||
| }; | }; | ||||
| class LambNextMVWithDecayRuleCond1 : public LambNextMVWithDecayRule { | |||||
| public: | |||||
| explicit LambNextMVWithDecayRuleCond1(bool multigraph = true) | |||||
| : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond1", multigraph) {} | |||||
| ~LambNextMVWithDecayRuleCond1() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const BaseRef DefineAnotherPattern() const override; | |||||
| }; | |||||
| class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule { | |||||
| public: | |||||
| explicit LambNextMVWithDecayRuleCond2(bool multigraph = true) | |||||
| : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond2", multigraph) {} | |||||
| ~LambNextMVWithDecayRuleCond2() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const BaseRef DefineAnotherPattern() const override; | |||||
| }; | |||||
| class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule { | |||||
| public: | |||||
| explicit LambNextMVWithDecayRuleCond3(bool multigraph = true) | |||||
| : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond3", multigraph) {} | |||||
| ~LambNextMVWithDecayRuleCond3() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const BaseRef DefineAnotherPattern() const override; | |||||
| }; | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,6 +24,8 @@ | |||||
| #include "pre_activate/common/helper.h" | #include "pre_activate/common/helper.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | |||||
| namespace { | |||||
| bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) { | bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(add); | MS_EXCEPTION_IF_NULL(add); | ||||
| @@ -36,6 +38,14 @@ bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_ | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) { | if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) { | ||||
| if (!opt::IsUsedByOthers(graph, cnode)) { | if (!opt::IsUsedByOthers(graph, cnode)) { | ||||
| auto full_name = cnode->fullname_with_scope(); | |||||
| // exclude lamb and adam, and only work in bert | |||||
| if (std::string::npos != full_name.find("adam") || std::string::npos != full_name.find("lamb") || | |||||
| std::string::npos == full_name.find("bert")) { | |||||
| MS_LOG(INFO) << "Mul is in adam or lamb or not a bert network, quit fusion"; | |||||
| return false; | |||||
| } | |||||
| *mul = cnode; | *mul = cnode; | ||||
| *mul_index = index; | *mul_index = index; | ||||
| return true; | return true; | ||||
| @@ -45,8 +55,7 @@ bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_ | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| namespace opt { | |||||
| } // namespace | |||||
| const BaseRef MulAddFusion::DefinePattern() const { | const BaseRef MulAddFusion::DefinePattern() const { | ||||
| VarPtr x = std::make_shared<Var>(); | VarPtr x = std::make_shared<Var>(); | ||||
| VarPtr y = std::make_shared<Var>(); | VarPtr y = std::make_shared<Var>(); | ||||
| @@ -74,7 +83,12 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP | |||||
| for (size_t index = 1; index < mul->size(); ++index) { | for (size_t index = 1; index < mul->size(); ++index) { | ||||
| inputs.push_back(mul->input(index)); | inputs.push_back(mul->input(index)); | ||||
| } | } | ||||
| inputs.push_back(add->input(add->size() - mul_index)); | |||||
| auto another_input_node = add->input(add->size() - mul_index); | |||||
| if (IsUsedByOthers(graph, another_input_node)) { | |||||
| MS_LOG(INFO) << "Add's another input node is used by others, do not fuse"; | |||||
| return nullptr; | |||||
| } | |||||
| inputs.push_back(another_input_node); | |||||
| auto fusion_node = graph->NewCNode(inputs); | auto fusion_node = graph->NewCNode(inputs); | ||||
| fusion_node->set_scope(add->scope()); | fusion_node->set_scope(add->scope()); | ||||
| fusion_node->set_abstract(add->abstract()); | fusion_node->set_abstract(add->abstract()); | ||||
| @@ -253,5 +253,134 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after"); | FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after"); | ||||
| EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); | EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond1) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "before"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 13; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond1>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond1_un_match) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "un_match"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 13; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto origin_graph = std::make_shared<session::KernelGraph>(*fg); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond1>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "un_match"); | |||||
| EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond2) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "before"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 13; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| DumpIR("fg.ir", fg, true); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond2>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "after"); | |||||
| DumpIR("g_after.ir", g_after, true); | |||||
| DumpIR("new_graph.ir", new_graph, true); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond2_un_match) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "un_match"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 13; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto origin_graph = std::make_shared<session::KernelGraph>(*fg); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond2>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "un_match"); | |||||
| EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond3) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "before"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 13; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| DumpIR("fg.ir", fg, true); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond3>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "after"); | |||||
| DumpIR("g_after.ir", g_after, true); | |||||
| DumpIR("new_graph.ir", new_graph, true); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond3_un_match) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "un_match"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 13; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto origin_graph = std::make_shared<session::KernelGraph>(*fg); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond3>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "un_match"); | |||||
| EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,6 +37,10 @@ TEST_F(TestHWMulAddFusion, test_mul_add_fusion1) { | |||||
| args_spec_list.push_back(x_abstract); | args_spec_list.push_back(x_abstract); | ||||
| } | } | ||||
| auto fg = GetKernelGraph(g, args_spec_list); | auto fg = GetKernelGraph(g, args_spec_list); | ||||
| auto scope = std::make_shared<Scope>("bert"); | |||||
| for (auto nd : fg->execution_order()) { | |||||
| nd->set_scope(scope); | |||||
| } | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto pm = std::make_shared<opt::PassManager>(); | auto pm = std::make_shared<opt::PassManager>(); | ||||
| @@ -57,6 +61,10 @@ TEST_F(TestHWMulAddFusion, test_mul_add_fusion2) { | |||||
| args_spec_list.push_back(x_abstract); | args_spec_list.push_back(x_abstract); | ||||
| } | } | ||||
| auto fg = GetKernelGraph(g, args_spec_list); | auto fg = GetKernelGraph(g, args_spec_list); | ||||
| auto scope = std::make_shared<Scope>("bert"); | |||||
| for (auto nd : fg->execution_order()) { | |||||
| nd->set_scope(scope); | |||||
| } | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto pm = std::make_shared<opt::PassManager>(); | auto pm = std::make_shared<opt::PassManager>(); | ||||
| @@ -174,3 +174,201 @@ def test_lamb_next_mv_with_decay_rule(tag): | |||||
| return output | return output | ||||
| return fns[tag] | return fns[tag] | ||||
| def test_lamb_next_mv_with_decay_rule_cond1(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||||
| mul1 = Mul(input3, constant_mul1_sub) | |||||
| mul0 = Mul(input4, constant_mul0_x) | |||||
| add0 = Add(mul0, mul1) | |||||
| mul2 = Mul(input1, constant_mul2_x) | |||||
| mul3 = Mul(input0, constant_mul3_sub1) | |||||
| add1 = Add(mul2, mul3) | |||||
| real_div1 = RealDiv(add1, input2) | |||||
| add2 = Add(constant_add2_y, real_div1) | |||||
| sqrt1 = Sqrt(real_div1) | |||||
| real_div0 = RealDiv(add0, input5) | |||||
| add4 = Add(sqrt1, constant_add2_y) | |||||
| sqrt0 = Rsqrt(add2) | |||||
| mul4 = Mul(constant_mul4_x, input6) | |||||
| real_div4 = RealDiv(real_div0, add4) | |||||
| real_div2 = Mul(sqrt0, real_div0) | |||||
| add5 = Add(mul4, real_div4) | |||||
| add3 = Add(mul4, real_div2) | |||||
| outputs = make_tuple(add3, add0, add1, add5) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| @fns | |||||
| def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||||
| lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6, | |||||
| constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, | |||||
| constant_add2_y) | |||||
| outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1), | |||||
| tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3)) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return make_tuple(output) | |||||
| @fns | |||||
| def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||||
| mul1 = Mul(input3, constant_mul1_sub) | |||||
| mul0 = Mul(input4, constant_mul0_x) | |||||
| add0 = Add(mul0, mul1) | |||||
| mul2 = Mul(input1, constant_mul2_x) | |||||
| mul3 = Mul(input0, constant_mul3_sub1) | |||||
| add1 = Add(mul2, mul3) | |||||
| real_div1 = RealDiv(add1, input2) | |||||
| add2 = Add(constant_add2_y, real_div1) | |||||
| sqrt1 = Sqrt(real_div1) | |||||
| real_div0 = RealDiv(add0, input5) | |||||
| add4 = Add(sqrt1, constant_add2_y) | |||||
| sqrt0 = Rsqrt(add2) | |||||
| mul4 = Mul(constant_mul4_x, input6) | |||||
| real_div4 = RealDiv(real_div0, add4) | |||||
| real_div2 = Mul(sqrt0, real_div0) | |||||
| add5 = Add(mul4, real_div4) | |||||
| # un match | |||||
| add3 = Add(real_div2, mul4) | |||||
| outputs = make_tuple(add3, add0, add1, add5) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| return fns[tag] | |||||
| def test_lamb_next_mv_with_decay_rule_cond2(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||||
| mul1 = Mul(constant_mul1_sub, input3) | |||||
| mul0 = Mul(constant_mul0_x, input4) | |||||
| add0 = Add(mul0, mul1) | |||||
| mul2 = Mul(constant_mul2_x, input1) | |||||
| mul3 = Mul(constant_mul3_sub1, input0) | |||||
| add1 = Add(mul2, mul3) | |||||
| real_div1 = RealDiv(add1, input2) | |||||
| add2 = Add(constant_add2_y, real_div1) | |||||
| sqrt1 = Sqrt(real_div1) | |||||
| real_div0 = RealDiv(add0, input5) | |||||
| add4 = Add(constant_add2_y, sqrt1) | |||||
| sqrt0 = Rsqrt(add2) | |||||
| mul4 = Mul(constant_mul4_x, input6) | |||||
| real_div4 = RealDiv(real_div0, add4) | |||||
| real_div2 = Mul(sqrt0, real_div0) | |||||
| add5 = Add(mul4, real_div4) | |||||
| add3 = Add(mul4, real_div2) | |||||
| outputs = make_tuple(add3, add0, add1, add5) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| @fns | |||||
| def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||||
| lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6, | |||||
| constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, | |||||
| constant_add2_y) | |||||
| outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1), | |||||
| tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3)) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return make_tuple(output) | |||||
| @fns | |||||
| def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||||
| mul1 = Mul(constant_mul1_sub, input3) | |||||
| mul0 = Mul(constant_mul0_x, input4) | |||||
| add0 = Add(mul0, mul1) | |||||
| mul2 = Mul(constant_mul2_x, input1) | |||||
| mul3 = Mul(constant_mul3_sub1, input0) | |||||
| add1 = Add(mul2, mul3) | |||||
| real_div1 = RealDiv(add1, input2) | |||||
| add2 = Add(constant_add2_y, real_div1) | |||||
| sqrt1 = Sqrt(real_div1) | |||||
| real_div0 = RealDiv(add0, input5) | |||||
| add4 = Add(constant_add2_y, sqrt1) | |||||
| sqrt0 = Rsqrt(add2) | |||||
| mul4 = Mul(constant_mul4_x, input6) | |||||
| real_div4 = RealDiv(real_div0, add4) | |||||
| real_div2 = Mul(sqrt0, real_div0) | |||||
| add5 = Add(mul4, real_div4) | |||||
| # un_match | |||||
| add3 = Add(real_div2, mul4) | |||||
| outputs = make_tuple(add3, add0, add1, add5) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| return fns[tag] | |||||
| def test_lamb_next_mv_with_decay_rule_cond3(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||||
| mul1 = Mul(input3, constant_mul1_sub) | |||||
| mul0 = Mul(input4, constant_mul0_x) | |||||
| add0 = Add(mul0, mul1) | |||||
| mul2 = Mul(input1, constant_mul2_x) | |||||
| mul3 = Mul(constant_mul3_sub1, input0) | |||||
| add1 = Add(mul2, mul3) | |||||
| real_div1 = RealDiv(add1, input2) | |||||
| add2 = Add(real_div1, constant_add2_y) | |||||
| sqrt1 = Sqrt(real_div1) | |||||
| real_div0 = RealDiv(add0, input5) | |||||
| add4 = Add(sqrt1, constant_add2_y) | |||||
| sqrt0 = Rsqrt(add2) | |||||
| mul4 = Mul(input6, constant_mul4_x) | |||||
| real_div4 = RealDiv(real_div0, add4) | |||||
| real_div2 = Mul(sqrt0, real_div0) | |||||
| add5 = Add(mul4, real_div4) | |||||
| add3 = Add(mul4, real_div2) | |||||
| outputs = make_tuple(add3, add0, add1, add5) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| @fns | |||||
| def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||||
| lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6, | |||||
| constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, | |||||
| constant_add2_y) | |||||
| outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1), | |||||
| tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3)) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return make_tuple(output) | |||||
| @fns | |||||
| def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||||
| mul1 = Mul(input3, constant_mul1_sub) | |||||
| mul0 = Mul(input4, constant_mul0_x) | |||||
| add0 = Add(mul0, mul1) | |||||
| mul2 = Mul(input1, constant_mul2_x) | |||||
| mul3 = Mul(constant_mul3_sub1, input0) | |||||
| add1 = Add(mul2, mul3) | |||||
| real_div1 = RealDiv(add1, input2) | |||||
| add2 = Add(real_div1, constant_add2_y) | |||||
| sqrt1 = Sqrt(real_div1) | |||||
| real_div0 = RealDiv(add0, input5) | |||||
| add4 = Add(sqrt1, constant_add2_y) | |||||
| sqrt0 = Rsqrt(add2) | |||||
| mul4 = Mul(input6, constant_mul4_x) | |||||
| real_div4 = RealDiv(real_div0, add4) | |||||
| real_div2 = Mul(sqrt0, real_div0) | |||||
| add5 = Add(mul4, real_div4) | |||||
| # un match | |||||
| add3 = Add(real_div2, mul4) | |||||
| outputs = make_tuple(add3, add0, add1, add5) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| return fns[tag] | |||||