| @@ -99,6 +99,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||
| ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>()); | |||
| @@ -114,11 +115,15 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||
| ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<ClipByValueFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRule>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond1Fusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond2Fusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond3Fusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond4Fusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond1>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond2>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond3>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond4>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond5>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>()); | |||
| @@ -41,24 +41,104 @@ std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const Equ | |||
| return {NewValueNode(prim), input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y}; | |||
| } | |||
| const BaseRef AdamApplyOneWithDecayRule::DefinePattern() const { | |||
| const BaseRef AdamApplyOneWithDecayRuleCond1::DefinePattern() const { | |||
| auto sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| auto real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| VectorRef mul0_pattern({prim::kPrimMul, mul0_x_, input2_}); | |||
| VectorRef mul1_pattern({prim::kPrimMul, mul1_x_, input0_}); | |||
| VectorRef square0_pattern({prim::kPrimSquare, input0_}); | |||
| VectorRef add0_pattern({add0_var_, mul0_pattern, mul1_pattern}); | |||
| VectorRef mul2_pattern({prim::kPrimMul, mul2_x_, input1_}); | |||
| VectorRef mul3_pattern({prim::kPrimMul, mul3_x_, square0_pattern}); | |||
| VectorRef add1_pattern({add1_var_, mul2_pattern, mul3_pattern}); | |||
| VectorRef sqrt0_pattern({sqrt, add1_pattern}); | |||
| VectorRef add2_pattern({prim::kPrimTensorAdd, sqrt0_pattern, add2_y_}); | |||
| VectorRef mul4_pattern({prim::kPrimMul, mul4_x_, input3_}); | |||
| VectorRef real_div_pattern({real_div, add0_pattern, add2_pattern}); | |||
| VectorRef add3_pattern({prim::kPrimTensorAdd, real_div_pattern, mul4_pattern}); | |||
| VectorRef mul5_pattern({prim::kPrimMul, input4_, add3_pattern}); | |||
| VectorRef sub0_pattern({prim::kPrimSub, input3_, mul5_pattern}); | |||
| return sub0_pattern; | |||
| VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); | |||
| VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); | |||
| VectorRef square0({prim::kPrimSquare, input0_}); | |||
| VectorRef add0({add0_var_, mul0, mul1}); | |||
| VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); | |||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | |||
| VectorRef add1({add1_var_, mul2, mul3}); | |||
| VectorRef sqrt0({sqrt, add1}); | |||
| VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); | |||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | |||
| VectorRef real_div0({real_div, add0, add2}); | |||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||
| VectorRef mul5({prim::kPrimMul, input4_, add3}); | |||
| VectorRef sub0({prim::kPrimSub, input3_, mul5}); | |||
| return sub0; | |||
| } | |||
| const BaseRef AdamApplyOneWithDecayRuleCond2::DefinePattern() const { | |||
| auto sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| auto real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| VectorRef mul0({prim::kPrimMul, input2_, mul0_x_}); | |||
| VectorRef mul1({prim::kPrimMul, input0_, mul1_x_}); | |||
| VectorRef square0({prim::kPrimSquare, input0_}); | |||
| VectorRef add0({add0_var_, mul0, mul1}); | |||
| VectorRef mul2({prim::kPrimMul, input1_, mul2_x_}); | |||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | |||
| VectorRef add1({add1_var_, mul2, mul3}); | |||
| VectorRef sqrt0({sqrt, add1}); | |||
| VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); | |||
| VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); | |||
| VectorRef real_div0({real_div, add0, add2}); | |||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | |||
| VectorRef sub0({prim::kPrimSub, input3_, mul5}); | |||
| return sub0; | |||
| } | |||
| const BaseRef AdamApplyOneWithDecayRuleCond3::DefinePattern() const { | |||
| auto sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| auto real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); | |||
| VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); | |||
| VectorRef square0({prim::kPrimSquare, input0_}); | |||
| VectorRef add0({add0_var_, mul0, mul1}); | |||
| VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); | |||
| VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); | |||
| VectorRef add1({add1_var_, mul2, mul3}); | |||
| VectorRef sqrt0({sqrt, add1}); | |||
| VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); | |||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | |||
| VectorRef real_div0({real_div, add0, add2}); | |||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | |||
| VectorRef sub0({prim::kPrimSub, input3_, mul5}); | |||
| return sub0; | |||
| } | |||
| const BaseRef AdamApplyOneWithDecayRuleCond4::DefinePattern() const { | |||
| auto sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| auto real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); | |||
| VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); | |||
| VectorRef square0({prim::kPrimSquare, input0_}); | |||
| VectorRef add0({add0_var_, mul0, mul1}); | |||
| VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); | |||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | |||
| VectorRef add1({add1_var_, mul2, mul3}); | |||
| VectorRef sqrt0({sqrt, add1}); | |||
| VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); | |||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | |||
| VectorRef real_div0({real_div, add0, add2}); | |||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | |||
| VectorRef sub0({prim::kPrimSub, input3_, mul5}); | |||
| return sub0; | |||
| } | |||
| const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const { | |||
| auto sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| auto real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); | |||
| VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); | |||
| VectorRef square0({prim::kPrimSquare, input0_}); | |||
| VectorRef add0({add0_var_, mul0, mul1}); | |||
| VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); | |||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | |||
| VectorRef add1({add1_var_, mul2, mul3}); | |||
| VectorRef sqrt0({sqrt, add1}); | |||
| VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); | |||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | |||
| VectorRef real_div0({real_div, add0, add2}); | |||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | |||
| VectorRef sub0({prim::kPrimSub, input3_, mul5}); | |||
| return sub0; | |||
| } | |||
| const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| @@ -18,14 +18,15 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "pre_activate/common/optimizer.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class AdamApplyOneWithDecayRule : public PatternProcessPass { | |||
| public: | |||
| explicit AdamApplyOneWithDecayRule(bool multigraph = true) | |||
| : PatternProcessPass("adam_apply_one_with_decay_rule", multigraph) { | |||
| explicit AdamApplyOneWithDecayRule(const std::string &name = "adam_apply_one_with_decay_rule", bool multigraph = true) | |||
| : PatternProcessPass(name, multigraph) { | |||
| input0_ = std::make_shared<Var>(); | |||
| input1_ = std::make_shared<Var>(); | |||
| input2_ = std::make_shared<Var>(); | |||
| @@ -41,10 +42,10 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass { | |||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | |||
| } | |||
| ~AdamApplyOneWithDecayRule() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const BaseRef DefinePattern() const override = 0; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| protected: | |||
| std::vector<AnfNodePtr> GetFusionNodeInputs(const EquivPtr &equiv) const; | |||
| VarPtr input0_; | |||
| VarPtr input1_; | |||
| @@ -60,6 +61,51 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass { | |||
| VarPtr add0_var_; | |||
| VarPtr add1_var_; | |||
| }; | |||
| class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule { | |||
| public: | |||
| explicit AdamApplyOneWithDecayRuleCond1(bool multigraph = true) | |||
| : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond1", multigraph) {} | |||
| ~AdamApplyOneWithDecayRuleCond1() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| }; | |||
| class AdamApplyOneWithDecayRuleCond2 : public AdamApplyOneWithDecayRule { | |||
| public: | |||
| explicit AdamApplyOneWithDecayRuleCond2(bool multigraph = true) | |||
| : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond2", multigraph) {} | |||
| ~AdamApplyOneWithDecayRuleCond2() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| }; | |||
| class AdamApplyOneWithDecayRuleCond3 : public AdamApplyOneWithDecayRule { | |||
| public: | |||
| explicit AdamApplyOneWithDecayRuleCond3(bool multigraph = true) | |||
| : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond3", multigraph) {} | |||
| ~AdamApplyOneWithDecayRuleCond3() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| }; | |||
| class AdamApplyOneWithDecayRuleCond4 : public AdamApplyOneWithDecayRule { | |||
| public: | |||
| explicit AdamApplyOneWithDecayRuleCond4(bool multigraph = true) | |||
| : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond4", multigraph) {} | |||
| ~AdamApplyOneWithDecayRuleCond4() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| }; | |||
| class AdamApplyOneWithDecayRuleCond5 : public AdamApplyOneWithDecayRule { | |||
| public: | |||
| explicit AdamApplyOneWithDecayRuleCond5(bool multigraph = true) | |||
| : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond5", multigraph) {} | |||
| ~AdamApplyOneWithDecayRuleCond5() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ | |||
| @@ -30,8 +30,8 @@ class TestHWOptimizeAdamApplyOneWithDecayRule : public BackendCommon { | |||
| UT::PyFuncGraphFetcher get_py_fun_; | |||
| }; | |||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before"); | |||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond1) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond1", "before"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| @@ -43,16 +43,16 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRule>()); | |||
| pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRuleCond1>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after"); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond1", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_no_match) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "no_match"); | |||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond2) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond2", "before"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| @@ -61,15 +61,78 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_no_match) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto origin_graph = std::make_shared<session::KernelGraph>(*fg); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRule>()); | |||
| pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRuleCond2>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond2", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond3) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond3", "before"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 11; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRuleCond3>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond3", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond4) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond4", "before"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 11; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRuleCond4>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond4", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond5) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond5", "before"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 11; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayRuleCond5>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond5", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -89,3 +89,168 @@ def test_adam_apply_one_with_decay_rule(tag): | |||
| return make_tuple(make_tuple(item0, item1, item2)) | |||
| return fns[tag] | |||
| def test_adam_apply_one_with_decay_rule_cond1(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||
| mul0 = mul(mul0_x, input2) | |||
| mul1 = mul(mul1_x, input0) | |||
| square0 = square(input0) | |||
| add0 = add(mul0, mul1) | |||
| mul2 = mul(mul2_x, input1) | |||
| mul3 = mul(mul3_x, square0) | |||
| add1 = add(mul2, mul3) | |||
| sqrt0 = sqrt(add1) | |||
| add2 = add(add2_y, sqrt0) | |||
| mul4 = mul(mul4_x, input3) | |||
| real_div0 = real_div(add0, add2) | |||
| add3 = add(mul4, real_div0) | |||
| mul5 = mul(input4, add3) | |||
| sub0 = sub(input3, mul5) | |||
| return make_tuple(add1, add0, sub0) | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||
| res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, | |||
| add2_y) | |||
| item0 = tuple_getitem(res, 0) | |||
| item1 = tuple_getitem(res, 1) | |||
| item2 = tuple_getitem(res, 2) | |||
| return make_tuple(make_tuple(item0, item1, item2)) | |||
| return fns[tag] | |||
| def test_adam_apply_one_with_decay_rule_cond2(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||
| mul0 = mul(input2, mul0_x) | |||
| mul1 = mul(input0, mul1_x) | |||
| square0 = square(input0) | |||
| add0 = add(mul0, mul1) | |||
| mul2 = mul(input1, mul2_x) | |||
| mul3 = mul(mul3_x, square0) | |||
| add1 = add(mul2, mul3) | |||
| sqrt0 = sqrt(add1) | |||
| add2 = add(sqrt0, add2_y) | |||
| mul4 = mul(input3, mul4_x) | |||
| real_div0 = real_div(add0, add2) | |||
| add3 = add(mul4, real_div0) | |||
| mul5 = mul(add3, input4) | |||
| sub0 = sub(input3, mul5) | |||
| return make_tuple(add1, add0, sub0) | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||
| res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, | |||
| add2_y) | |||
| item0 = tuple_getitem(res, 0) | |||
| item1 = tuple_getitem(res, 1) | |||
| item2 = tuple_getitem(res, 2) | |||
| return make_tuple(make_tuple(item0, item1, item2)) | |||
| return fns[tag] | |||
| def test_adam_apply_one_with_decay_rule_cond3(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||
| mul0 = mul(mul0_x, input2) | |||
| mul1 = mul(mul1_x, input0) | |||
| square0 = square(input0) | |||
| add0 = add(mul0, mul1) | |||
| mul2 = mul(mul2_x, input1) | |||
| mul3 = mul(square0, mul3_x) | |||
| add1 = add(mul2, mul3) | |||
| sqrt0 = sqrt(add1) | |||
| add2 = add(sqrt0, add2_y) | |||
| mul4 = mul(mul4_x, input3) | |||
| real_div0 = real_div(add0, add2) | |||
| add3 = add(mul4, real_div0) | |||
| mul5 = mul(add3, input4) | |||
| sub0 = sub(input3, mul5) | |||
| return make_tuple(add1, add0, sub0) | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||
| res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, | |||
| add2_y) | |||
| item0 = tuple_getitem(res, 0) | |||
| item1 = tuple_getitem(res, 1) | |||
| item2 = tuple_getitem(res, 2) | |||
| return make_tuple(make_tuple(item0, item1, item2)) | |||
| return fns[tag] | |||
| def test_adam_apply_one_with_decay_rule_cond4(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||
| mul0 = mul(mul0_x, input2) | |||
| mul1 = mul(mul1_x, input0) | |||
| square0 = square(input0) | |||
| add0 = add(mul0, mul1) | |||
| mul2 = mul(mul2_x, input1) | |||
| mul3 = mul(mul3_x, square0) | |||
| add1 = add(mul2, mul3) | |||
| sqrt0 = sqrt(add1) | |||
| add2 = add(add2_y, sqrt0) | |||
| mul4 = mul(mul4_x, input3) | |||
| real_div0 = real_div(add0, add2) | |||
| add3 = add(mul4, real_div0) | |||
| mul5 = mul(add3, input4) | |||
| sub0 = sub(input3, mul5) | |||
| return make_tuple(add1, add0, sub0) | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||
| res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, | |||
| add2_y) | |||
| item0 = tuple_getitem(res, 0) | |||
| item1 = tuple_getitem(res, 1) | |||
| item2 = tuple_getitem(res, 2) | |||
| return make_tuple(make_tuple(item0, item1, item2)) | |||
| return fns[tag] | |||
| def test_adam_apply_one_with_decay_rule_cond5(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||
| mul0 = mul(mul0_x, input2) | |||
| mul1 = mul(mul1_x, input0) | |||
| square0 = square(input0) | |||
| add0 = add(mul0, mul1) | |||
| mul2 = mul(mul2_x, input1) | |||
| mul3 = mul(mul3_x, square0) | |||
| add1 = add(mul2, mul3) | |||
| sqrt0 = sqrt(add1) | |||
| add2 = add(sqrt0, add2_y) | |||
| mul4 = mul(mul4_x, input3) | |||
| real_div0 = real_div(add0, add2) | |||
| add3 = add(mul4, real_div0) | |||
| mul5 = mul(add3, input4) | |||
| sub0 = sub(input3, mul5) | |||
| return make_tuple(add1, add0, sub0) | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||
| res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, | |||
| add2_y) | |||
| item0 = tuple_getitem(res, 0) | |||
| item1 = tuple_getitem(res, 1) | |||
| item2 = tuple_getitem(res, 2) | |||
| return make_tuple(make_tuple(item0, item1, item2)) | |||
| return fns[tag] | |||