diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc index be8aa21854..640f84aa44 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc @@ -149,8 +149,17 @@ const BaseRef BatchNormBertFission::DefinePattern() const { const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); std::vector bn_outputs; if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) { + MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed"; + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() != kBatchNormRealInputNum + 1) { + MS_LOG(INFO) << "The input size of BatchNorm should be " << kBatchNormRealInputNum + << ". The node should not be changed"; return nullptr; } AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_bert_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_bert_fission_test.cc index e5abf56c2e..d3998f0736 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_bert_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_bert_fission_test.cc @@ -28,7 +28,7 @@ class TestHWBatchNormBertFission : public BackendCommon { UT::PyFuncGraphFetcher get_py_fun_; }; -TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) { +TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fission) { FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "before"); EXPECT_NE(g, nullptr); std::vector shp_x{32, 64, 112, 112}; @@ -40,6 +40,23 @@ TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) { args_spec_list.push_back(y_abstract); } auto kg = GetKernelGraph(g, args_spec_list); + auto ret = kg->get_return(); + EXPECT_NE(ret, nullptr); + auto make_tuple0 = ret->input(1); + EXPECT_NE(make_tuple0, nullptr); + auto tuple_getitem0 = make_tuple0->cast()->input(1); + EXPECT_NE(tuple_getitem0, nullptr); + auto make_tuple1 = tuple_getitem0->cast()->input(1); + EXPECT_NE(make_tuple1, nullptr); + auto tuple_getitem1 = make_tuple1->cast()->input(1); + EXPECT_NE(tuple_getitem1, nullptr); + auto bn = tuple_getitem1->cast()->input(1); + EXPECT_NE(bn, nullptr); + auto bn_cnode = bn->cast(); + EXPECT_NE(bn_cnode, nullptr); + auto inputs = bn_cnode->inputs(); + std::vector new_inputs(inputs.begin(), inputs.begin() + 4); + bn_cnode->set_inputs(new_inputs); auto optimizer = std::make_shared(); auto pm = std::make_shared(); @@ -50,5 +67,27 @@ TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_no_fission) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp_x{32, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + std::vector shp_y{64}; + auto y_abstract = std::make_shared(kFloat32, shp_y); + AbstractBasePtrList args_spec_list{x_abstract}; + for (size_t i = 0; i < 4; ++i) { + args_spec_list.push_back(y_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); +} } // namespace opt } // namespace mindspore