| @@ -104,6 +104,9 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond3>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond4>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond1>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond2>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond3>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>()); | |||
| @@ -116,9 +116,116 @@ const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const A | |||
| return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv); | |||
| } | |||
| const BaseRef LambNextMVRuleCond1::DefinePattern() const { | |||
| const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName); | |||
| auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_}); | |||
| auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_}); | |||
| auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_}); | |||
| auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); | |||
| auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_}); | |||
| auto add0 = VectorRef({add0_var_, mul0, mul1}); | |||
| auto add1 = VectorRef({add1_var_, mul2, mul3}); | |||
| auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); | |||
| auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); | |||
| auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); | |||
| auto sqrt0 = VectorRef({prim_rsqrt, add2}); | |||
| auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); | |||
| return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||
| } | |||
| BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const { | |||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||
| // Two patterns share: real_div0, real_div1, add2_y_ | |||
| VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); | |||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | |||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | |||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt1}); | |||
| VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); | |||
| return real_div4; | |||
| } | |||
| const BaseRef LambNextMVRuleCond2::DefinePattern() const { | |||
| const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName); | |||
| auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_}); | |||
| auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_}); | |||
| auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_}); | |||
| auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); | |||
| auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_}); | |||
| auto add0 = VectorRef({add0_var_, mul0, mul1}); | |||
| auto add1 = VectorRef({add1_var_, mul2, mul3}); | |||
| auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); | |||
| auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); | |||
| auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); | |||
| auto sqrt0 = VectorRef({prim_rsqrt, add2}); | |||
| auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); | |||
| return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||
| } | |||
| BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const { | |||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||
| // Two patterns share: real_div0, real_div1, add2_y_ | |||
| VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); | |||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | |||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | |||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); | |||
| VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); | |||
| return real_div4; | |||
| } | |||
| const BaseRef LambNextMVRuleCond3::DefinePattern() const { | |||
| const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName); | |||
| auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_}); | |||
| auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_}); | |||
| auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_}); | |||
| auto mul3 = VectorRef({prim::kPrimMul, input0_, mul3_sub1_}); | |||
| auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_}); | |||
| auto add0 = VectorRef({add0_var_, mul0, mul1}); | |||
| auto add1 = VectorRef({add1_var_, mul2, mul3}); | |||
| auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); | |||
| auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); | |||
| auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); | |||
| auto sqrt0 = VectorRef({prim_rsqrt, add2}); | |||
| auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); | |||
| return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); | |||
| } | |||
| BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const { | |||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||
| // Two patterns share: real_div0, real_div1, add2_y_ | |||
| VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); | |||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | |||
| VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); | |||
| VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); | |||
| VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); | |||
| return real_div4; | |||
| } | |||
| const BaseRef LambNextMVRuleCond4::DefinePattern() const { | |||
| const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_rsqrt); | |||
| auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_}); | |||
| auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_}); | |||
| @@ -140,13 +247,9 @@ const BaseRef LambNextMVRuleCond4::DefinePattern() const { | |||
| BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const { | |||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_sqrt); | |||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| MS_EXCEPTION_IF_NULL(prim_real_div); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| VarPtr Ys = std::make_shared<SeqVar>(); | |||
| MS_EXCEPTION_IF_NULL(Xs); | |||
| MS_EXCEPTION_IF_NULL(Ys); | |||
| // Two patterns share: real_div0, real_div1, add2_y_ | |||
| VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); | |||
| VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); | |||
| @@ -87,6 +87,33 @@ class LambNextMVRule : public MultipleOutputPatternProcessPass { | |||
| VarPtr real_div2_var_; | |||
| }; | |||
| class LambNextMVRuleCond1 : public LambNextMVRule { | |||
| public: | |||
| explicit LambNextMVRuleCond1(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond1", multigraph) {} | |||
| ~LambNextMVRuleCond1() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| BaseRef DefineAnotherPattern() const override; | |||
| }; | |||
| class LambNextMVRuleCond2 : public LambNextMVRule { | |||
| public: | |||
| explicit LambNextMVRuleCond2(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond2", multigraph) {} | |||
| ~LambNextMVRuleCond2() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| BaseRef DefineAnotherPattern() const override; | |||
| }; | |||
| class LambNextMVRuleCond3 : public LambNextMVRule { | |||
| public: | |||
| explicit LambNextMVRuleCond3(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond3", multigraph) {} | |||
| ~LambNextMVRuleCond3() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| BaseRef DefineAnotherPattern() const override; | |||
| }; | |||
| class LambNextMVRuleCond4 : public LambNextMVRule { | |||
| public: | |||
| explicit LambNextMVRuleCond4(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond4", multigraph) {} | |||
| @@ -244,5 +244,125 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div1) { | |||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||
| } | |||
| TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond1_fusion) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond1", "before"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 13; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::LambNextMVRuleCond1>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond1", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond1_unmatched) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond1", "un_match"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 13; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto origin_graph = std::make_shared<session::KernelGraph>(*fg); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::LambNextMVRuleCond1>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||
| } | |||
| TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond2_fusion) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond2", "before"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 13; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::LambNextMVRuleCond2>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond2", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond2_unmatched) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond2", "un_match"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 13; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto origin_graph = std::make_shared<session::KernelGraph>(*fg); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::LambNextMVRuleCond2>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||
| } | |||
| TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond3_fusion) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond3", "before"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 13; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::LambNextMVRuleCond3>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond3", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond3_unmatched) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond3", "un_match"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 13; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto origin_graph = std::make_shared<session::KernelGraph>(*fg); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::LambNextMVRuleCond3>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -24,7 +24,6 @@ make_tuple = Primitive('make_tuple') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| LambNextMV = Primitive('LambNextMV') | |||
| class FnDict: | |||
| def __init__(self): | |||
| self.fnDict = {} | |||
| @@ -35,7 +34,6 @@ class FnDict: | |||
| def __getitem__(self, name): | |||
| return self.fnDict[name] | |||
| def test_lamb_next_mv_rule_cond4(tag): | |||
| fns = FnDict() | |||
| @@ -170,3 +168,192 @@ def test_lamb_next_mv_rule_cond4(tag): | |||
| return output | |||
| return fns[tag] | |||
| def test_lamb_next_mv_rule_cond1(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| mul0 = Mul(constant_mul0_x, input4) | |||
| mul1 = Mul(constant_mul1_sub, input3) | |||
| add0 = Add(mul0, mul1) | |||
| mul2 = Mul(constant_mul2_x, input1) | |||
| mul3 = Mul(constant_mul3_sub1, input0) | |||
| add1 = Add(mul2, mul3) | |||
| real_div1 = RealDiv(add1, input2) | |||
| add2 = Add(constant_add2_y, real_div1) | |||
| sqrt0 = Rsqrt(add2) | |||
| sqrt1 = Sqrt(real_div1) | |||
| add4 = Add(constant_add2_y, sqrt1) | |||
| real_div0 = RealDiv(add0, input5) | |||
| real_div4 = RealDiv(real_div0, add4) | |||
| real_div2 = Mul(sqrt0, real_div0) | |||
| mul4 = Mul(constant_mul4_x, input6) | |||
| add3 = Add(mul4, real_div2) | |||
| outputs = make_tuple(add3, add0, add1, real_div4) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6, | |||
| constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, | |||
| constant_mul4_x, constant_add2_y) | |||
| outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1), | |||
| tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3)) | |||
| output = tuple_getitem(outputs, 0) | |||
| return make_tuple(output) | |||
| @fns | |||
| def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| mul0 = Mul(constant_mul0_x, input4) | |||
| mul1 = Mul(constant_mul1_sub, input3) | |||
| add0 = Add(mul0, mul1) | |||
| mul2 = Mul(constant_mul2_x, input1) | |||
| mul3 = Mul(constant_mul3_sub1, input0) | |||
| add1 = Add(mul2, mul3) | |||
| real_div1 = RealDiv(add1, input2) | |||
| add2 = Add(constant_add2_y, real_div1) | |||
| sqrt0 = Rsqrt(add2) | |||
| sqrt1 = Sqrt(real_div1) | |||
| # un match | |||
| add4 = Add(sqrt1, constant_add2_y) | |||
| real_div0 = RealDiv(add0, input5) | |||
| real_div4 = RealDiv(real_div0, add4) | |||
| real_div2 = Mul(sqrt0, real_div0) | |||
| mul4 = Mul(constant_mul4_x, input6) | |||
| add3 = Add(mul4, real_div2) | |||
| outputs = make_tuple(add3, add0, add1, real_div4) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| return fns[tag] | |||
| def test_lamb_next_mv_rule_cond2(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| mul0 = Mul(input4, constant_mul0_x) | |||
| mul1 = Mul(input3, constant_mul1_sub) | |||
| add0 = Add(mul0, mul1) | |||
| mul2 = Mul(input1, constant_mul2_x) | |||
| mul3 = Mul(constant_mul3_sub1, input0) | |||
| add1 = Add(mul2, mul3) | |||
| real_div1 = RealDiv(add1, input2) | |||
| add2 = Add(constant_add2_y, real_div1) | |||
| sqrt0 = Rsqrt(add2) | |||
| sqrt1 = Sqrt(real_div1) | |||
| add4 = Add(sqrt1, constant_add2_y) | |||
| real_div0 = RealDiv(add0, input5) | |||
| real_div4 = RealDiv(real_div0, add4) | |||
| real_div2 = Mul(sqrt0, real_div0) | |||
| mul4 = Mul(input6, constant_mul4_x) | |||
| add3 = Add(mul4, real_div2) | |||
| outputs = make_tuple(add3, add0, add1, real_div4) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6, | |||
| constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, | |||
| constant_mul4_x, constant_add2_y) | |||
| outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1), | |||
| tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3)) | |||
| output = tuple_getitem(outputs, 0) | |||
| return make_tuple(output) | |||
| @fns | |||
| def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| mul0 = Mul(input4, constant_mul0_x) | |||
| mul1 = Mul(input3, constant_mul1_sub) | |||
| add0 = Add(mul0, mul1) | |||
| mul2 = Mul(input1, constant_mul2_x) | |||
| mul3 = Mul(constant_mul3_sub1, input0) | |||
| add1 = Add(mul2, mul3) | |||
| real_div1 = RealDiv(add1, input2) | |||
| add2 = Add(constant_add2_y, real_div1) | |||
| sqrt0 = Rsqrt(add2) | |||
| sqrt1 = Sqrt(real_div1) | |||
| # un match | |||
| add4 = Add(constant_add2_y, sqrt1) | |||
| real_div0 = RealDiv(add0, input5) | |||
| real_div4 = RealDiv(real_div0, add4) | |||
| real_div2 = Mul(sqrt0, real_div0) | |||
| mul4 = Mul(input6, constant_mul4_x) | |||
| add3 = Add(mul4, real_div2) | |||
| outputs = make_tuple(add3, add0, add1, real_div4) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| return fns[tag] | |||
| def test_lamb_next_mv_rule_cond3(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| mul0 = Mul(input4, constant_mul0_x) | |||
| mul1 = Mul(input3, constant_mul1_sub) | |||
| add0 = Add(mul0, mul1) | |||
| mul2 = Mul(input1, constant_mul2_x) | |||
| mul3 = Mul(input0, constant_mul3_sub1) | |||
| add1 = Add(mul2, mul3) | |||
| real_div1 = RealDiv(add1, input2) | |||
| add2 = Add(real_div1, constant_add2_y) | |||
| sqrt0 = Rsqrt(add2) | |||
| sqrt1 = Sqrt(real_div1) | |||
| add4 = Add(sqrt1, constant_add2_y) | |||
| real_div0 = RealDiv(add0, input5) | |||
| real_div4 = RealDiv(real_div0, add4) | |||
| real_div2 = Mul(sqrt0, real_div0) | |||
| mul4 = Mul(input6, constant_mul4_x) | |||
| add3 = Add(mul4, real_div2) | |||
| outputs = make_tuple(add3, add0, add1, real_div4) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6, | |||
| constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, | |||
| constant_mul4_x, constant_add2_y) | |||
| outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1), | |||
| tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3)) | |||
| output = tuple_getitem(outputs, 0) | |||
| return make_tuple(output) | |||
| @fns | |||
| def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub, | |||
| constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y): | |||
| mul0 = Mul(input4, constant_mul0_x) | |||
| mul1 = Mul(input3, constant_mul1_sub) | |||
| add0 = Add(mul0, mul1) | |||
| mul2 = Mul(input1, constant_mul2_x) | |||
| mul3 = Mul(input0, constant_mul3_sub1) | |||
| add1 = Add(mul2, mul3) | |||
| real_div1 = RealDiv(add1, input2) | |||
| add2 = Add(real_div1, constant_add2_y) | |||
| sqrt0 = Rsqrt(add2) | |||
| sqrt1 = Sqrt(real_div1) | |||
| # un match | |||
| add4 = Add(constant_add2_y, sqrt1) | |||
| real_div0 = RealDiv(add0, input5) | |||
| real_div4 = RealDiv(real_div0, add4) | |||
| real_div2 = Mul(sqrt0, real_div0) | |||
| mul4 = Mul(input6, constant_mul4_x) | |||
| add3 = Add(mul4, real_div2) | |||
| outputs = make_tuple(add3, add0, add1, real_div4) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| return fns[tag] | |||