| @@ -21,6 +21,7 @@ | |||
| #include "pre_activate/ascend/ir_fission/bn_grad_split.h" | |||
| #include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h" | |||
| #include "pre_activate/ascend/ir_fission/batch_norm_bert_fission.h" | |||
| #include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h" | |||
| #include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h" | |||
| #include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h" | |||
| #include "pre_activate/pass/communication_op_fusion.h" | |||
| @@ -240,6 +241,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>()); | |||
| } | |||
| ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | |||
| @@ -27,24 +27,6 @@ const std::vector<int> kOutputIndex{0, 3, 4, 5}; | |||
| constexpr size_t kBatchNormRealOutputNum = 3; | |||
| constexpr size_t kBatchNormRealInputNum = 3; | |||
| bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) { | |||
| MS_EXCEPTION_IF_NULL(n1); | |||
| MS_EXCEPTION_IF_NULL(n2); | |||
| auto n1_cnode = n1->cast<CNodePtr>(); | |||
| auto n2_cnode = n2->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(n1_cnode); | |||
| MS_EXCEPTION_IF_NULL(n2_cnode); | |||
| auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(index_input1); | |||
| auto value_node1 = index_input1->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node1); | |||
| auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(index_input2); | |||
| auto value_node2 = index_input2->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node2); | |||
| return GetValue<int>(value_node1->value()) < GetValue<int>(value_node2->value()); | |||
| } | |||
| bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(bn_outputs); | |||
| @@ -0,0 +1,169 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "pre_activate/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| const std::vector<int> kOutputIndex{0, 1, 2, 3, 4}; | |||
| constexpr size_t kBatchNormRealOutputNum = 5; | |||
| constexpr size_t kBatchNormRealInputNum = 3; | |||
| bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(bn_outputs); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto iter = manager->node_users().find(bn); | |||
| if (iter == manager->node_users().end()) { | |||
| return false; | |||
| } | |||
| size_t output_num = 0; | |||
| for (const auto &node_index : iter->second) { | |||
| AnfNodePtr output = node_index.first; | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { | |||
| continue; | |||
| } | |||
| auto tuple_getiterm_cnode = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); | |||
| auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(index_node); | |||
| auto value_node = index_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| int index = GetValue<int>(value_node->value()); | |||
| if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) { | |||
| return false; | |||
| } | |||
| bn_outputs->push_back(output); | |||
| output_num++; | |||
| } | |||
| return output_num == kBatchNormRealOutputNum; | |||
| } | |||
| AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(bn); | |||
| auto bn_cnode = bn->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(bn_cnode); | |||
| if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { | |||
| MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " | |||
| << kBatchNormRealInputNum + 1; | |||
| } | |||
| std::vector<AnfNodePtr> bn_training_reduce_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(1)}; | |||
| auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); | |||
| MS_EXCEPTION_IF_NULL(bn_training_reduce); | |||
| // set abstract | |||
| auto bn_input1 = bn_cnode->input(2); | |||
| MS_EXCEPTION_IF_NULL(bn_input1); | |||
| AbstractBasePtrList abstract_list{bn_input1->abstract(), bn_input1->abstract()}; | |||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||
| bn_training_reduce->set_abstract(abstract_tuple); | |||
| bn_training_reduce->set_scope(bn->scope()); | |||
| return bn_training_reduce; | |||
| } | |||
| AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, | |||
| const std::vector<AnfNodePtr> &bn_training_reduce_outputs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(bn); | |||
| auto bn_cnode = bn->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(bn_cnode); | |||
| if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { | |||
| MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " | |||
| << kBatchNormRealInputNum + 1; | |||
| } | |||
| if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { | |||
| MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum | |||
| << ", but it is " << bn_training_reduce_outputs.size(); | |||
| } | |||
| std::vector<AnfNodePtr> bn_training_update_v3_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateV3OpName)), | |||
| bn_cnode->input(1), | |||
| bn_training_reduce_outputs[0], | |||
| bn_training_reduce_outputs[1], | |||
| bn_cnode->input(2), | |||
| bn_cnode->input(3)}; | |||
| auto bn_training_update_v3 = func_graph->NewCNode(bn_training_update_v3_inputs); | |||
| MS_EXCEPTION_IF_NULL(bn_training_update_v3); | |||
| auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract()); | |||
| MS_EXCEPTION_IF_NULL(bn_abstract_tuple); | |||
| if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) { | |||
| MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is " | |||
| << bn_abstract_tuple->elements().size(); | |||
| } | |||
| bn_training_update_v3->set_abstract(bn->abstract()); | |||
| bn_training_update_v3->set_scope(bn->scope()); | |||
| AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_cnode, bn_training_update_v3); | |||
| return bn_training_update_v3; | |||
| } | |||
| } // namespace | |||
| const BaseRef SingleBatchNormFission::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({prim::kPrimBatchNorm, Xs}); | |||
| } | |||
| const AnfNodePtr SingleBatchNormFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| std::vector<AnfNodePtr> bn_outputs; | |||
| if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) { | |||
| MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed"; | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->inputs().size() < kBatchNormRealInputNum + 1) { | |||
| MS_LOG(INFO) << "The input num of BatchNorm less than" << kBatchNormRealInputNum | |||
| << ". The node should not be changed"; | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); | |||
| std::vector<AnfNodePtr> bn_training_reduce_outputs; | |||
| CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, | |||
| &bn_training_reduce_outputs); | |||
| AnfNodePtr bn_training_update_v3 = CreateBNTrainingUpdateV3(func_graph, node, bn_training_reduce_outputs); | |||
| std::vector<AnfNodePtr> bn_training_update_v3_outputs; | |||
| CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_v3, kBNTrainingUpdateV3OutputNum, | |||
| &bn_training_update_v3_outputs); | |||
| if (bn_training_update_v3_outputs.size() != kBNTrainingUpdateV3OutputNum) { | |||
| MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingUpdateV2OutputNum | |||
| << ", but it is " << bn_training_update_v3_outputs.size(); | |||
| } | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem); | |||
| size_t output_index = 0; | |||
| for (const auto &output : bn_outputs) { | |||
| (void)manager->Replace(output, bn_training_update_v3_outputs[output_index]); | |||
| output_index++; | |||
| } | |||
| // Return the new node for control depends. | |||
| return bn_training_update_v3; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ | |||
| #include "pre_activate/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class SingleBatchNormFission : public PatternProcessPass { | |||
| public: | |||
| explicit SingleBatchNormFission(bool multigraph = true) | |||
| : PatternProcessPass("single_batch_norm_fission", multigraph) {} | |||
| ~SingleBatchNormFission() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ | |||
| @@ -704,5 +704,23 @@ AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) { | |||
| } | |||
| return res; | |||
| } | |||
| bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) { | |||
| MS_EXCEPTION_IF_NULL(n1); | |||
| MS_EXCEPTION_IF_NULL(n2); | |||
| auto n1_cnode = n1->cast<CNodePtr>(); | |||
| auto n2_cnode = n2->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(n1_cnode); | |||
| MS_EXCEPTION_IF_NULL(n2_cnode); | |||
| auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(index_input1); | |||
| auto value_node1 = index_input1->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node1); | |||
| auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(index_input2); | |||
| auto value_node2 = index_input2->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node2); | |||
| return GetValue<int>(value_node1->value()) < GetValue<int>(value_node2->value()); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -65,6 +65,7 @@ constexpr size_t kBNGrad3OutputNum = 1; | |||
| constexpr size_t kBNTrainingReduceOutputNum = 2; | |||
| constexpr size_t kBNTrainingUpdateOutputNum = 5; | |||
| constexpr size_t kBNTrainingUpdateV2OutputNum = 3; | |||
| constexpr size_t kBNTrainingUpdateV3OutputNum = 5; | |||
| constexpr size_t kBNTrainingUpdateGradOutputNum = 2; | |||
| constexpr size_t kSingleOutputNum = 1; | |||
| @@ -176,6 +177,9 @@ bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &va | |||
| // Get anf_node from equiv by var_node | |||
| AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node); | |||
| // Compare tuple getitem's index, return bool[n1's index < n2's index] | |||
| bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ | |||
| @@ -55,6 +55,7 @@ constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches"; | |||
| constexpr auto kBNTrainingReduceOpName = "BNTrainingReduce"; | |||
| constexpr auto kBNTrainingUpdateOpName = "BNTrainingUpdate"; | |||
| constexpr auto kBNTrainingUpdateV2OpName = "BNTrainingUpdateV2"; | |||
| constexpr auto kBNTrainingUpdateV3OpName = "BNTrainingUpdateV3"; | |||
| constexpr auto kSimpleMeanGradOpName = "SimpleMeanGrad"; | |||
| constexpr auto kMeanGradOpName = "MeanGrad"; | |||
| constexpr auto kSliceOpName = "Slice"; | |||
| @@ -80,6 +80,7 @@ TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_no_fission) { | |||
| args_spec_list.push_back(y_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto origin_graph = std::make_shared<session::KernelGraph>(*kg); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| @@ -87,7 +88,7 @@ TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_no_fission) { | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); | |||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h" | |||
| #include "common/backend_common_test.h" | |||
| #include "common/py_func_graph_fetcher.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TestHWSingleBatchNormFission : public BackendCommon { | |||
| public: | |||
| TestHWSingleBatchNormFission() : get_py_fun_("gtest_input.pre_activate.single_batch_norm_fission_test", true) {} | |||
| ~TestHWSingleBatchNormFission() override = default; | |||
| UT::PyFuncGraphFetcher get_py_fun_; | |||
| }; | |||
| TEST_F(TestHWSingleBatchNormFission, test_fission) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_single_batch_norm_fission", "before"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp_x{32, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| std::vector<int> shp_y{64}; | |||
| auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y); | |||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||
| for (size_t i = 0; i < 4; ++i) { | |||
| args_spec_list.push_back(y_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::SingleBatchNormFission>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_single_batch_norm_fission", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWSingleBatchNormFission, test_no_fission) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_single_batch_norm_fission", "before"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp_x{32, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| std::vector<int> shp_y{64}; | |||
| auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y); | |||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||
| for (size_t i = 0; i < 4; ++i) { | |||
| args_spec_list.push_back(y_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto origin_graph = std::make_shared<session::KernelGraph>(*kg); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::SingleBatchNormFission>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,64 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore.ops import Primitive | |||
| from mindspore.ops import operations as P | |||
| make_tuple = Primitive('make_tuple') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| BatchNorm = P.BatchNorm() | |||
| BNTrainingReduce = Primitive('BNTrainingReduce') | |||
| BNTrainingUpdateV3 = Primitive('BNTrainingUpdateV3') | |||
| class FnDict: | |||
| def __init__(self): | |||
| self.fnDict = {} | |||
| def __call__(self, fn): | |||
| self.fnDict[fn.__name__] = fn | |||
| def __getitem__(self, name): | |||
| return self.fnDict[name] | |||
| def test_single_batch_norm_fission(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4): | |||
| batch_norm = BatchNorm(input0, input1, input2, input3, input4) | |||
| item0 = tuple_getitem(batch_norm, 0) | |||
| item1 = tuple_getitem(batch_norm, 1) | |||
| item2 = tuple_getitem(batch_norm, 2) | |||
| item3 = tuple_getitem(batch_norm, 3) | |||
| item4 = tuple_getitem(batch_norm, 4) | |||
| output = make_tuple(item0, item1, item2, item3, item4) | |||
| return output | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4): | |||
| reduce = BNTrainingReduce(input0) | |||
| reduce_item0 = tuple_getitem(reduce, 0) | |||
| reduce_item1 = tuple_getitem(reduce, 1) | |||
| update = BNTrainingUpdateV3(input0, reduce_item0, reduce_item1, input1, input2) | |||
| update_item0 = tuple_getitem(update, 0) | |||
| update_item1 = tuple_getitem(update, 1) | |||
| update_item2 = tuple_getitem(update, 2) | |||
| update_item3 = tuple_getitem(update, 3) | |||
| update_item4 = tuple_getitem(update, 4) | |||
| output = make_tuple(update_item0, update_item1, update_item2, update_item3, update_item4) | |||
| return make_tuple(output) | |||
| return fns[tag] | |||