|
|
|
@@ -28,7 +28,7 @@ class TestHWBatchNormBertFission : public BackendCommon { |
|
|
|
UT::PyFuncGraphFetcher get_py_fun_; |
|
|
|
}; |
|
|
|
|
|
|
|
TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) { |
|
|
|
TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fission) { |
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "before"); |
|
|
|
EXPECT_NE(g, nullptr); |
|
|
|
std::vector<int> shp_x{32, 64, 112, 112}; |
|
|
|
@@ -40,6 +40,23 @@ TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) { |
|
|
|
args_spec_list.push_back(y_abstract); |
|
|
|
} |
|
|
|
auto kg = GetKernelGraph(g, args_spec_list); |
|
|
|
auto ret = kg->get_return(); |
|
|
|
EXPECT_NE(ret, nullptr); |
|
|
|
auto make_tuple0 = ret->input(1); |
|
|
|
EXPECT_NE(make_tuple0, nullptr); |
|
|
|
auto tuple_getitem0 = make_tuple0->cast<CNodePtr>()->input(1); |
|
|
|
EXPECT_NE(tuple_getitem0, nullptr); |
|
|
|
auto make_tuple1 = tuple_getitem0->cast<CNodePtr>()->input(1); |
|
|
|
EXPECT_NE(make_tuple1, nullptr); |
|
|
|
auto tuple_getitem1 = make_tuple1->cast<CNodePtr>()->input(1); |
|
|
|
EXPECT_NE(tuple_getitem1, nullptr); |
|
|
|
auto bn = tuple_getitem1->cast<CNodePtr>()->input(1); |
|
|
|
EXPECT_NE(bn, nullptr); |
|
|
|
auto bn_cnode = bn->cast<CNodePtr>(); |
|
|
|
EXPECT_NE(bn_cnode, nullptr); |
|
|
|
auto inputs = bn_cnode->inputs(); |
|
|
|
std::vector<AnfNodePtr> new_inputs(inputs.begin(), inputs.begin() + 4); |
|
|
|
bn_cnode->set_inputs(new_inputs); |
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>(); |
|
|
|
auto pm = std::make_shared<opt::PassManager>(); |
|
|
|
@@ -50,5 +67,27 @@ TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) { |
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "after"); |
|
|
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_no_fission) { |
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "before"); |
|
|
|
EXPECT_NE(g, nullptr); |
|
|
|
std::vector<int> shp_x{32, 64, 112, 112}; |
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); |
|
|
|
std::vector<int> shp_y{64}; |
|
|
|
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y); |
|
|
|
AbstractBasePtrList args_spec_list{x_abstract}; |
|
|
|
for (size_t i = 0; i < 4; ++i) { |
|
|
|
args_spec_list.push_back(y_abstract); |
|
|
|
} |
|
|
|
auto kg = GetKernelGraph(g, args_spec_list); |
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>(); |
|
|
|
auto pm = std::make_shared<opt::PassManager>(); |
|
|
|
pm->AddPass(std::make_shared<opt::BatchNormBertFission>()); |
|
|
|
optimizer->AddPassManager(pm); |
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(kg); |
|
|
|
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |