| @@ -239,7 +239,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| } else { | } else { | ||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>()); | ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); | ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); | |||||
| } | } | ||||
| ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>()); | ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ||||
| @@ -291,7 +291,7 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c | |||||
| return bn_training_update_outputs[0]; | return bn_training_update_outputs[0]; | ||||
| } | } | ||||
| const BaseRef FusedBatchNormMixPrecisionFusion::DefinePattern() const { | |||||
| const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const { | |||||
| std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); | std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); | ||||
| VarPtr index0 = std::make_shared<CondVar>(IsC); | VarPtr index0 = std::make_shared<CondVar>(IsC); | ||||
| VarPtr index1 = std::make_shared<CondVar>(IsC); | VarPtr index1 = std::make_shared<CondVar>(IsC); | ||||
| @@ -313,5 +313,28 @@ const BaseRef FusedBatchNormMixPrecisionFusion::DefinePattern() const { | |||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); | VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); | ||||
| return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); | return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); | ||||
| } | } | ||||
| const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const { | |||||
| std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); | |||||
| VarPtr index0 = std::make_shared<CondVar>(IsC); | |||||
| VarPtr index1 = std::make_shared<CondVar>(IsC); | |||||
| VarPtr index2 = std::make_shared<CondVar>(IsC); | |||||
| VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); | |||||
| VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); | |||||
| VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); | |||||
| VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); | |||||
| VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_}); | |||||
| VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_}); | |||||
| VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1}); | |||||
| VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2}); | |||||
| VectorRef cast0 = VectorRef({prim::kPrimCast, sub0}); | |||||
| VectorRef cast1 = VectorRef({prim::kPrimCast, sub1}); | |||||
| VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_}); | |||||
| VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_}); | |||||
| VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); | |||||
| VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); | |||||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); | |||||
| return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -61,12 +61,21 @@ class FusedBatchNormFusion : public PatternProcessPass { | |||||
| VarPtr batch_norm_var_; | VarPtr batch_norm_var_; | ||||
| }; | }; | ||||
| class FusedBatchNormMixPrecisionFusion : public FusedBatchNormFusion { | |||||
| class FusedBatchNormMixPrecisionFusion0 : public FusedBatchNormFusion { | |||||
| public: | public: | ||||
| explicit FusedBatchNormMixPrecisionFusion(bool multigraph = true) | |||||
| explicit FusedBatchNormMixPrecisionFusion0(bool multigraph = true) | |||||
| : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} | : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} | ||||
| ~FusedBatchNormMixPrecisionFusion() override = default; | |||||
| ~FusedBatchNormMixPrecisionFusion0() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| }; | |||||
| class FusedBatchNormMixPrecisionFusion1 : public FusedBatchNormFusion { | |||||
| public: | |||||
| explicit FusedBatchNormMixPrecisionFusion1(bool multigraph = true) | |||||
| : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} | |||||
| ~FusedBatchNormMixPrecisionFusion1() override = default; | |||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -51,8 +51,8 @@ TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_fusion) { | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "before_mix_precision"); | |||||
| TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion0) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "before_mix_precision0"); | |||||
| EXPECT_NE(g, nullptr); | EXPECT_NE(g, nullptr); | ||||
| std::vector<int> shp_x{32, 64, 112, 112}; | std::vector<int> shp_x{32, 64, 112, 112}; | ||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | ||||
| @@ -66,7 +66,30 @@ TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion) { | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto pm = std::make_shared<opt::PassManager>(); | auto pm = std::make_shared<opt::PassManager>(); | ||||
| pm->AddPass(std::make_shared<opt::FusedBatchNormMixPrecisionFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::FusedBatchNormMixPrecisionFusion0>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion1) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "before_mix_precision1"); | |||||
| EXPECT_NE(g, nullptr); | |||||
| std::vector<int> shp_x{32, 64, 112, 112}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| std::vector<int> shp_y{64}; | |||||
| auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y); | |||||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||||
| for (size_t i = 0; i < 6; ++i) { | |||||
| args_spec_list.push_back(y_abstract); | |||||
| } | |||||
| auto kg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::FusedBatchNormMixPrecisionFusion1>()); | |||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | FuncGraphPtr new_graph = optimizer->Optimize(kg); | ||||
| @@ -61,7 +61,7 @@ def test_fused_batch_norm_fusion(tag): | |||||
| return output | return output | ||||
| @fns | @fns | ||||
| def before_mix_precision(input0, input1, input2, input3, input4, var0, var1): | |||||
| def before_mix_precision0(input0, input1, input2, input3, input4, var0, var1): | |||||
| batch_norm = BatchNorm(input0, input1, input2, input3, input4) | batch_norm = BatchNorm(input0, input1, input2, input3, input4) | ||||
| sub0 = Sub(Cast(var0, mstype.float32), tuple_getitem(batch_norm, 1)) | sub0 = Sub(Cast(var0, mstype.float32), tuple_getitem(batch_norm, 1)) | ||||
| sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2)) | sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2)) | ||||
| @@ -75,6 +75,21 @@ def test_fused_batch_norm_fusion(tag): | |||||
| output = tuple_getitem(outputs, 0) | output = tuple_getitem(outputs, 0) | ||||
| return output | return output | ||||
| @fns | |||||
| def before_mix_precision1(input0, input1, input2, input3, input4, var0, var1): | |||||
| batch_norm = BatchNorm(input0, input1, input2, input3, input4) | |||||
| sub0 = Sub(Cast(var0, mstype.float32), tuple_getitem(batch_norm, 1)) | |||||
| sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2)) | |||||
| mul0 = Mul(Cast(sub0, mstype.float32), constant0) | |||||
| mul1 = Mul(Cast(sub1, mstype.float32), constant1) | |||||
| assign_sub0 = AssignSub(var0, mul0) | |||||
| assign_sub1 = AssignSub(var1, mul1) | |||||
| depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0) | |||||
| depend1 = depend(depend0, assign_sub1) | |||||
| outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4)) | |||||
| output = tuple_getitem(outputs, 0) | |||||
| return output | |||||
| @fns | @fns | ||||
| def after(input0, input1, input2, input3, input4, var0, var1): | def after(input0, input1, input2, input3, input4, var0, var1): | ||||
| bn_training_reduce = BNTrainingReduce(input0) | bn_training_reduce = BNTrainingReduce(input0) | ||||