| @@ -50,6 +50,7 @@ | |||||
| #include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" | #include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" | ||||
| #include "pre_activate/ascend/ir_fusion/derelu_fusion.h" | #include "pre_activate/ascend/ir_fusion/derelu_fusion.h" | ||||
| #include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h" | #include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h" | ||||
| #include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" | |||||
| #include "pre_activate/ascend/format_type/insert_trans_op.h" | #include "pre_activate/ascend/format_type/insert_trans_op.h" | ||||
| #include "pre_activate/pass/getitem_tuple.h" | #include "pre_activate/pass/getitem_tuple.h" | ||||
| #include "pre_activate/pass/optimize_dependence.h" | #include "pre_activate/pass/optimize_dependence.h" | ||||
| @@ -102,6 +103,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>()); | ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>()); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -96,6 +96,7 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat | |||||
| AnfNodePtr batchnorm_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); | AnfNodePtr batchnorm_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); | ||||
| MS_EXCEPTION_IF_NULL(batchnorm_anf); | MS_EXCEPTION_IF_NULL(batchnorm_anf); | ||||
| MS_EXCEPTION_IF_NULL(batchnorm); | |||||
| *batchnorm = batchnorm_anf->cast<CNodePtr>(); | *batchnorm = batchnorm_anf->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(*batchnorm); | MS_EXCEPTION_IF_NULL(*batchnorm); | ||||
| return CheckBatchNorm(graph, *batchnorm); | return CheckBatchNorm(graph, *batchnorm); | ||||
| @@ -0,0 +1,127 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "operator/ops.h" | |||||
| #include "pipeline/static_analysis/abstract_value.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| CNodePtr CreateBNInferGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(batchnormgrad); | |||||
| auto prim = std::make_shared<Primitive>(kBNInferGradOpName); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim)}; | |||||
| inputs.push_back(batchnormgrad->input(1)); | |||||
| inputs.push_back(batchnormgrad->input(3)); | |||||
| inputs.push_back(batchnormgrad->input(5)); | |||||
| auto new_node = graph->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(new_node); | |||||
| new_node->set_scope(batchnormgrad->scope()); | |||||
| new_node->set_abstract(node->abstract()); | |||||
| AnfAlgo::CopyNodeAttr(kAttrIsTraining, batchnormgrad, new_node); | |||||
| AnfAlgo::CopyNodeAttr(kAttrEpsilon, batchnormgrad, new_node); | |||||
| return new_node; | |||||
| } | |||||
| bool CheckIndex(const AnfNodePtr &index_node) { | |||||
| MS_EXCEPTION_IF_NULL(index_node); | |||||
| if (!IsValueNode<Int32Imm>(index_node)) { | |||||
| return false; | |||||
| } | |||||
| ValueNodePtr value_node = index_node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| int index = GetValue<int>(value_node->value()); | |||||
| if (index != 0) { | |||||
| MS_LOG(DEBUG) << "tuple_getitem must be 0th output of BatchNormGrad"; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(batchnormgrad); | |||||
| if (batchnormgrad->size() < kBatchNormInputNum + 1) { | |||||
| MS_LOG(DEBUG) << "BatchNormGrad's input less than " << kBatchNormInputNum; | |||||
| return false; | |||||
| } | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnormgrad)) { | |||||
| return false; | |||||
| } | |||||
| auto is_training = AnfAlgo::GetNodeAttr<bool>(batchnormgrad, kAttrIsTraining); | |||||
| if (is_training) { | |||||
| MS_LOG(DEBUG) << "is_training is true, no need do fusion"; | |||||
| return false; | |||||
| } | |||||
| if (IsUsedByOthers(graph, batchnormgrad)) { | |||||
| MS_LOG(DEBUG) << "Only the 0th output of BatchNormGrad is used, then do fusion"; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnormgrad) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto tuple_getitem = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | |||||
| CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize); | |||||
| AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); | |||||
| MS_EXCEPTION_IF_NULL(index_node); | |||||
| if (!CheckIndex(index_node)) { | |||||
| return false; | |||||
| } | |||||
| AnfNodePtr batchnormgrad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); | |||||
| MS_EXCEPTION_IF_NULL(batchnormgrad_anf); | |||||
| MS_EXCEPTION_IF_NULL(batchnormgrad); | |||||
| *batchnormgrad = batchnormgrad_anf->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(*batchnormgrad); | |||||
| return CheckBatchNormGrad(graph, *batchnormgrad); | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef BatchNormGrad2BNInferGrad::DefinePattern() const { | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| VarPtr Y = std::make_shared<Var>(); | |||||
| MS_EXCEPTION_IF_NULL(Xs); | |||||
| MS_EXCEPTION_IF_NULL(Y); | |||||
| VectorRef batchnormgrad({prim::kPrimBatchNormGrad, Xs}); | |||||
| VectorRef pattern({prim::kPrimTupleGetItem, batchnormgrad, Y}); | |||||
| return pattern; | |||||
| } | |||||
| const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| CNodePtr batchnormgrad = nullptr; | |||||
| if (!NeedFusion(graph, node, &batchnormgrad)) { | |||||
| return nullptr; | |||||
| } | |||||
| return CreateBNInferGrad(graph, batchnormgrad, node); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ | |||||
| #include <memory> | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class BatchNormGrad2BNInferGrad : public PatternProcessPass { | |||||
| public: | |||||
| explicit BatchNormGrad2BNInferGrad(bool multigraph = true) | |||||
| : PatternProcessPass("batchnormgrad_to_bninfergrad", multigraph) {} | |||||
| ~BatchNormGrad2BNInferGrad() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ | |||||
| @@ -0,0 +1,73 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/backend_common_test.h" | |||||
| #include "common/py_func_graph_fetcher.h" | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| #include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class TestHWOptimizeBatchNormGrad2BNInferGrad : public BackendCommon { | |||||
| public: | |||||
| TestHWOptimizeBatchNormGrad2BNInferGrad() | |||||
| : get_py_fun_("gtest_input.pre_activate.batchnormgrad_to_bninfergrad", true) {} | |||||
| ~TestHWOptimizeBatchNormGrad2BNInferGrad() override = default; | |||||
| UT::PyFuncGraphFetcher get_py_fun_; | |||||
| }; | |||||
| TEST_F(TestHWOptimizeBatchNormGrad2BNInferGrad, test_fusion) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batchnormgrad_to_bninfergrad", "before"); | |||||
| EXPECT_NE(g, nullptr); | |||||
| std::vector<int> shp_x{32, 64, 112, 112}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| std::vector<int> shp_y{64}; | |||||
| auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y); | |||||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, y_abstract, y_abstract, y_abstract}; | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::BatchNormGrad2BNInferGrad>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_batchnormgrad_to_bninfergrad", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWOptimizeBatchNormGrad2BNInferGrad, test_no_fusion) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batchnormgrad_to_bninfergrad", "no_fusion"); | |||||
| EXPECT_NE(g, nullptr); | |||||
| std::vector<int> shp_x{32, 64, 112, 112}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| std::vector<int> shp_y{64}; | |||||
| auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y); | |||||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, y_abstract, y_abstract, y_abstract}; | |||||
| auto fg = GetKernelGraph(g, args_spec_list); | |||||
| auto origin_graph = std::make_shared<session::KernelGraph>(*fg); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| pm->AddPass(std::make_shared<opt::BatchNormGrad2BNInferGrad>()); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,57 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| from mindspore.ops import Primitive | |||||
| batch_norm_grad = G.BatchNormGrad(is_training=False) | |||||
| bn_infer_grad = Primitive('BNInferGrad') | |||||
| make_tuple = Primitive('make_tuple') | |||||
| tuple_getitem = Primitive('tuple_getitem') | |||||
| class FnDict: | |||||
| def __init__(self): | |||||
| self.fnDict = {} | |||||
| def __call__(self, fn): | |||||
| self.fnDict[fn.__name__] = fn | |||||
| def __getitem__(self, name): | |||||
| return self.fnDict[name] | |||||
| def test_batchnormgrad_to_bninfergrad(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(input0, input1, input2, input3, input4): | |||||
| res = batch_norm_grad(input0, input1, input2, input3, input4) | |||||
| res = tuple_getitem(res, 0) | |||||
| return res | |||||
| @fns | |||||
| def after(input0, input1, input2, input3, input4): | |||||
| res = bn_infer_grad(input0, input2, input4) | |||||
| return make_tuple(res) | |||||
| @fns | |||||
| def no_fusion(input0, input1, input2, input3, input4): | |||||
| res = batch_norm_grad(input0, input1, input2, input3, input4) | |||||
| item0 = tuple_getitem(res, 0) | |||||
| item1 = tuple_getitem(res, 1) | |||||
| item2 = tuple_getitem(res, 2) | |||||
| return make_tuple(item0, item1, item2) | |||||
| return fns[tag] | |||||