Merge pull request !4197 from YuJianfeng/adam_assigntags/v0.7.0-beta
| @@ -125,6 +125,10 @@ void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>()); | ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>()); | ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>()); | ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneAssignCond1Fusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneAssignCond2Fusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneAssignCond3Fusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneAssignCond4Fusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond1Fusion>()); | ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond1Fusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond2Fusion>()); | ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond2Fusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond3Fusion>()); | ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond3Fusion>()); | ||||
| @@ -15,30 +15,9 @@ | |||||
| */ | */ | ||||
| #include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h" | #include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h" | ||||
| #include "backend/optimizer/common/helper.h" | #include "backend/optimizer/common/helper.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| auto prim = std::make_shared<Primitive>(kAdamApplyOneOpName); | |||||
| std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim)}; | |||||
| for (const auto &input_var : input_vars_) { | |||||
| auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_var]); | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| new_node_inputs.push_back(input_node); | |||||
| } | |||||
| for (const auto &mul_x_input_var : mul_x_input_vars_) { | |||||
| auto mul_x_input_node = utils::cast<AnfNodePtr>((*equiv)[mul_x_input_var]); | |||||
| MS_EXCEPTION_IF_NULL(mul_x_input_node); | |||||
| new_node_inputs.push_back(mul_x_input_node); | |||||
| } | |||||
| auto add2_y_node = utils::cast<AnfNodePtr>((*equiv)[add2_y_]); | |||||
| MS_EXCEPTION_IF_NULL(add2_y_node); | |||||
| new_node_inputs.push_back(add2_y_node); | |||||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||||
| return new_node; | |||||
| } | |||||
| const BaseRef AdamApplyOneFusion::DefinePattern() const { | const BaseRef AdamApplyOneFusion::DefinePattern() const { | ||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | ||||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | ||||
| @@ -104,16 +83,152 @@ const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const { | |||||
| return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | ||||
| } | } | ||||
| const BaseRef AdamApplyOneAssignFusion::DefinePattern() const { | |||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | |||||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | |||||
| VectorRef sqrt0 = VectorRef({prim_sqrt, add1}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||||
| VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | |||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | |||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | |||||
| VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0}); | |||||
| VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); | |||||
| VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1}); | |||||
| return VectorRef({prim::kPrimDepend, depend1, assign2}); | |||||
| } | |||||
| const BaseRef AdamApplyOneAssignCond1Fusion::DefinePattern() const { | |||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | |||||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | |||||
| VectorRef sqrt0 = VectorRef({prim_sqrt, add1}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); | |||||
| VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); | |||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | |||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | |||||
| VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0}); | |||||
| VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); | |||||
| VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1}); | |||||
| return VectorRef({prim::kPrimDepend, depend1, assign2}); | |||||
| } | |||||
| const BaseRef AdamApplyOneAssignCond2Fusion::DefinePattern() const { | |||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[0]}), mul_x_input_vars_[3]}); | |||||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | |||||
| VectorRef sqrt0 = VectorRef({prim_sqrt, add1}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||||
| VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | |||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | |||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | |||||
| VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0}); | |||||
| VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); | |||||
| VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1}); | |||||
| return VectorRef({prim::kPrimDepend, depend1, assign2}); | |||||
| } | |||||
| const BaseRef AdamApplyOneAssignCond3Fusion::DefinePattern() const { | |||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | |||||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | |||||
| VectorRef sqrt0 = VectorRef({prim_sqrt, add1}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); | |||||
| VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | |||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | |||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | |||||
| VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0}); | |||||
| VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); | |||||
| VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1}); | |||||
| return VectorRef({prim::kPrimDepend, depend1, assign2}); | |||||
| } | |||||
| const BaseRef AdamApplyOneAssignCond4Fusion::DefinePattern() const { | |||||
| const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName); | |||||
| const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName); | |||||
| VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); | |||||
| VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); | |||||
| VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); | |||||
| VectorRef sqrt0 = VectorRef({prim_sqrt, add1}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); | |||||
| VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); | |||||
| VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); | |||||
| VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); | |||||
| VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0}); | |||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0}); | |||||
| VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0}); | |||||
| VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1}); | |||||
| VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1}); | |||||
| return VectorRef({prim::kPrimDepend, depend1, assign2}); | |||||
| } | |||||
| AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, | |||||
| const AnfNodePtr &final_node) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| PrimitivePtr prim = nullptr; | |||||
| if (AnfAlgo::CheckPrimitiveType(final_node, prim::kPrimDepend)) { | |||||
| prim = std::make_shared<Primitive>(kAdamApplyOneAssignOpName); | |||||
| } else { | |||||
| prim = std::make_shared<Primitive>(kAdamApplyOneOpName); | |||||
| } | |||||
| std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim)}; | |||||
| for (const auto &input_var : input_vars_) { | |||||
| auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_var]); | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| new_node_inputs.push_back(input_node); | |||||
| } | |||||
| for (const auto &mul_x_input_var : mul_x_input_vars_) { | |||||
| auto mul_x_input_node = utils::cast<AnfNodePtr>((*equiv)[mul_x_input_var]); | |||||
| MS_EXCEPTION_IF_NULL(mul_x_input_node); | |||||
| new_node_inputs.push_back(mul_x_input_node); | |||||
| } | |||||
| auto add2_y_node = utils::cast<AnfNodePtr>((*equiv)[add2_y_]); | |||||
| MS_EXCEPTION_IF_NULL(add2_y_node); | |||||
| new_node_inputs.push_back(add2_y_node); | |||||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||||
| return new_node; | |||||
| } | |||||
| const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const EquivPtr &equiv) const { | const EquivPtr &equiv) const { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!CheckSupportDataType(node, kFloatDataTypeSet)) { | |||||
| auto sub0 = node; | |||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { | |||||
| auto iter_sub0 = (*equiv).find(sub0_var_); | |||||
| if (iter_sub0 == (*equiv).end()) { | |||||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the sub0 var after matched."; | |||||
| } | |||||
| sub0 = utils::cast<AnfNodePtr>(iter_sub0->second); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(sub0); | |||||
| if (!CheckSupportDataType(sub0, kFloatDataTypeSet)) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto new_node = CreateAdamApplyOneNode(func_graph, equiv); | |||||
| auto new_node = CreateAdamApplyOneNode(func_graph, equiv, node); | |||||
| MS_EXCEPTION_IF_NULL(new_node); | MS_EXCEPTION_IF_NULL(new_node); | ||||
| new_node->set_scope(node->scope()); | |||||
| new_node->set_scope(sub0->scope()); | |||||
| // Set abstract of new node | // Set abstract of new node | ||||
| AbstractBasePtrList new_node_abstract_list; | AbstractBasePtrList new_node_abstract_list; | ||||
| auto iter_add0 = (*equiv).find(add0_var_); | auto iter_add0 = (*equiv).find(add0_var_); | ||||
| @@ -130,7 +245,7 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con | |||||
| MS_EXCEPTION_IF_NULL(add1); | MS_EXCEPTION_IF_NULL(add1); | ||||
| new_node_abstract_list.push_back(add1->abstract()); | new_node_abstract_list.push_back(add1->abstract()); | ||||
| new_node_abstract_list.push_back(add0->abstract()); | new_node_abstract_list.push_back(add0->abstract()); | ||||
| new_node_abstract_list.push_back(node->abstract()); | |||||
| new_node_abstract_list.push_back(sub0->abstract()); | |||||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_abstract_list); | auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_abstract_list); | ||||
| new_node->set_abstract(abstract_tuple); | new_node->set_abstract(abstract_tuple); | ||||
| // Create tuple_getitem node for outputs | // Create tuple_getitem node for outputs | ||||
| @@ -40,6 +40,7 @@ class AdamApplyOneFusion : public PatternProcessPass { | |||||
| add2_y_ = std::make_shared<Var>(); | add2_y_ = std::make_shared<Var>(); | ||||
| add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | ||||
| add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); | ||||
| sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name())); | |||||
| } | } | ||||
| ~AdamApplyOneFusion() override = default; | ~AdamApplyOneFusion() override = default; | ||||
| @@ -47,12 +48,14 @@ class AdamApplyOneFusion : public PatternProcessPass { | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| protected: | protected: | ||||
| AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; | |||||
| AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, | |||||
| const AnfNodePtr &final_node) const; | |||||
| std::vector<VarPtr> input_vars_; | std::vector<VarPtr> input_vars_; | ||||
| std::vector<VarPtr> mul_x_input_vars_; | std::vector<VarPtr> mul_x_input_vars_; | ||||
| VarPtr add2_y_; | VarPtr add2_y_; | ||||
| VarPtr add0_var_; | VarPtr add0_var_; | ||||
| VarPtr add1_var_; | VarPtr add1_var_; | ||||
| VarPtr sub0_var_; | |||||
| }; | }; | ||||
| class AdamApplyOneCond1Fusion : public AdamApplyOneFusion { | class AdamApplyOneCond1Fusion : public AdamApplyOneFusion { | ||||
| @@ -90,6 +93,51 @@ class AdamApplyOneCond4Fusion : public AdamApplyOneFusion { | |||||
| ~AdamApplyOneCond4Fusion() override = default; | ~AdamApplyOneCond4Fusion() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| }; | }; | ||||
| class AdamApplyOneAssignFusion : public AdamApplyOneFusion { | |||||
| public: | |||||
| explicit AdamApplyOneAssignFusion(bool multigraph = true) | |||||
| : AdamApplyOneFusion("adam_apply_one_assign_fusion", multigraph) {} | |||||
| ~AdamApplyOneAssignFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| class AdamApplyOneAssignCond1Fusion : public AdamApplyOneFusion { | |||||
| public: | |||||
| explicit AdamApplyOneAssignCond1Fusion(bool multigraph = true) | |||||
| : AdamApplyOneFusion("adam_apply_one_assign_cond1_fusion", multigraph) {} | |||||
| ~AdamApplyOneAssignCond1Fusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| class AdamApplyOneAssignCond2Fusion : public AdamApplyOneFusion { | |||||
| public: | |||||
| explicit AdamApplyOneAssignCond2Fusion(bool multigraph = true) | |||||
| : AdamApplyOneFusion("adam_apply_one_assign_cond2_fusion", multigraph) {} | |||||
| ~AdamApplyOneAssignCond2Fusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| class AdamApplyOneAssignCond3Fusion : public AdamApplyOneFusion { | |||||
| public: | |||||
| explicit AdamApplyOneAssignCond3Fusion(bool multigraph = true) | |||||
| : AdamApplyOneFusion("adam_apply_one_assign_cond3_fusion", multigraph) {} | |||||
| ~AdamApplyOneAssignCond3Fusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| class AdamApplyOneAssignCond4Fusion : public AdamApplyOneFusion { | |||||
| public: | |||||
| explicit AdamApplyOneAssignCond4Fusion(bool multigraph = true) | |||||
| : AdamApplyOneFusion("adam_apply_one_assign_cond4_fusion", multigraph) {} | |||||
| ~AdamApplyOneAssignCond4Fusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ | #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ | ||||
| @@ -119,6 +119,7 @@ constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay"; | |||||
| constexpr auto kBatchNormGradOpName = "BatchNormGrad"; | constexpr auto kBatchNormGradOpName = "BatchNormGrad"; | ||||
| constexpr auto kBNInferOpName = "BNInfer"; | constexpr auto kBNInferOpName = "BNInfer"; | ||||
| constexpr auto kAdamApplyOneOpName = "AdamApplyOne"; | constexpr auto kAdamApplyOneOpName = "AdamApplyOne"; | ||||
| constexpr auto kAdamApplyOneAssignOpName = "AdamApplyOneAssign"; | |||||
| constexpr auto kResizeNearestNeighborGradOpName = "ResizeNearestNeighborGrad"; | constexpr auto kResizeNearestNeighborGradOpName = "ResizeNearestNeighborGrad"; | ||||
| constexpr auto kFusedMulAddOpName = "FusedMulAdd"; | constexpr auto kFusedMulAddOpName = "FusedMulAdd"; | ||||
| constexpr auto kFusedMulAddNOpName = "FusedMulAddN"; | constexpr auto kFusedMulAddNOpName = "FusedMulAddN"; | ||||
| @@ -217,5 +217,105 @@ TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond4_fusion) { | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); | FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after"); | ||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_fusion) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 10; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneAssignFusion>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_cond1_fusion) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before_cond1"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 10; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneAssignCond1Fusion>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_cond2_fusion) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before_cond2"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 10; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneAssignCond2Fusion>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_cond3_fusion) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before_cond3"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 10; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneAssignCond3Fusion>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_assign_cond4_fusion) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "before_cond4"); | |||||
| std::vector<int> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 10; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::AdamApplyOneAssignCond4Fusion>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_assign_fusion", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| from mindspore.ops import Primitive | from mindspore.ops import Primitive | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import functional as F | |||||
| Add = P.TensorAdd() | Add = P.TensorAdd() | ||||
| Sub = P.Sub() | Sub = P.Sub() | ||||
| @@ -21,9 +22,11 @@ Mul = P.Mul() | |||||
| RealDiv = P.RealDiv() | RealDiv = P.RealDiv() | ||||
| Sqrt = P.Sqrt() | Sqrt = P.Sqrt() | ||||
| Square = P.Square() | Square = P.Square() | ||||
| Assign = P.Assign() | |||||
| make_tuple = Primitive('make_tuple') | make_tuple = Primitive('make_tuple') | ||||
| tuple_getitem = Primitive('tuple_getitem') | tuple_getitem = Primitive('tuple_getitem') | ||||
| AdamApplyOne = Primitive('AdamApplyOne') | AdamApplyOne = Primitive('AdamApplyOne') | ||||
| AdamApplyOneAssign = Primitive('AdamApplyOneAssign') | |||||
| class FnDict: | class FnDict: | ||||
| @@ -139,3 +142,138 @@ def test_adam_apply_one_fusion(tag): | |||||
| return make_tuple(output) | return make_tuple(output) | ||||
| return fns[tag] | return fns[tag] | ||||
| def test_adam_apply_one_assign_fusion(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| square0 = Square(input0) | |||||
| mul1 = Mul(mul1_x, input0) | |||||
| mul0 = Mul(mul0_x, input2) | |||||
| mul2 = Mul(mul2_x, input1) | |||||
| mul3 = Mul(mul3_x, square0) | |||||
| add0 = Add(mul0, mul1) | |||||
| add1 = Add(mul2, mul3) | |||||
| sqrt0 = Sqrt(add1) | |||||
| add2 = Add(sqrt0, add2_y) | |||||
| true_div0 = RealDiv(add0, add2) | |||||
| mul4 = Mul(input4, true_div0) | |||||
| sub0 = Sub(input3, mul4) | |||||
| assign0 = Assign(input3, sub0) | |||||
| depend0 = F.depend(sub0, assign0) | |||||
| assign1 = Assign(input2, add0) | |||||
| depend1 = F.depend(depend0, assign1) | |||||
| assign2 = Assign(input1, add1) | |||||
| depend2 = F.depend(depend1, assign2) | |||||
| outputs = make_tuple(add1, add0, depend2) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| @fns | |||||
| def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| square0 = Square(input0) | |||||
| mul1 = Mul(mul1_x, input0) | |||||
| mul0 = Mul(mul0_x, input2) | |||||
| mul2 = Mul(mul2_x, input1) | |||||
| mul3 = Mul(mul3_x, square0) | |||||
| add0 = Add(mul0, mul1) | |||||
| add1 = Add(mul2, mul3) | |||||
| sqrt0 = Sqrt(add1) | |||||
| add2 = Add(add2_y, sqrt0) | |||||
| true_div0 = RealDiv(add0, add2) | |||||
| mul4 = Mul(input4, true_div0) | |||||
| sub0 = Sub(input3, mul4) | |||||
| assign0 = Assign(input3, sub0) | |||||
| depend0 = F.depend(sub0, assign0) | |||||
| assign1 = Assign(input2, add0) | |||||
| depend1 = F.depend(depend0, assign1) | |||||
| assign2 = Assign(input1, add1) | |||||
| depend2 = F.depend(depend1, assign2) | |||||
| outputs = make_tuple(add1, add0, depend2) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| @fns | |||||
| def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| square0 = Square(input0) | |||||
| mul1 = Mul(mul1_x, input0) | |||||
| mul0 = Mul(mul0_x, input2) | |||||
| mul2 = Mul(mul2_x, input1) | |||||
| mul3 = Mul(square0, mul3_x) | |||||
| add0 = Add(mul0, mul1) | |||||
| add1 = Add(mul2, mul3) | |||||
| sqrt0 = Sqrt(add1) | |||||
| add2 = Add(sqrt0, add2_y) | |||||
| true_div0 = RealDiv(add0, add2) | |||||
| mul4 = Mul(true_div0, input4) | |||||
| sub0 = Sub(input3, mul4) | |||||
| assign0 = Assign(input3, sub0) | |||||
| depend0 = F.depend(sub0, assign0) | |||||
| assign1 = Assign(input2, add0) | |||||
| depend1 = F.depend(depend0, assign1) | |||||
| assign2 = Assign(input1, add1) | |||||
| depend2 = F.depend(depend1, assign2) | |||||
| outputs = make_tuple(add1, add0, depend2) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| @fns | |||||
| def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| square0 = Square(input0) | |||||
| mul1 = Mul(mul1_x, input0) | |||||
| mul0 = Mul(mul0_x, input2) | |||||
| mul2 = Mul(mul2_x, input1) | |||||
| mul3 = Mul(mul3_x, square0) | |||||
| add0 = Add(mul0, mul1) | |||||
| add1 = Add(mul2, mul3) | |||||
| sqrt0 = Sqrt(add1) | |||||
| add2 = Add(sqrt0, add2_y) | |||||
| true_div0 = RealDiv(add0, add2) | |||||
| mul4 = Mul(true_div0, input4) | |||||
| sub0 = Sub(input3, mul4) | |||||
| assign0 = Assign(input3, sub0) | |||||
| depend0 = F.depend(sub0, assign0) | |||||
| assign1 = Assign(input2, add0) | |||||
| depend1 = F.depend(depend0, assign1) | |||||
| assign2 = Assign(input1, add1) | |||||
| depend2 = F.depend(depend1, assign2) | |||||
| outputs = make_tuple(add1, add0, depend2) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| @fns | |||||
| def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| square0 = Square(input0) | |||||
| mul1 = Mul(mul1_x, input0) | |||||
| mul0 = Mul(mul0_x, input2) | |||||
| mul2 = Mul(mul2_x, input1) | |||||
| mul3 = Mul(mul3_x, square0) | |||||
| add0 = Add(mul0, mul1) | |||||
| add1 = Add(mul2, mul3) | |||||
| sqrt0 = Sqrt(add1) | |||||
| add2 = Add(add2_y, sqrt0) | |||||
| true_div0 = RealDiv(add0, add2) | |||||
| mul4 = Mul(true_div0, input4) | |||||
| sub0 = Sub(input3, mul4) | |||||
| assign0 = Assign(input3, sub0) | |||||
| depend0 = F.depend(sub0, assign0) | |||||
| assign1 = Assign(input2, add0) | |||||
| depend1 = F.depend(depend0, assign1) | |||||
| assign2 = Assign(input1, add1) | |||||
| depend2 = F.depend(depend1, assign2) | |||||
| outputs = make_tuple(add1, add0, depend2) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| @fns | |||||
| def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y): | |||||
| adam_apply_one_assign = AdamApplyOneAssign(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, | |||||
| mul3_x, add2_y) | |||||
| outputs = make_tuple(tuple_getitem(adam_apply_one_assign, 0), tuple_getitem(adam_apply_one_assign, 1), | |||||
| tuple_getitem(adam_apply_one_assign, 2)) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return make_tuple(output) | |||||
| return fns[tag] | |||||