Merge pull request !1328 from huanghui/LambNextMvWithDecayRuleConds-fusion-passtags/v0.3.0-alpha
| @@ -98,6 +98,9 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||
| ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond3>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>()); | |||
| @@ -163,5 +163,128 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph | |||
| } | |||
| return nullptr; | |||
| } | |||
| const BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const { | |||
| const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_rsqrt); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||
| VarPtr Zs = std::make_shared<SeqVar>(); | |||
| MS_EXCEPTION_IF_NULL(Xs); | |||
| MS_EXCEPTION_IF_NULL(Ys); | |||
| MS_EXCEPTION_IF_NULL(Zs); | |||
| VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); | |||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | |||
| VectorRef mul4 = VectorRef({mul4_var_, Zs}); | |||
| VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); | |||
| VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | |||
| VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); | |||
| VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||
| return add3; | |||
| } | |||
| const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const { | |||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_sqrt); | |||
| const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_deal_div); | |||
| VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); | |||
| VectorRef mul3 = VectorRef({prim::kPrimMul, input_vars_[0], constant_mul_input_vars_[3]}); | |||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | |||
| VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | |||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | |||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); | |||
| VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); | |||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||
| VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | |||
| VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | |||
| VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); | |||
| VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); | |||
| return add5; | |||
| } | |||
| const BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const { | |||
| const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_rsqrt); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||
| VarPtr Zs = std::make_shared<SeqVar>(); | |||
| MS_EXCEPTION_IF_NULL(Xs); | |||
| MS_EXCEPTION_IF_NULL(Ys); | |||
| MS_EXCEPTION_IF_NULL(Zs); | |||
| VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); | |||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | |||
| VectorRef mul4 = VectorRef({mul4_var_, Zs}); | |||
| VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); | |||
| VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | |||
| VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); | |||
| VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||
| return add3; | |||
| } | |||
| const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const { | |||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_sqrt); | |||
| const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_deal_div); | |||
| VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); | |||
| VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); | |||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | |||
| VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | |||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | |||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, sqrt1}); | |||
| VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); | |||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||
| VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | |||
| VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | |||
| VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); | |||
| VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); | |||
| return add5; | |||
| } | |||
| const BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const { | |||
| const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_rsqrt); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||
| VarPtr Zs = std::make_shared<SeqVar>(); | |||
| MS_EXCEPTION_IF_NULL(Xs); | |||
| MS_EXCEPTION_IF_NULL(Ys); | |||
| MS_EXCEPTION_IF_NULL(Zs); | |||
| VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); | |||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | |||
| VectorRef mul4 = VectorRef({mul4_var_, Zs}); | |||
| VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); | |||
| VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); | |||
| VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); | |||
| VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||
| return add3; | |||
| } | |||
| const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const { | |||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_sqrt); | |||
| const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_deal_div); | |||
| VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); | |||
| VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); | |||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | |||
| VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); | |||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | |||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); | |||
| VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); | |||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||
| VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); | |||
| VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); | |||
| VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]}); | |||
| VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); | |||
| return add5; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -74,6 +74,36 @@ class LambNextMVWithDecayRule : public PatternProcessPass { | |||
| VarPtr add0_var_; | |||
| VarPtr add1_var_; | |||
| }; | |||
| class LambNextMVWithDecayRuleCond1 : public LambNextMVWithDecayRule { | |||
| public: | |||
| explicit LambNextMVWithDecayRuleCond1(bool multigraph = true) | |||
| : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond1", multigraph) {} | |||
| ~LambNextMVWithDecayRuleCond1() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const BaseRef DefineAnotherPattern() const override; | |||
| }; | |||
| class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule { | |||
| public: | |||
| explicit LambNextMVWithDecayRuleCond2(bool multigraph = true) | |||
| : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond2", multigraph) {} | |||
| ~LambNextMVWithDecayRuleCond2() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const BaseRef DefineAnotherPattern() const override; | |||
| }; | |||
| class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule { | |||
| public: | |||
| explicit LambNextMVWithDecayRuleCond3(bool multigraph = true) | |||
| : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond3", multigraph) {} | |||
| ~LambNextMVWithDecayRuleCond3() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const BaseRef DefineAnotherPattern() const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -24,6 +24,8 @@ | |||
| #include "pre_activate/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| @@ -36,6 +38,14 @@ bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_ | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) { | |||
| if (!opt::IsUsedByOthers(graph, cnode)) { | |||
| auto full_name = cnode->fullname_with_scope(); | |||
| // exclude lamb and adam, and only work in bert | |||
| if (std::string::npos != full_name.find("adam") || std::string::npos != full_name.find("lamb") || | |||
| std::string::npos == full_name.find("bert")) { | |||
| MS_LOG(INFO) << "Mul is in adam or lamb or not a bert network, quit fusion"; | |||
| return false; | |||
| } | |||
| *mul = cnode; | |||
| *mul_index = index; | |||
| return true; | |||
| @@ -45,8 +55,7 @@ bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_ | |||
| } | |||
| return false; | |||
| } | |||
| namespace opt { | |||
| } // namespace | |||
| const BaseRef MulAddFusion::DefinePattern() const { | |||
| VarPtr x = std::make_shared<Var>(); | |||
| VarPtr y = std::make_shared<Var>(); | |||
| @@ -74,7 +83,12 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP | |||
| for (size_t index = 1; index < mul->size(); ++index) { | |||
| inputs.push_back(mul->input(index)); | |||
| } | |||
| inputs.push_back(add->input(add->size() - mul_index)); | |||
| auto another_input_node = add->input(add->size() - mul_index); | |||
| if (IsUsedByOthers(graph, another_input_node)) { | |||
| MS_LOG(INFO) << "Add's another input node is used by others, do not fuse"; | |||
| return nullptr; | |||
| } | |||
| inputs.push_back(another_input_node); | |||
| auto fusion_node = graph->NewCNode(inputs); | |||
| fusion_node->set_scope(add->scope()); | |||
| fusion_node->set_abstract(add->abstract()); | |||
| @@ -253,5 +253,134 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after"); | |||
| EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond1) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "before"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 13; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond1>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond1_un_match) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "un_match"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 13; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto origin_graph = std::make_shared<session::KernelGraph>(*fg); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond1>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "un_match"); | |||
| EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond2) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "before"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 13; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| DumpIR("fg.ir", fg, true); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond2>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "after"); | |||
| DumpIR("g_after.ir", g_after, true); | |||
| DumpIR("new_graph.ir", new_graph, true); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond2_un_match) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "un_match"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 13; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto origin_graph = std::make_shared<session::KernelGraph>(*fg); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond2>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "un_match"); | |||
| EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond3) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "before"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 13; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| DumpIR("fg.ir", fg, true); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond3>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "after"); | |||
| DumpIR("g_after.ir", g_after, true); | |||
| DumpIR("new_graph.ir", new_graph, true); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond3_un_match) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "un_match"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 13; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto origin_graph = std::make_shared<session::KernelGraph>(*fg); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond3>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "un_match"); | |||
| EXPECT_FALSE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -37,6 +37,10 @@ TEST_F(TestHWMulAddFusion, test_mul_add_fusion1) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto scope = std::make_shared<Scope>("bert"); | |||
| for (auto nd : fg->execution_order()) { | |||
| nd->set_scope(scope); | |||
| } | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| @@ -57,6 +61,10 @@ TEST_F(TestHWMulAddFusion, test_mul_add_fusion2) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto scope = std::make_shared<Scope>("bert"); | |||
| for (auto nd : fg->execution_order()) { | |||
| nd->set_scope(scope); | |||
| } | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| @@ -174,3 +174,201 @@ def test_lamb_next_mv_with_decay_rule(tag): | |||
| return output | |||
| return fns[tag] | |||
| def test_lamb_next_mv_with_decay_rule_cond1(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| mul1 = Mul(input3, constant_mul1_sub) | |||
| mul0 = Mul(input4, constant_mul0_x) | |||
| add0 = Add(mul0, mul1) | |||
| mul2 = Mul(input1, constant_mul2_x) | |||
| mul3 = Mul(input0, constant_mul3_sub1) | |||
| add1 = Add(mul2, mul3) | |||
| real_div1 = RealDiv(add1, input2) | |||
| add2 = Add(constant_add2_y, real_div1) | |||
| sqrt1 = Sqrt(real_div1) | |||
| real_div0 = RealDiv(add0, input5) | |||
| add4 = Add(sqrt1, constant_add2_y) | |||
| sqrt0 = Rsqrt(add2) | |||
| mul4 = Mul(constant_mul4_x, input6) | |||
| real_div4 = RealDiv(real_div0, add4) | |||
| real_div2 = Mul(sqrt0, real_div0) | |||
| add5 = Add(mul4, real_div4) | |||
| add3 = Add(mul4, real_div2) | |||
| outputs = make_tuple(add3, add0, add1, add5) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6, | |||
| constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, | |||
| constant_add2_y) | |||
| outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1), | |||
| tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3)) | |||
| output = tuple_getitem(outputs, 0) | |||
| return make_tuple(output) | |||
| @fns | |||
| def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| mul1 = Mul(input3, constant_mul1_sub) | |||
| mul0 = Mul(input4, constant_mul0_x) | |||
| add0 = Add(mul0, mul1) | |||
| mul2 = Mul(input1, constant_mul2_x) | |||
| mul3 = Mul(input0, constant_mul3_sub1) | |||
| add1 = Add(mul2, mul3) | |||
| real_div1 = RealDiv(add1, input2) | |||
| add2 = Add(constant_add2_y, real_div1) | |||
| sqrt1 = Sqrt(real_div1) | |||
| real_div0 = RealDiv(add0, input5) | |||
| add4 = Add(sqrt1, constant_add2_y) | |||
| sqrt0 = Rsqrt(add2) | |||
| mul4 = Mul(constant_mul4_x, input6) | |||
| real_div4 = RealDiv(real_div0, add4) | |||
| real_div2 = Mul(sqrt0, real_div0) | |||
| add5 = Add(mul4, real_div4) | |||
| # un match | |||
| add3 = Add(real_div2, mul4) | |||
| outputs = make_tuple(add3, add0, add1, add5) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| return fns[tag] | |||
| def test_lamb_next_mv_with_decay_rule_cond2(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| mul1 = Mul(constant_mul1_sub, input3) | |||
| mul0 = Mul(constant_mul0_x, input4) | |||
| add0 = Add(mul0, mul1) | |||
| mul2 = Mul(constant_mul2_x, input1) | |||
| mul3 = Mul(constant_mul3_sub1, input0) | |||
| add1 = Add(mul2, mul3) | |||
| real_div1 = RealDiv(add1, input2) | |||
| add2 = Add(constant_add2_y, real_div1) | |||
| sqrt1 = Sqrt(real_div1) | |||
| real_div0 = RealDiv(add0, input5) | |||
| add4 = Add(constant_add2_y, sqrt1) | |||
| sqrt0 = Rsqrt(add2) | |||
| mul4 = Mul(constant_mul4_x, input6) | |||
| real_div4 = RealDiv(real_div0, add4) | |||
| real_div2 = Mul(sqrt0, real_div0) | |||
| add5 = Add(mul4, real_div4) | |||
| add3 = Add(mul4, real_div2) | |||
| outputs = make_tuple(add3, add0, add1, add5) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6, | |||
| constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, | |||
| constant_add2_y) | |||
| outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1), | |||
| tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3)) | |||
| output = tuple_getitem(outputs, 0) | |||
| return make_tuple(output) | |||
| @fns | |||
| def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| mul1 = Mul(constant_mul1_sub, input3) | |||
| mul0 = Mul(constant_mul0_x, input4) | |||
| add0 = Add(mul0, mul1) | |||
| mul2 = Mul(constant_mul2_x, input1) | |||
| mul3 = Mul(constant_mul3_sub1, input0) | |||
| add1 = Add(mul2, mul3) | |||
| real_div1 = RealDiv(add1, input2) | |||
| add2 = Add(constant_add2_y, real_div1) | |||
| sqrt1 = Sqrt(real_div1) | |||
| real_div0 = RealDiv(add0, input5) | |||
| add4 = Add(constant_add2_y, sqrt1) | |||
| sqrt0 = Rsqrt(add2) | |||
| mul4 = Mul(constant_mul4_x, input6) | |||
| real_div4 = RealDiv(real_div0, add4) | |||
| real_div2 = Mul(sqrt0, real_div0) | |||
| add5 = Add(mul4, real_div4) | |||
| # un_match | |||
| add3 = Add(real_div2, mul4) | |||
| outputs = make_tuple(add3, add0, add1, add5) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| return fns[tag] | |||
| def test_lamb_next_mv_with_decay_rule_cond3(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| mul1 = Mul(input3, constant_mul1_sub) | |||
| mul0 = Mul(input4, constant_mul0_x) | |||
| add0 = Add(mul0, mul1) | |||
| mul2 = Mul(input1, constant_mul2_x) | |||
| mul3 = Mul(constant_mul3_sub1, input0) | |||
| add1 = Add(mul2, mul3) | |||
| real_div1 = RealDiv(add1, input2) | |||
| add2 = Add(real_div1, constant_add2_y) | |||
| sqrt1 = Sqrt(real_div1) | |||
| real_div0 = RealDiv(add0, input5) | |||
| add4 = Add(sqrt1, constant_add2_y) | |||
| sqrt0 = Rsqrt(add2) | |||
| mul4 = Mul(input6, constant_mul4_x) | |||
| real_div4 = RealDiv(real_div0, add4) | |||
| real_div2 = Mul(sqrt0, real_div0) | |||
| add5 = Add(mul4, real_div4) | |||
| add3 = Add(mul4, real_div2) | |||
| outputs = make_tuple(add3, add0, add1, add5) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6, | |||
| constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, | |||
| constant_add2_y) | |||
| outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1), | |||
| tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3)) | |||
| output = tuple_getitem(outputs, 0) | |||
| return make_tuple(output) | |||
| @fns | |||
| def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| mul1 = Mul(input3, constant_mul1_sub) | |||
| mul0 = Mul(input4, constant_mul0_x) | |||
| add0 = Add(mul0, mul1) | |||
| mul2 = Mul(input1, constant_mul2_x) | |||
| mul3 = Mul(constant_mul3_sub1, input0) | |||
| add1 = Add(mul2, mul3) | |||
| real_div1 = RealDiv(add1, input2) | |||
| add2 = Add(real_div1, constant_add2_y) | |||
| sqrt1 = Sqrt(real_div1) | |||
| real_div0 = RealDiv(add0, input5) | |||
| add4 = Add(sqrt1, constant_add2_y) | |||
| sqrt0 = Rsqrt(add2) | |||
| mul4 = Mul(input6, constant_mul4_x) | |||
| real_div4 = RealDiv(real_div0, add4) | |||
| real_div2 = Mul(sqrt0, real_div0) | |||
| add5 = Add(mul4, real_div4) | |||
| # un match | |||
| add3 = Add(real_div2, mul4) | |||
| outputs = make_tuple(add3, add0, add1, add5) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| return fns[tag] | |||