| @@ -201,6 +201,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| } else { | |||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion>()); | |||
| } | |||
| ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>()); | |||
| if (context_ptr->ir_fusion_flag()) { | |||
| @@ -277,5 +277,28 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c | |||
| } | |||
| return bn_training_update_outputs[0]; | |||
| } | |||
| const BaseRef FusedBatchNormMixPrecisionFusion::DefinePattern() const { | |||
| std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); | |||
| VarPtr index0 = std::make_shared<CondVar>(IsC); | |||
| VarPtr index1 = std::make_shared<CondVar>(IsC); | |||
| VarPtr index2 = std::make_shared<CondVar>(IsC); | |||
| VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); | |||
| VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); | |||
| VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); | |||
| VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); | |||
| VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_}); | |||
| VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_}); | |||
| VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1}); | |||
| VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2}); | |||
| VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); | |||
| VectorRef cast2 = VectorRef({prim::kPrimCast, mul0}); | |||
| VectorRef cast3 = VectorRef({prim::kPrimCast, mul1}); | |||
| VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, cast2}); | |||
| VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, cast3}); | |||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); | |||
| return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -18,6 +18,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "pre_activate/common/optimizer.h" | |||
| #include "utils/utils.h" | |||
| @@ -25,8 +26,8 @@ namespace mindspore { | |||
| namespace opt { | |||
| class FusedBatchNormFusion : public PatternProcessPass { | |||
| public: | |||
| explicit FusedBatchNormFusion(bool multigraph = true) | |||
| : PatternProcessPass("fused_batch_norm_fusion", multigraph), | |||
| explicit FusedBatchNormFusion(const std::string &name = "fused_batch_norm_fusion", bool multigraph = true) | |||
| : PatternProcessPass(name, multigraph), | |||
| data_input0_var_(std::make_shared<Var>()), | |||
| data_input1_var_(std::make_shared<Var>()), | |||
| data_input2_var_(std::make_shared<Var>()), | |||
| @@ -39,7 +40,7 @@ class FusedBatchNormFusion : public PatternProcessPass { | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| protected: | |||
| AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const; | |||
| void GetBNTrainingUpdateInputs(const EquivPtr &equiv, const std::vector<AnfNodePtr> &bn_training_reduce_outputs, | |||
| @@ -59,6 +60,15 @@ class FusedBatchNormFusion : public PatternProcessPass { | |||
| VarPtr constant_input1_var_; | |||
| VarPtr batch_norm_var_; | |||
| }; | |||
| class FusedBatchNormMixPrecisionFusion : public FusedBatchNormFusion { | |||
| public: | |||
| explicit FusedBatchNormMixPrecisionFusion(bool multigraph = true) | |||
| : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} | |||
| ~FusedBatchNormMixPrecisionFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ | |||
| @@ -50,5 +50,28 @@ TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_fusion) { | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "before_mix_precision"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp_x{32, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| std::vector<int> shp_y{64}; | |||
| auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y); | |||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||
| for (size_t i = 0; i < 6; ++i) { | |||
| args_spec_list.push_back(y_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::FusedBatchNormMixPrecisionFusion>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -24,6 +24,7 @@ make_tuple = Primitive('make_tuple') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| depend = Primitive('depend') | |||
| BatchNorm = P.BatchNorm() | |||
| Cast = P.Cast() | |||
| BNTrainingReduce = Primitive('BNTrainingReduce') | |||
| BNTrainingUpdate = Primitive('BNTrainingUpdate') | |||
| constant0 = Tensor(0.1, mstype.float32) | |||
| @@ -59,6 +60,21 @@ def test_fused_batch_norm_fusion(tag): | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def before_mix_precision(input0, input1, input2, input3, input4, var0, var1): | |||
| batch_norm = BatchNorm(input0, input1, input2, input3, input4) | |||
| sub0 = Sub(Cast(var0, mstype.float32), tuple_getitem(batch_norm, 1)) | |||
| sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2)) | |||
| mul0 = Mul(sub0, constant0) | |||
| mul1 = Mul(sub1, constant1) | |||
| assign_sub0 = AssignSub(var0, Cast(mul0, mstype.float32)) | |||
| assign_sub1 = AssignSub(var1, Cast(mul1, mstype.float32)) | |||
| depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0) | |||
| depend1 = depend(depend0, assign_sub1) | |||
| outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4)) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4, var0, var1): | |||
| bn_training_reduce = BNTrainingReduce(input0) | |||