Merge pull request !441 from YuJianfeng/mastertags/v0.2.0-alpha
| @@ -42,17 +42,69 @@ AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_g | |||
| const BaseRef AdamApplyOneFusion::DefinePattern() const { | |||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | |||
| VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||
| VectorRef true_div0 = VectorRef({prim_deal_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | |||
| } | |||
| const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const { | |||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | |||
| VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); | |||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | |||
| } | |||
| const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const { | |||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||
| VectorRef mul3 = VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[0]}), mul_x_input_vars_[3]}); | |||
| VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | |||
| } | |||
| const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const { | |||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | |||
| VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | |||
| } | |||
| const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const { | |||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | |||
| VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); | |||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | |||
| } | |||
| const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -18,21 +18,23 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "pre_activate/common/optimizer.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| constexpr size_t kAdamApplyOneInputNum = 5; | |||
| constexpr size_t kAdamApplyOneMulInputNum = 4; | |||
| constexpr size_t kAdamApplyOneInputVarNum = 5; | |||
| constexpr size_t kAdamApplyOneMulInputVarNum = 4; | |||
| class AdamApplyOneFusion : public PatternProcessPass { | |||
| public: | |||
| explicit AdamApplyOneFusion(bool multigraph = true) : PatternProcessPass("adam_apply_one_fusion", multigraph) { | |||
| for (size_t i = 0; i < kAdamApplyOneInputNum; ++i) { | |||
| explicit AdamApplyOneFusion(const std::string &name = "adam_apply_one_fusion", bool multigraph = true) | |||
| : PatternProcessPass(name, multigraph) { | |||
| for (size_t i = 0; i < kAdamApplyOneInputVarNum; ++i) { | |||
| input_vars_.push_back(std::make_shared<Var>()); | |||
| } | |||
| for (size_t i = 0; i < kAdamApplyOneMulInputNum; ++i) { | |||
| for (size_t i = 0; i < kAdamApplyOneMulInputVarNum; ++i) { | |||
| mul_x_input_vars_.push_back(std::make_shared<Var>()); | |||
| } | |||
| add2_y_ = std::make_shared<Var>(); | |||
| @@ -44,7 +46,7 @@ class AdamApplyOneFusion : public PatternProcessPass { | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| protected: | |||
| AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; | |||
| std::vector<VarPtr> input_vars_; | |||
| std::vector<VarPtr> mul_x_input_vars_; | |||
| @@ -52,6 +54,42 @@ class AdamApplyOneFusion : public PatternProcessPass { | |||
| VarPtr add0_var_; | |||
| VarPtr add1_var_; | |||
| }; | |||
| class AdamApplyOneCond1Fusion : public AdamApplyOneFusion { | |||
| public: | |||
| explicit AdamApplyOneCond1Fusion(bool multigraph = true) | |||
| : AdamApplyOneFusion("adam_apply_one_cond1_fusion", multigraph) {} | |||
| ~AdamApplyOneCond1Fusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| }; | |||
| class AdamApplyOneCond2Fusion : public AdamApplyOneFusion { | |||
| public: | |||
| explicit AdamApplyOneCond2Fusion(bool multigraph = true) | |||
| : AdamApplyOneFusion("adam_apply_one_cond2_fusion", multigraph) {} | |||
| ~AdamApplyOneCond2Fusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| }; | |||
| class AdamApplyOneCond3Fusion : public AdamApplyOneFusion { | |||
| public: | |||
| explicit AdamApplyOneCond3Fusion(bool multigraph = true) | |||
| : AdamApplyOneFusion("adam_apply_one_cond3_fusion", multigraph) {} | |||
| ~AdamApplyOneCond3Fusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| }; | |||
| class AdamApplyOneCond4Fusion : public AdamApplyOneFusion { | |||
| public: | |||
| explicit AdamApplyOneCond4Fusion(bool multigraph = true) | |||
| : AdamApplyOneFusion("adam_apply_one_cond4_fusion", multigraph) {} | |||
| ~AdamApplyOneCond4Fusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ | |||
| @@ -66,5 +66,156 @@ TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_fusion) { | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond1_fusion) { | |||
| /* | |||
| * def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||
| * square0 = Square(input0) | |||
| * mul1 = Mul(mul1_x, input0) | |||
| * mul0 = Mul(mul0_x, input2) | |||
| * mul2 = Mul(mul2_x, input1) | |||
| * mul3 = Mul(mul3_x, square0) | |||
| * add0 = Add(mul0, mul1) | |||
| * add1 = Add(mul2, mul3) | |||
| * sqrt0 = Sqrt(add1) | |||
| * add2 = Add(add2_y, sqrt0) | |||
| * true_div0 = RealDiv(add0, add2) | |||
| * mul4 = Mul(input4, true_div0) | |||
| * sub0 = Sub(input3, mul4) | |||
| * outputs = make_tuple(add1, add0, sub0) | |||
| * output = tuple_getitem(outputs, 0) | |||
| * return output | |||
| */ | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond1"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 10; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AdamApplyOneCond1Fusion>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond2_fusion) { | |||
| /* | |||
| * def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||
| * square0 = Square(input0) | |||
| * mul1 = Mul(mul1_x, input0) | |||
| * mul0 = Mul(mul0_x, input2) | |||
| * mul2 = Mul(mul2_x, input1) | |||
| * mul3 = Mul(square0, mul3_x) | |||
| * add0 = Add(mul0, mul1) | |||
| * add1 = Add(mul2, mul3) | |||
| * sqrt0 = Sqrt(add1) | |||
| * add2 = Add(sqrt0, add2_y) | |||
| * true_div0 = RealDiv(add0, add2) | |||
| * mul4 = Mul(true_div0, input4) | |||
| * sub0 = Sub(input3, mul4) | |||
| * outputs = make_tuple(add1, add0, sub0) | |||
| * output = tuple_getitem(outputs, 0) | |||
| * return output | |||
| */ | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond2"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 10; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AdamApplyOneCond2Fusion>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond3_fusion) { | |||
| /* | |||
| * def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||
| * square0 = Square(input0) | |||
| * mul1 = Mul(mul1_x, input0) | |||
| * mul0 = Mul(mul0_x, input2) | |||
| * mul2 = Mul(mul2_x, input1) | |||
| * mul3 = Mul(mul3_x, square0) | |||
| * add0 = Add(mul0, mul1) | |||
| * add1 = Add(mul2, mul3) | |||
| * sqrt0 = Sqrt(add1) | |||
| * add2 = Add(sqrt0, add2_y) | |||
| * true_div0 = RealDiv(add0, add2) | |||
| * mul4 = Mul(true_div0, input4) | |||
| * sub0 = Sub(input3, mul4) | |||
| * outputs = make_tuple(add1, add0, sub0) | |||
| * output = tuple_getitem(outputs, 0) | |||
| * return output | |||
| */ | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond3"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 10; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AdamApplyOneCond3Fusion>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond4_fusion) { | |||
| /* | |||
| * def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||
| * square0 = Square(input0) | |||
| * mul1 = Mul(mul1_x, input0) | |||
| * mul0 = Mul(mul0_x, input2) | |||
| * mul2 = Mul(mul2_x, input1) | |||
| * mul3 = Mul(mul3_x, square0) | |||
| * add0 = Add(mul0, mul1) | |||
| * add1 = Add(mul2, mul3) | |||
| * sqrt0 = Sqrt(add1) | |||
| * add2 = Add(add2_y, sqrt0) | |||
| * true_div0 = RealDiv(add0, add2) | |||
| * mul4 = Mul(true_div0, input4) | |||
| * sub0 = Sub(input3, mul4) | |||
| * outputs = make_tuple(add1, add0, sub0) | |||
| * output = tuple_getitem(outputs, 0) | |||
| * return output | |||
| */ | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond4"); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 10; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AdamApplyOneCond4Fusion>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -58,6 +58,78 @@ def test_adam_apply_one_fusion(tag): | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||
| square0 = Square(input0) | |||
| mul1 = Mul(mul1_x, input0) | |||
| mul0 = Mul(mul0_x, input2) | |||
| mul2 = Mul(mul2_x, input1) | |||
| mul3 = Mul(mul3_x, square0) | |||
| add0 = Add(mul0, mul1) | |||
| add1 = Add(mul2, mul3) | |||
| sqrt0 = Sqrt(add1) | |||
| add2 = Add(add2_y, sqrt0) | |||
| true_div0 = RealDiv(add0, add2) | |||
| mul4 = Mul(input4, true_div0) | |||
| sub0 = Sub(input3, mul4) | |||
| outputs = make_tuple(add1, add0, sub0) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||
| square0 = Square(input0) | |||
| mul1 = Mul(mul1_x, input0) | |||
| mul0 = Mul(mul0_x, input2) | |||
| mul2 = Mul(mul2_x, input1) | |||
| mul3 = Mul(square0, mul3_x) | |||
| add0 = Add(mul0, mul1) | |||
| add1 = Add(mul2, mul3) | |||
| sqrt0 = Sqrt(add1) | |||
| add2 = Add(sqrt0, add2_y) | |||
| true_div0 = RealDiv(add0, add2) | |||
| mul4 = Mul(true_div0, input4) | |||
| sub0 = Sub(input3, mul4) | |||
| outputs = make_tuple(add1, add0, sub0) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||
| square0 = Square(input0) | |||
| mul1 = Mul(mul1_x, input0) | |||
| mul0 = Mul(mul0_x, input2) | |||
| mul2 = Mul(mul2_x, input1) | |||
| mul3 = Mul(mul3_x, square0) | |||
| add0 = Add(mul0, mul1) | |||
| add1 = Add(mul2, mul3) | |||
| sqrt0 = Sqrt(add1) | |||
| add2 = Add(sqrt0, add2_y) | |||
| true_div0 = RealDiv(add0, add2) | |||
| mul4 = Mul(true_div0, input4) | |||
| sub0 = Sub(input3, mul4) | |||
| outputs = make_tuple(add1, add0, sub0) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||
| square0 = Square(input0) | |||
| mul1 = Mul(mul1_x, input0) | |||
| mul0 = Mul(mul0_x, input2) | |||
| mul2 = Mul(mul2_x, input1) | |||
| mul3 = Mul(mul3_x, square0) | |||
| add0 = Add(mul0, mul1) | |||
| add1 = Add(mul2, mul3) | |||
| sqrt0 = Sqrt(add1) | |||
| add2 = Add(add2_y, sqrt0) | |||
| true_div0 = RealDiv(add0, add2) | |||
| mul4 = Mul(true_div0, input4) | |||
| sub0 = Sub(input3, mul4) | |||
| outputs = make_tuple(add1, add0, sub0) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||
| adam_apply_one = AdamApplyOne(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y) | |||