Merge pull request !5508 from YuJianfeng/adam_assigntags/v1.0.0
| @@ -24,7 +24,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv) const { | |||||
| std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv, | |||||
| const AnfNodePtr &final_node) const { | |||||
| MS_EXCEPTION_IF_NULL(equiv); | MS_EXCEPTION_IF_NULL(equiv); | ||||
| auto input0 = utils::cast<AnfNodePtr>((*equiv)[input0_]); | auto input0 = utils::cast<AnfNodePtr>((*equiv)[input0_]); | ||||
| auto input1 = utils::cast<AnfNodePtr>((*equiv)[input1_]); | auto input1 = utils::cast<AnfNodePtr>((*equiv)[input1_]); | ||||
| @@ -37,7 +38,12 @@ std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const Equ | |||||
| auto mul3_x = utils::cast<AnfNodePtr>((*equiv)[mul3_x_]); | auto mul3_x = utils::cast<AnfNodePtr>((*equiv)[mul3_x_]); | ||||
| auto mul4_x = utils::cast<AnfNodePtr>((*equiv)[mul4_x_]); | auto mul4_x = utils::cast<AnfNodePtr>((*equiv)[mul4_x_]); | ||||
| auto add2_y = utils::cast<AnfNodePtr>((*equiv)[add2_y_]); | auto add2_y = utils::cast<AnfNodePtr>((*equiv)[add2_y_]); | ||||
| auto prim = std::make_shared<Primitive>(kAdamApplyOneWithDecayOpName); | |||||
| PrimitivePtr prim = nullptr; | |||||
| if (AnfAlgo::CheckPrimitiveType(final_node, prim::kPrimDepend)) { | |||||
| prim = std::make_shared<Primitive>(kAdamApplyOneWithDecayAssignOpName); | |||||
| } else { | |||||
| prim = std::make_shared<Primitive>(kAdamApplyOneWithDecayOpName); | |||||
| } | |||||
| return {NewValueNode(prim), input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y}; | return {NewValueNode(prim), input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y}; | ||||
| } | } | ||||
| @@ -141,18 +147,152 @@ const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const { | |||||
| return sub0; | return sub0; | ||||
| } | } | ||||
| const BaseRef AdamApplyOneWithDecayAssignRuleCond1::DefinePattern() const { | |||||
| auto sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| auto real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); | |||||
| VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); | |||||
| VectorRef square0({prim::kPrimSquare, input0_}); | |||||
| VectorRef add0({add0_var_, mul0, mul1}); | |||||
| VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); | |||||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | |||||
| VectorRef add1({add1_var_, mul2, mul3}); | |||||
| VectorRef sqrt0({sqrt, add1}); | |||||
| VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); | |||||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | |||||
| VectorRef real_div0({real_div, add0, add2}); | |||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, input4_, add3}); | |||||
| VectorRef sub0({sub0_var_, input3_, mul5}); | |||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | |||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | |||||
| VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0}); | |||||
| VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); | |||||
| VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1}); | |||||
| return VectorRef({prim::kPrimDepend, depend1, assign2}); | |||||
| } | |||||
| const BaseRef AdamApplyOneWithDecayAssignRuleCond2::DefinePattern() const { | |||||
| auto sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| auto real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul0({prim::kPrimMul, input2_, mul0_x_}); | |||||
| VectorRef mul1({prim::kPrimMul, input0_, mul1_x_}); | |||||
| VectorRef square0({prim::kPrimSquare, input0_}); | |||||
| VectorRef add0({add0_var_, mul0, mul1}); | |||||
| VectorRef mul2({prim::kPrimMul, input1_, mul2_x_}); | |||||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | |||||
| VectorRef add1({add1_var_, mul2, mul3}); | |||||
| VectorRef sqrt0({sqrt, add1}); | |||||
| VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); | |||||
| VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); | |||||
| VectorRef real_div0({real_div, add0, add2}); | |||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | |||||
| VectorRef sub0({sub0_var_, input3_, mul5}); | |||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | |||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | |||||
| VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0}); | |||||
| VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); | |||||
| VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1}); | |||||
| return VectorRef({prim::kPrimDepend, depend1, assign2}); | |||||
| } | |||||
| const BaseRef AdamApplyOneWithDecayAssignRuleCond3::DefinePattern() const { | |||||
| auto sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| auto real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); | |||||
| VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); | |||||
| VectorRef square0({prim::kPrimSquare, input0_}); | |||||
| VectorRef add0({add0_var_, mul0, mul1}); | |||||
| VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); | |||||
| VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); | |||||
| VectorRef add1({add1_var_, mul2, mul3}); | |||||
| VectorRef sqrt0({sqrt, add1}); | |||||
| VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); | |||||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | |||||
| VectorRef real_div0({real_div, add0, add2}); | |||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | |||||
| VectorRef sub0({sub0_var_, input3_, mul5}); | |||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | |||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | |||||
| VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0}); | |||||
| VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); | |||||
| VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1}); | |||||
| return VectorRef({prim::kPrimDepend, depend1, assign2}); | |||||
| } | |||||
| const BaseRef AdamApplyOneWithDecayAssignRuleCond4::DefinePattern() const { | |||||
| auto sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| auto real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); | |||||
| VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); | |||||
| VectorRef square0({prim::kPrimSquare, input0_}); | |||||
| VectorRef add0({add0_var_, mul0, mul1}); | |||||
| VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); | |||||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | |||||
| VectorRef add1({add1_var_, mul2, mul3}); | |||||
| VectorRef sqrt0({sqrt, add1}); | |||||
| VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); | |||||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | |||||
| VectorRef real_div0({real_div, add0, add2}); | |||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | |||||
| VectorRef sub0({sub0_var_, input3_, mul5}); | |||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | |||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | |||||
| VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0}); | |||||
| VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); | |||||
| VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1}); | |||||
| return VectorRef({prim::kPrimDepend, depend1, assign2}); | |||||
| } | |||||
| const BaseRef AdamApplyOneWithDecayAssignRuleCond5::DefinePattern() const { | |||||
| auto sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| auto real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); | |||||
| VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); | |||||
| VectorRef square0({prim::kPrimSquare, input0_}); | |||||
| VectorRef add0({add0_var_, mul0, mul1}); | |||||
| VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); | |||||
| VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); | |||||
| VectorRef add1({add1_var_, mul2, mul3}); | |||||
| VectorRef sqrt0({sqrt, add1}); | |||||
| VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); | |||||
| VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); | |||||
| VectorRef real_div0({real_div, add0, add2}); | |||||
| VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); | |||||
| VectorRef mul5({prim::kPrimMul, add3, input4_}); | |||||
| VectorRef sub0({sub0_var_, input3_, mul5}); | |||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0}); | |||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | |||||
| VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0}); | |||||
| VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); | |||||
| VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1}); | |||||
| return VectorRef({prim::kPrimDepend, depend1, assign2}); | |||||
| } | |||||
| const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | ||||
| const EquivPtr &equiv) const { | const EquivPtr &equiv) const { | ||||
| if (graph == nullptr || node == nullptr || equiv == nullptr) { | if (graph == nullptr || node == nullptr || equiv == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (!CheckSupportDataType(node, kFloatDataTypeSet)) { | |||||
| auto sub0 = node; | |||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { | |||||
| auto iter_sub0 = (*equiv).find(sub0_var_); | |||||
| if (iter_sub0 == (*equiv).end()) { | |||||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the sub0 var after matched."; | |||||
| } | |||||
| sub0 = utils::cast<AnfNodePtr>(iter_sub0->second); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(sub0); | |||||
| if (!CheckSupportDataType(sub0, kFloatDataTypeSet)) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> inputs = GetFusionNodeInputs(equiv); | |||||
| std::vector<AnfNodePtr> inputs = GetFusionNodeInputs(equiv, node); | |||||
| auto fusion_node = graph->NewCNode(inputs); | auto fusion_node = graph->NewCNode(inputs); | ||||
| MS_EXCEPTION_IF_NULL(fusion_node); | MS_EXCEPTION_IF_NULL(fusion_node); | ||||
| fusion_node->set_scope(node->scope()); | |||||
| fusion_node->set_scope(sub0->scope()); | |||||
| auto iter_add0 = (*equiv).find(add0_var_); | auto iter_add0 = (*equiv).find(add0_var_); | ||||
| if (iter_add0 == (*equiv).end()) { | if (iter_add0 == (*equiv).end()) { | ||||
| @@ -167,9 +307,9 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c | |||||
| auto add1 = utils::cast<AnfNodePtr>(iter_add1->second); | auto add1 = utils::cast<AnfNodePtr>(iter_add1->second); | ||||
| MS_EXCEPTION_IF_NULL(add1); | MS_EXCEPTION_IF_NULL(add1); | ||||
| auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0), | auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0), | ||||
| AnfAlgo::GetOutputInferDataType(node, 0)}; | |||||
| AnfAlgo::GetOutputInferDataType(sub0, 0)}; | |||||
| auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0), | auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0), | ||||
| AnfAlgo::GetOutputInferShape(node, 0)}; | |||||
| AnfAlgo::GetOutputInferShape(sub0, 0)}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); | AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); | ||||
| std::vector<AnfNodePtr> fusion_node_outputs; | std::vector<AnfNodePtr> fusion_node_outputs; | ||||
| @@ -40,13 +40,14 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass { | |||||
| add2_y_ = std::make_shared<Var>(); | add2_y_ = std::make_shared<Var>(); | ||||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | ||||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | ||||
| sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name())); | |||||
| } | } | ||||
| ~AdamApplyOneWithDecayRule() override = default; | ~AdamApplyOneWithDecayRule() override = default; | ||||
| const BaseRef DefinePattern() const override = 0; | const BaseRef DefinePattern() const override = 0; | ||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| protected: | protected: | ||||
| std::vector<AnfNodePtr> GetFusionNodeInputs(const EquivPtr &equiv) const; | |||||
| std::vector<AnfNodePtr> GetFusionNodeInputs(const EquivPtr &equiv, const AnfNodePtr &final_node) const; | |||||
| VarPtr input0_; | VarPtr input0_; | ||||
| VarPtr input1_; | VarPtr input1_; | ||||
| VarPtr input2_; | VarPtr input2_; | ||||
| @@ -60,6 +61,7 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass { | |||||
| VarPtr add2_y_; | VarPtr add2_y_; | ||||
| VarPtr add0_var_; | VarPtr add0_var_; | ||||
| VarPtr add1_var_; | VarPtr add1_var_; | ||||
| VarPtr sub0_var_; | |||||
| }; | }; | ||||
| class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule { | class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule { | ||||
| @@ -106,6 +108,51 @@ class AdamApplyOneWithDecayRuleCond5 : public AdamApplyOneWithDecayRule { | |||||
| ~AdamApplyOneWithDecayRuleCond5() override = default; | ~AdamApplyOneWithDecayRuleCond5() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| }; | }; | ||||
| class AdamApplyOneWithDecayAssignRuleCond1 : public AdamApplyOneWithDecayRule { | |||||
| public: | |||||
| explicit AdamApplyOneWithDecayAssignRuleCond1(bool multigraph = true) | |||||
| : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond1", multigraph) {} | |||||
| ~AdamApplyOneWithDecayAssignRuleCond1() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| class AdamApplyOneWithDecayAssignRuleCond2 : public AdamApplyOneWithDecayRule { | |||||
| public: | |||||
| explicit AdamApplyOneWithDecayAssignRuleCond2(bool multigraph = true) | |||||
| : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond2", multigraph) {} | |||||
| ~AdamApplyOneWithDecayAssignRuleCond2() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| class AdamApplyOneWithDecayAssignRuleCond3 : public AdamApplyOneWithDecayRule { | |||||
| public: | |||||
| explicit AdamApplyOneWithDecayAssignRuleCond3(bool multigraph = true) | |||||
| : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond3", multigraph) {} | |||||
| ~AdamApplyOneWithDecayAssignRuleCond3() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| class AdamApplyOneWithDecayAssignRuleCond4 : public AdamApplyOneWithDecayRule { | |||||
| public: | |||||
| explicit AdamApplyOneWithDecayAssignRuleCond4(bool multigraph = true) | |||||
| : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond4", multigraph) {} | |||||
| ~AdamApplyOneWithDecayAssignRuleCond4() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| class AdamApplyOneWithDecayAssignRuleCond5 : public AdamApplyOneWithDecayRule { | |||||
| public: | |||||
| explicit AdamApplyOneWithDecayAssignRuleCond5(bool multigraph = true) | |||||
| : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond5", multigraph) {} | |||||
| ~AdamApplyOneWithDecayAssignRuleCond5() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ | #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ | ||||
| @@ -122,6 +122,7 @@ constexpr auto kLayerNormBetaGammaBackpropOpName = "LayerNormBetaGammaBackprop"; | |||||
| constexpr auto kLambNextMVOpName = "LambNextMV"; | constexpr auto kLambNextMVOpName = "LambNextMV"; | ||||
| constexpr auto kConfusionTransposeDOpName = "ConfusionTransposeD"; | constexpr auto kConfusionTransposeDOpName = "ConfusionTransposeD"; | ||||
| constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay"; | constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay"; | ||||
| constexpr auto kAdamApplyOneWithDecayAssignOpName = "AdamApplyOneWithDecayAssign"; | |||||
| constexpr auto kBatchNormGradOpName = "BatchNormGrad"; | constexpr auto kBatchNormGradOpName = "BatchNormGrad"; | ||||
| constexpr auto kBNInferOpName = "BNInfer"; | constexpr auto kBNInferOpName = "BNInfer"; | ||||
| constexpr auto kAdamApplyOneOpName = "AdamApplyOne"; | constexpr auto kAdamApplyOneOpName = "AdamApplyOne"; | ||||
| @@ -31,7 +31,7 @@ class TestHWOptimizeAdamApplyOneWithDecayRule : public BackendCommon { | |||||
| }; | }; | ||||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond1) { | TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond1) { | ||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond1", "before"); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond1"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | std::vector<int> shp{2, 32, 224, 224}; | ||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | ||||
| @@ -47,12 +47,12 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r | |||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | FuncGraphPtr new_graph = optimizer->Optimize(fg); | ||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond1", "after"); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond2) { | TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond2) { | ||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond2", "before"); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond2"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | std::vector<int> shp{2, 32, 224, 224}; | ||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | ||||
| @@ -68,12 +68,12 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r | |||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | FuncGraphPtr new_graph = optimizer->Optimize(fg); | ||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond2", "after"); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond3) { | TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond3) { | ||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond3", "before"); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond3"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | std::vector<int> shp{2, 32, 224, 224}; | ||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | ||||
| @@ -89,12 +89,12 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r | |||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | FuncGraphPtr new_graph = optimizer->Optimize(fg); | ||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond3", "after"); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond4) { | TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond4) { | ||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond4", "before"); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond4"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | std::vector<int> shp{2, 32, 224, 224}; | ||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | ||||
| @@ -110,12 +110,12 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r | |||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | FuncGraphPtr new_graph = optimizer->Optimize(fg); | ||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond4", "after"); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond5) { | TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond5) { | ||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond5", "before"); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond5"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | std::vector<int> shp{2, 32, 224, 224}; | ||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | ||||
| @@ -131,7 +131,112 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r | |||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | FuncGraphPtr new_graph = optimizer->Optimize(fg); | ||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond5", "after"); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond1) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond1"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 11; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayAssignRuleCond1>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond2) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond2"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 11; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayAssignRuleCond2>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond3) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond3"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 11; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayAssignRuleCond3>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond4) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond4"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 11; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayAssignRuleCond4>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond5) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond5"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 11; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayAssignRuleCond5>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -15,6 +15,7 @@ | |||||
| from mindspore.ops import Primitive | from mindspore.ops import Primitive | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import functional as F | |||||
| mul = P.Mul() | mul = P.Mul() | ||||
| add = P.TensorAdd() | add = P.TensorAdd() | ||||
| @@ -22,9 +23,11 @@ square = P.Square() | |||||
| sqrt = P.Sqrt() | sqrt = P.Sqrt() | ||||
| real_div = P.RealDiv() | real_div = P.RealDiv() | ||||
| sub = P.Sub() | sub = P.Sub() | ||||
| Assign = P.Assign() | |||||
| make_tuple = Primitive('make_tuple') | make_tuple = Primitive('make_tuple') | ||||
| tuple_getitem = Primitive('tuple_getitem') | tuple_getitem = Primitive('tuple_getitem') | ||||
| adam_apply_one_with_decay = Primitive('AdamApplyOneWithDecay') | adam_apply_one_with_decay = Primitive('AdamApplyOneWithDecay') | ||||
| adam_apply_one_with_decay_assign = Primitive('AdamApplyOneWithDecayAssign') | |||||
| class FnDict: | class FnDict: | ||||
| @@ -39,11 +42,10 @@ class FnDict: | |||||
| def test_adam_apply_one_with_decay_rule(tag): | def test_adam_apply_one_with_decay_rule(tag): | ||||
| """ test_adam_apply_one_with_decay_rule """ | |||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| mul0 = mul(mul0_x, input2) | mul0 = mul(mul0_x, input2) | ||||
| mul1 = mul(mul1_x, input0) | mul1 = mul(mul1_x, input0) | ||||
| square0 = square(input0) | square0 = square(input0) | ||||
| @@ -52,50 +54,70 @@ def test_adam_apply_one_with_decay_rule(tag): | |||||
| mul3 = mul(mul3_x, square0) | mul3 = mul(mul3_x, square0) | ||||
| add1 = add(mul2, mul3) | add1 = add(mul2, mul3) | ||||
| sqrt0 = sqrt(add1) | sqrt0 = sqrt(add1) | ||||
| add2 = add(sqrt0, add2_y) | |||||
| add2 = add(add2_y, sqrt0) | |||||
| mul4 = mul(mul4_x, input3) | mul4 = mul(mul4_x, input3) | ||||
| real_div0 = real_div(add0, add2) | real_div0 = real_div(add0, add2) | ||||
| add3 = add(real_div0, mul4) | |||||
| add3 = add(mul4, real_div0) | |||||
| mul5 = mul(input4, add3) | mul5 = mul(input4, add3) | ||||
| sub0 = sub(input3, mul5) | sub0 = sub(input3, mul5) | ||||
| return make_tuple(add1, add0, sub0) | return make_tuple(add1, add0, sub0) | ||||
| @fns | @fns | ||||
| def no_match(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| mul0 = mul(input2, mul0_x) | |||||
| mul1 = mul(input0, mul1_x) | |||||
| square0 = square(input0) | |||||
| add0 = add(mul0, mul1) | |||||
| mul2 = mul(input1, mul2_x) | |||||
| mul3 = mul(mul3_x, square0) | |||||
| add1 = add(mul2, mul3) | |||||
| sqrt0 = sqrt(add1) | |||||
| add2 = add(sqrt0, add2_y) | |||||
| mul4 = mul(input3, mul4_x) | |||||
| real_div0 = real_div(add0, add2) | |||||
| add3 = add(mul4, real_div0) | |||||
| mul5 = mul(add3, input4) | |||||
| sub0 = sub(input3, mul5) | |||||
| return make_tuple(add1, add0, sub0) | |||||
| @fns | |||||
| def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| mul0 = mul(mul0_x, input2) | mul0 = mul(mul0_x, input2) | ||||
| mul1 = mul(mul1_x, input0) | mul1 = mul(mul1_x, input0) | ||||
| square0 = square(input0) | square0 = square(input0) | ||||
| # diff mul from original add | |||||
| add0 = mul(mul0, mul1) | |||||
| add0 = add(mul0, mul1) | |||||
| mul2 = mul(mul2_x, input1) | mul2 = mul(mul2_x, input1) | ||||
| mul3 = mul(mul3_x, square0) | |||||
| mul3 = mul(square0, mul3_x) | |||||
| add1 = add(mul2, mul3) | add1 = add(mul2, mul3) | ||||
| sqrt0 = sqrt(add1) | sqrt0 = sqrt(add1) | ||||
| add2 = add(sqrt0, add2_y) | add2 = add(sqrt0, add2_y) | ||||
| mul4 = mul(mul4_x, input3) | mul4 = mul(mul4_x, input3) | ||||
| real_div0 = real_div(add0, add2) | real_div0 = real_div(add0, add2) | ||||
| add3 = add(real_div0, mul4) | |||||
| mul5 = mul(input4, add3) | |||||
| add3 = add(mul4, real_div0) | |||||
| mul5 = mul(add3, input4) | |||||
| sub0 = sub(input3, mul5) | sub0 = sub(input3, mul5) | ||||
| return make_tuple(add1, add0, sub0) | return make_tuple(add1, add0, sub0) | ||||
| @fns | @fns | ||||
| def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, | |||||
| add2_y) | |||||
| item0 = tuple_getitem(res, 0) | |||||
| item1 = tuple_getitem(res, 1) | |||||
| item2 = tuple_getitem(res, 2) | |||||
| return make_tuple(make_tuple(item0, item1, item2)) | |||||
| return fns[tag] | |||||
| def test_adam_apply_one_with_decay_rule_cond1(tag): | |||||
| fns = FnDict() | |||||
| def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| mul0 = mul(mul0_x, input2) | |||||
| mul1 = mul(mul1_x, input0) | |||||
| square0 = square(input0) | |||||
| add0 = add(mul0, mul1) | |||||
| mul2 = mul(mul2_x, input1) | |||||
| mul3 = mul(mul3_x, square0) | |||||
| add1 = add(mul2, mul3) | |||||
| sqrt0 = sqrt(add1) | |||||
| add2 = add(add2_y, sqrt0) | |||||
| mul4 = mul(mul4_x, input3) | |||||
| real_div0 = real_div(add0, add2) | |||||
| add3 = add(mul4, real_div0) | |||||
| mul5 = mul(add3, input4) | |||||
| sub0 = sub(input3, mul5) | |||||
| return make_tuple(add1, add0, sub0) | |||||
| @fns | @fns | ||||
| def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| def before_cond5(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| mul0 = mul(mul0_x, input2) | mul0 = mul(mul0_x, input2) | ||||
| mul1 = mul(mul1_x, input0) | mul1 = mul(mul1_x, input0) | ||||
| square0 = square(input0) | square0 = square(input0) | ||||
| @@ -104,11 +126,11 @@ def test_adam_apply_one_with_decay_rule_cond1(tag): | |||||
| mul3 = mul(mul3_x, square0) | mul3 = mul(mul3_x, square0) | ||||
| add1 = add(mul2, mul3) | add1 = add(mul2, mul3) | ||||
| sqrt0 = sqrt(add1) | sqrt0 = sqrt(add1) | ||||
| add2 = add(add2_y, sqrt0) | |||||
| add2 = add(sqrt0, add2_y) | |||||
| mul4 = mul(mul4_x, input3) | mul4 = mul(mul4_x, input3) | ||||
| real_div0 = real_div(add0, add2) | real_div0 = real_div(add0, add2) | ||||
| add3 = add(mul4, real_div0) | add3 = add(mul4, real_div0) | ||||
| mul5 = mul(input4, add3) | |||||
| mul5 = mul(add3, input4) | |||||
| sub0 = sub(input3, mul5) | sub0 = sub(input3, mul5) | ||||
| return make_tuple(add1, add0, sub0) | return make_tuple(add1, add0, sub0) | ||||
| @@ -124,11 +146,35 @@ def test_adam_apply_one_with_decay_rule_cond1(tag): | |||||
| return fns[tag] | return fns[tag] | ||||
| def test_adam_apply_one_with_decay_rule_cond2(tag): | |||||
| def test_adam_apply_one_with_decay_assign_rule(tag): | |||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| mul0 = mul(mul0_x, input2) | |||||
| mul1 = mul(mul1_x, input0) | |||||
| square0 = square(input0) | |||||
| add0 = add(mul0, mul1) | |||||
| mul2 = mul(mul2_x, input1) | |||||
| mul3 = mul(mul3_x, square0) | |||||
| add1 = add(mul2, mul3) | |||||
| sqrt0 = sqrt(add1) | |||||
| add2 = add(add2_y, sqrt0) | |||||
| mul4 = mul(mul4_x, input3) | |||||
| real_div0 = real_div(add0, add2) | |||||
| add3 = add(mul4, real_div0) | |||||
| mul5 = mul(input4, add3) | |||||
| sub0 = sub(input3, mul5) | |||||
| assign0 = Assign(input3, sub0) | |||||
| depend0 = F.depend(sub0, assign0) | |||||
| assign1 = Assign(input2, add0) | |||||
| depend1 = F.depend(depend0, assign1) | |||||
| assign2 = Assign(input1, add1) | |||||
| depend2 = F.depend(depend1, assign2) | |||||
| return make_tuple(add1, add0, depend2) | |||||
| @fns | |||||
| def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| mul0 = mul(input2, mul0_x) | mul0 = mul(input2, mul0_x) | ||||
| mul1 = mul(input0, mul1_x) | mul1 = mul(input0, mul1_x) | ||||
| square0 = square(input0) | square0 = square(input0) | ||||
| @@ -143,25 +189,16 @@ def test_adam_apply_one_with_decay_rule_cond2(tag): | |||||
| add3 = add(mul4, real_div0) | add3 = add(mul4, real_div0) | ||||
| mul5 = mul(add3, input4) | mul5 = mul(add3, input4) | ||||
| sub0 = sub(input3, mul5) | sub0 = sub(input3, mul5) | ||||
| return make_tuple(add1, add0, sub0) | |||||
| @fns | |||||
| def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, | |||||
| add2_y) | |||||
| item0 = tuple_getitem(res, 0) | |||||
| item1 = tuple_getitem(res, 1) | |||||
| item2 = tuple_getitem(res, 2) | |||||
| return make_tuple(make_tuple(item0, item1, item2)) | |||||
| return fns[tag] | |||||
| def test_adam_apply_one_with_decay_rule_cond3(tag): | |||||
| fns = FnDict() | |||||
| assign0 = Assign(input3, sub0) | |||||
| depend0 = F.depend(sub0, assign0) | |||||
| assign1 = Assign(input2, add0) | |||||
| depend1 = F.depend(depend0, assign1) | |||||
| assign2 = Assign(input1, add1) | |||||
| depend2 = F.depend(depend1, assign2) | |||||
| return make_tuple(add1, add0, depend2) | |||||
| @fns | @fns | ||||
| def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| mul0 = mul(mul0_x, input2) | mul0 = mul(mul0_x, input2) | ||||
| mul1 = mul(mul1_x, input0) | mul1 = mul(mul1_x, input0) | ||||
| square0 = square(input0) | square0 = square(input0) | ||||
| @@ -176,25 +213,16 @@ def test_adam_apply_one_with_decay_rule_cond3(tag): | |||||
| add3 = add(mul4, real_div0) | add3 = add(mul4, real_div0) | ||||
| mul5 = mul(add3, input4) | mul5 = mul(add3, input4) | ||||
| sub0 = sub(input3, mul5) | sub0 = sub(input3, mul5) | ||||
| return make_tuple(add1, add0, sub0) | |||||
| @fns | |||||
| def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, | |||||
| add2_y) | |||||
| item0 = tuple_getitem(res, 0) | |||||
| item1 = tuple_getitem(res, 1) | |||||
| item2 = tuple_getitem(res, 2) | |||||
| return make_tuple(make_tuple(item0, item1, item2)) | |||||
| return fns[tag] | |||||
| def test_adam_apply_one_with_decay_rule_cond4(tag): | |||||
| fns = FnDict() | |||||
| assign0 = Assign(input3, sub0) | |||||
| depend0 = F.depend(sub0, assign0) | |||||
| assign1 = Assign(input2, add0) | |||||
| depend1 = F.depend(depend0, assign1) | |||||
| assign2 = Assign(input1, add1) | |||||
| depend2 = F.depend(depend1, assign2) | |||||
| return make_tuple(add1, add0, depend2) | |||||
| @fns | @fns | ||||
| def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| mul0 = mul(mul0_x, input2) | mul0 = mul(mul0_x, input2) | ||||
| mul1 = mul(mul1_x, input0) | mul1 = mul(mul1_x, input0) | ||||
| square0 = square(input0) | square0 = square(input0) | ||||
| @@ -209,25 +237,16 @@ def test_adam_apply_one_with_decay_rule_cond4(tag): | |||||
| add3 = add(mul4, real_div0) | add3 = add(mul4, real_div0) | ||||
| mul5 = mul(add3, input4) | mul5 = mul(add3, input4) | ||||
| sub0 = sub(input3, mul5) | sub0 = sub(input3, mul5) | ||||
| return make_tuple(add1, add0, sub0) | |||||
| @fns | |||||
| def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, | |||||
| add2_y) | |||||
| item0 = tuple_getitem(res, 0) | |||||
| item1 = tuple_getitem(res, 1) | |||||
| item2 = tuple_getitem(res, 2) | |||||
| return make_tuple(make_tuple(item0, item1, item2)) | |||||
| return fns[tag] | |||||
| def test_adam_apply_one_with_decay_rule_cond5(tag): | |||||
| fns = FnDict() | |||||
| assign0 = Assign(input3, sub0) | |||||
| depend0 = F.depend(sub0, assign0) | |||||
| assign1 = Assign(input2, add0) | |||||
| depend1 = F.depend(depend0, assign1) | |||||
| assign2 = Assign(input1, add1) | |||||
| depend2 = F.depend(depend1, assign2) | |||||
| return make_tuple(add1, add0, depend2) | |||||
| @fns | @fns | ||||
| def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| def before_cond5(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | |||||
| mul0 = mul(mul0_x, input2) | mul0 = mul(mul0_x, input2) | ||||
| mul1 = mul(mul1_x, input0) | mul1 = mul(mul1_x, input0) | ||||
| square0 = square(input0) | square0 = square(input0) | ||||
| @@ -242,12 +261,18 @@ def test_adam_apply_one_with_decay_rule_cond5(tag): | |||||
| add3 = add(mul4, real_div0) | add3 = add(mul4, real_div0) | ||||
| mul5 = mul(add3, input4) | mul5 = mul(add3, input4) | ||||
| sub0 = sub(input3, mul5) | sub0 = sub(input3, mul5) | ||||
| return make_tuple(add1, add0, sub0) | |||||
| assign0 = Assign(input3, sub0) | |||||
| depend0 = F.depend(sub0, assign0) | |||||
| assign1 = Assign(input2, add0) | |||||
| depend1 = F.depend(depend0, assign1) | |||||
| assign2 = Assign(input1, add1) | |||||
| depend2 = F.depend(depend1, assign2) | |||||
| return make_tuple(add1, add0, depend2) | |||||
| @fns | @fns | ||||
| def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): | ||||
| res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, | |||||
| add2_y) | |||||
| res = adam_apply_one_with_decay_assign(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, | |||||
| mul4_x, add2_y) | |||||
| item0 = tuple_getitem(res, 0) | item0 = tuple_getitem(res, 0) | ||||
| item1 = tuple_getitem(res, 1) | item1 = tuple_getitem(res, 1) | ||||
| item2 = tuple_getitem(res, 2) | item2 = tuple_getitem(res, 2) | ||||