Browse Source

Add batch norm fusion pattern for mix precision

tags/v0.5.0-beta
yujianfeng 5 years ago
parent
commit
e87ac6525e
5 changed files with 80 additions and 9 deletions
  1. +2
    -1
      mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
  2. +24
    -1
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc
  3. +12
    -3
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h
  4. +26
    -3
      tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc
  5. +16
    -1
      tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py

+ 2
- 1
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc View File

@@ -239,7 +239,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
} else { } else {
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>()); ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
} }
ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>()); ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());


+ 24
- 1
mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc View File

@@ -291,7 +291,7 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c
return bn_training_update_outputs[0]; return bn_training_update_outputs[0];
} }


const BaseRef FusedBatchNormMixPrecisionFusion::DefinePattern() const {
const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const {
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
VarPtr index0 = std::make_shared<CondVar>(IsC); VarPtr index0 = std::make_shared<CondVar>(IsC);
VarPtr index1 = std::make_shared<CondVar>(IsC); VarPtr index1 = std::make_shared<CondVar>(IsC);
@@ -313,5 +313,28 @@ const BaseRef FusedBatchNormMixPrecisionFusion::DefinePattern() const {
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
} }

const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const {
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
VarPtr index0 = std::make_shared<CondVar>(IsC);
VarPtr index1 = std::make_shared<CondVar>(IsC);
VarPtr index2 = std::make_shared<CondVar>(IsC);
VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs});
VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0});
VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1});
VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2});
VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_});
VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_});
VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1});
VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2});
VectorRef cast0 = VectorRef({prim::kPrimCast, sub0});
VectorRef cast1 = VectorRef({prim::kPrimCast, sub1});
VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_});
VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_});
VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0});
VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1});
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

+ 12
- 3
mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h View File

@@ -61,12 +61,21 @@ class FusedBatchNormFusion : public PatternProcessPass {
VarPtr batch_norm_var_; VarPtr batch_norm_var_;
}; };


class FusedBatchNormMixPrecisionFusion : public FusedBatchNormFusion {
class FusedBatchNormMixPrecisionFusion0 : public FusedBatchNormFusion {
public: public:
explicit FusedBatchNormMixPrecisionFusion(bool multigraph = true)
explicit FusedBatchNormMixPrecisionFusion0(bool multigraph = true)
: FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {}


~FusedBatchNormMixPrecisionFusion() override = default;
~FusedBatchNormMixPrecisionFusion0() override = default;
const BaseRef DefinePattern() const override;
};

class FusedBatchNormMixPrecisionFusion1 : public FusedBatchNormFusion {
public:
explicit FusedBatchNormMixPrecisionFusion1(bool multigraph = true)
: FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {}

~FusedBatchNormMixPrecisionFusion1() override = default;
const BaseRef DefinePattern() const override; const BaseRef DefinePattern() const override;
}; };
} // namespace opt } // namespace opt


+ 26
- 3
tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc View File

@@ -51,8 +51,8 @@ TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_fusion) {
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
} }


TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "before_mix_precision");
TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion0) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "before_mix_precision0");
EXPECT_NE(g, nullptr); EXPECT_NE(g, nullptr);
std::vector<int> shp_x{32, 64, 112, 112}; std::vector<int> shp_x{32, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
@@ -66,7 +66,30 @@ TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion) {


auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>(); auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::FusedBatchNormMixPrecisionFusion>());
pm->AddPass(std::make_shared<opt::FusedBatchNormMixPrecisionFusion0>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);

FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}

TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion1) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "before_mix_precision1");
EXPECT_NE(g, nullptr);
std::vector<int> shp_x{32, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
std::vector<int> shp_y{64};
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
AbstractBasePtrList args_spec_list{x_abstract};
for (size_t i = 0; i < 6; ++i) {
args_spec_list.push_back(y_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);

auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::FusedBatchNormMixPrecisionFusion1>());
optimizer->AddPassManager(pm); optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg); FuncGraphPtr new_graph = optimizer->Optimize(kg);




+ 16
- 1
tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py View File

@@ -61,7 +61,7 @@ def test_fused_batch_norm_fusion(tag):
return output return output


@fns @fns
def before_mix_precision(input0, input1, input2, input3, input4, var0, var1):
def before_mix_precision0(input0, input1, input2, input3, input4, var0, var1):
batch_norm = BatchNorm(input0, input1, input2, input3, input4) batch_norm = BatchNorm(input0, input1, input2, input3, input4)
sub0 = Sub(Cast(var0, mstype.float32), tuple_getitem(batch_norm, 1)) sub0 = Sub(Cast(var0, mstype.float32), tuple_getitem(batch_norm, 1))
sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2)) sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2))
@@ -75,6 +75,21 @@ def test_fused_batch_norm_fusion(tag):
output = tuple_getitem(outputs, 0) output = tuple_getitem(outputs, 0)
return output return output


@fns
def before_mix_precision1(input0, input1, input2, input3, input4, var0, var1):
batch_norm = BatchNorm(input0, input1, input2, input3, input4)
sub0 = Sub(Cast(var0, mstype.float32), tuple_getitem(batch_norm, 1))
sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2))
mul0 = Mul(Cast(sub0, mstype.float32), constant0)
mul1 = Mul(Cast(sub1, mstype.float32), constant1)
assign_sub0 = AssignSub(var0, mul0)
assign_sub1 = AssignSub(var1, mul1)
depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0)
depend1 = depend(depend0, assign_sub1)
outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4))
output = tuple_getitem(outputs, 0)
return output

@fns @fns
def after(input0, input1, input2, input3, input4, var0, var1): def after(input0, input1, input2, input3, input4, var0, var1):
bn_training_reduce = BNTrainingReduce(input0) bn_training_reduce = BNTrainingReduce(input0)


Loading…
Cancel
Save