Merge pull request !441 from YuJianfeng/mastertags/v0.2.0-alpha
| @@ -42,17 +42,69 @@ AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_g | |||||
| const BaseRef AdamApplyOneFusion::DefinePattern() const { | const BaseRef AdamApplyOneFusion::DefinePattern() const { | ||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | ||||
| const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | ||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | ||||
| VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); | VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); | ||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | ||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | ||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | ||||
| VectorRef true_div0 = VectorRef({prim_deal_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | ||||
| } | } | ||||
| const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const { | |||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | |||||
| VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); | |||||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | |||||
| } | |||||
| const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const { | |||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[0]}), mul_x_input_vars_[3]}); | |||||
| VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | |||||
| } | |||||
| const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const { | |||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | |||||
| VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | |||||
| } | |||||
| const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const { | |||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | |||||
| VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); | |||||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | |||||
| } | |||||
| const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const EquivPtr &equiv) const { | const EquivPtr &equiv) const { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| @@ -18,21 +18,23 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | |||||
| #include "pre_activate/common/optimizer.h" | #include "pre_activate/common/optimizer.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| constexpr size_t kAdamApplyOneInputNum = 5; | |||||
| constexpr size_t kAdamApplyOneMulInputNum = 4; | |||||
| constexpr size_t kAdamApplyOneInputVarNum = 5; | |||||
| constexpr size_t kAdamApplyOneMulInputVarNum = 4; | |||||
| class AdamApplyOneFusion : public PatternProcessPass { | class AdamApplyOneFusion : public PatternProcessPass { | ||||
| public: | public: | ||||
| explicit AdamApplyOneFusion(bool multigraph = true) : PatternProcessPass("adam_apply_one_fusion", multigraph) { | |||||
| for (size_t i = 0; i < kAdamApplyOneInputNum; ++i) { | |||||
| explicit AdamApplyOneFusion(const std::string &name = "adam_apply_one_fusion", bool multigraph = true) | |||||
| : PatternProcessPass(name, multigraph) { | |||||
| for (size_t i = 0; i < kAdamApplyOneInputVarNum; ++i) { | |||||
| input_vars_.push_back(std::make_shared<Var>()); | input_vars_.push_back(std::make_shared<Var>()); | ||||
| } | } | ||||
| for (size_t i = 0; i < kAdamApplyOneMulInputNum; ++i) { | |||||
| for (size_t i = 0; i < kAdamApplyOneMulInputVarNum; ++i) { | |||||
| mul_x_input_vars_.push_back(std::make_shared<Var>()); | mul_x_input_vars_.push_back(std::make_shared<Var>()); | ||||
| } | } | ||||
| add2_y_ = std::make_shared<Var>(); | add2_y_ = std::make_shared<Var>(); | ||||
| @@ -44,7 +46,7 @@ class AdamApplyOneFusion : public PatternProcessPass { | |||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| private: | |||||
| protected: | |||||
| AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; | AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; | ||||
| std::vector<VarPtr> input_vars_; | std::vector<VarPtr> input_vars_; | ||||
| std::vector<VarPtr> mul_x_input_vars_; | std::vector<VarPtr> mul_x_input_vars_; | ||||
| @@ -52,6 +54,42 @@ class AdamApplyOneFusion : public PatternProcessPass { | |||||
| VarPtr add0_var_; | VarPtr add0_var_; | ||||
| VarPtr add1_var_; | VarPtr add1_var_; | ||||
| }; | }; | ||||
| class AdamApplyOneCond1Fusion : public AdamApplyOneFusion { | |||||
| public: | |||||
| explicit AdamApplyOneCond1Fusion(bool multigraph = true) | |||||
| : AdamApplyOneFusion("adam_apply_one_cond1_fusion", multigraph) {} | |||||
| ~AdamApplyOneCond1Fusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| class AdamApplyOneCond2Fusion : public AdamApplyOneFusion { | |||||
| public: | |||||
| explicit AdamApplyOneCond2Fusion(bool multigraph = true) | |||||
| : AdamApplyOneFusion("adam_apply_one_cond2_fusion", multigraph) {} | |||||
| ~AdamApplyOneCond2Fusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| class AdamApplyOneCond3Fusion : public AdamApplyOneFusion { | |||||
| public: | |||||
| explicit AdamApplyOneCond3Fusion(bool multigraph = true) | |||||
| : AdamApplyOneFusion("adam_apply_one_cond3_fusion", multigraph) {} | |||||
| ~AdamApplyOneCond3Fusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| class AdamApplyOneCond4Fusion : public AdamApplyOneFusion { | |||||
| public: | |||||
| explicit AdamApplyOneCond4Fusion(bool multigraph = true) | |||||
| : AdamApplyOneFusion("adam_apply_one_cond4_fusion", multigraph) {} | |||||
| ~AdamApplyOneCond4Fusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ | #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ | ||||
| @@ -66,5 +66,156 @@ TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_fusion) { | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond1_fusion) { | |||||
| /* | |||||
| * def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| * square0 = Square(input0) | |||||
| * mul1 = Mul(mul1_x, input0) | |||||
| * mul0 = Mul(mul0_x, input2) | |||||
| * mul2 = Mul(mul2_x, input1) | |||||
| * mul3 = Mul(mul3_x, square0) | |||||
| * add0 = Add(mul0, mul1) | |||||
| * add1 = Add(mul2, mul3) | |||||
| * sqrt0 = Sqrt(add1) | |||||
| * add2 = Add(add2_y, sqrt0) | |||||
| * true_div0 = RealDiv(add0, add2) | |||||
| * mul4 = Mul(input4, true_div0) | |||||
| * sub0 = Sub(input3, mul4) | |||||
| * outputs = make_tuple(add1, add0, sub0) | |||||
| * output = tuple_getitem(outputs, 0) | |||||
| * return output | |||||
| */ | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond1"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 10; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneCond1Fusion>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond2_fusion) { | |||||
| /* | |||||
| * def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| * square0 = Square(input0) | |||||
| * mul1 = Mul(mul1_x, input0) | |||||
| * mul0 = Mul(mul0_x, input2) | |||||
| * mul2 = Mul(mul2_x, input1) | |||||
| * mul3 = Mul(square0, mul3_x) | |||||
| * add0 = Add(mul0, mul1) | |||||
| * add1 = Add(mul2, mul3) | |||||
| * sqrt0 = Sqrt(add1) | |||||
| * add2 = Add(sqrt0, add2_y) | |||||
| * true_div0 = RealDiv(add0, add2) | |||||
| * mul4 = Mul(true_div0, input4) | |||||
| * sub0 = Sub(input3, mul4) | |||||
| * outputs = make_tuple(add1, add0, sub0) | |||||
| * output = tuple_getitem(outputs, 0) | |||||
| * return output | |||||
| */ | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond2"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 10; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneCond2Fusion>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond3_fusion) { | |||||
| /* | |||||
| * def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| * square0 = Square(input0) | |||||
| * mul1 = Mul(mul1_x, input0) | |||||
| * mul0 = Mul(mul0_x, input2) | |||||
| * mul2 = Mul(mul2_x, input1) | |||||
| * mul3 = Mul(mul3_x, square0) | |||||
| * add0 = Add(mul0, mul1) | |||||
| * add1 = Add(mul2, mul3) | |||||
| * sqrt0 = Sqrt(add1) | |||||
| * add2 = Add(sqrt0, add2_y) | |||||
| * true_div0 = RealDiv(add0, add2) | |||||
| * mul4 = Mul(true_div0, input4) | |||||
| * sub0 = Sub(input3, mul4) | |||||
| * outputs = make_tuple(add1, add0, sub0) | |||||
| * output = tuple_getitem(outputs, 0) | |||||
| * return output | |||||
| */ | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond3"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 10; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneCond3Fusion>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond4_fusion) { | |||||
| /* | |||||
| * def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| * square0 = Square(input0) | |||||
| * mul1 = Mul(mul1_x, input0) | |||||
| * mul0 = Mul(mul0_x, input2) | |||||
| * mul2 = Mul(mul2_x, input1) | |||||
| * mul3 = Mul(mul3_x, square0) | |||||
| * add0 = Add(mul0, mul1) | |||||
| * add1 = Add(mul2, mul3) | |||||
| * sqrt0 = Sqrt(add1) | |||||
| * add2 = Add(add2_y, sqrt0) | |||||
| * true_div0 = RealDiv(add0, add2) | |||||
| * mul4 = Mul(true_div0, input4) | |||||
| * sub0 = Sub(input3, mul4) | |||||
| * outputs = make_tuple(add1, add0, sub0) | |||||
| * output = tuple_getitem(outputs, 0) | |||||
| * return output | |||||
| */ | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond4"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 10; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneCond4Fusion>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -58,6 +58,78 @@ def test_adam_apply_one_fusion(tag): | |||||
| output = tuple_getitem(outputs, 0) | output = tuple_getitem(outputs, 0) | ||||
| return output | return output | ||||
| @fns | |||||
| def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| square0 = Square(input0) | |||||
| mul1 = Mul(mul1_x, input0) | |||||
| mul0 = Mul(mul0_x, input2) | |||||
| mul2 = Mul(mul2_x, input1) | |||||
| mul3 = Mul(mul3_x, square0) | |||||
| add0 = Add(mul0, mul1) | |||||
| add1 = Add(mul2, mul3) | |||||
| sqrt0 = Sqrt(add1) | |||||
| add2 = Add(add2_y, sqrt0) | |||||
| true_div0 = RealDiv(add0, add2) | |||||
| mul4 = Mul(input4, true_div0) | |||||
| sub0 = Sub(input3, mul4) | |||||
| outputs = make_tuple(add1, add0, sub0) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| @fns | |||||
| def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| square0 = Square(input0) | |||||
| mul1 = Mul(mul1_x, input0) | |||||
| mul0 = Mul(mul0_x, input2) | |||||
| mul2 = Mul(mul2_x, input1) | |||||
| mul3 = Mul(square0, mul3_x) | |||||
| add0 = Add(mul0, mul1) | |||||
| add1 = Add(mul2, mul3) | |||||
| sqrt0 = Sqrt(add1) | |||||
| add2 = Add(sqrt0, add2_y) | |||||
| true_div0 = RealDiv(add0, add2) | |||||
| mul4 = Mul(true_div0, input4) | |||||
| sub0 = Sub(input3, mul4) | |||||
| outputs = make_tuple(add1, add0, sub0) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| @fns | |||||
| def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| square0 = Square(input0) | |||||
| mul1 = Mul(mul1_x, input0) | |||||
| mul0 = Mul(mul0_x, input2) | |||||
| mul2 = Mul(mul2_x, input1) | |||||
| mul3 = Mul(mul3_x, square0) | |||||
| add0 = Add(mul0, mul1) | |||||
| add1 = Add(mul2, mul3) | |||||
| sqrt0 = Sqrt(add1) | |||||
| add2 = Add(sqrt0, add2_y) | |||||
| true_div0 = RealDiv(add0, add2) | |||||
| mul4 = Mul(true_div0, input4) | |||||
| sub0 = Sub(input3, mul4) | |||||
| outputs = make_tuple(add1, add0, sub0) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| @fns | |||||
| def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| square0 = Square(input0) | |||||
| mul1 = Mul(mul1_x, input0) | |||||
| mul0 = Mul(mul0_x, input2) | |||||
| mul2 = Mul(mul2_x, input1) | |||||
| mul3 = Mul(mul3_x, square0) | |||||
| add0 = Add(mul0, mul1) | |||||
| add1 = Add(mul2, mul3) | |||||
| sqrt0 = Sqrt(add1) | |||||
| add2 = Add(add2_y, sqrt0) | |||||
| true_div0 = RealDiv(add0, add2) | |||||
| mul4 = Mul(true_div0, input4) | |||||
| sub0 = Sub(input3, mul4) | |||||
| outputs = make_tuple(add1, add0, sub0) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| @fns | @fns | ||||
| def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | ||||
| adam_apply_one = AdamApplyOne(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y) | adam_apply_one = AdamApplyOne(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y) | ||||