diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index c1bc8ec638..8cbfecaf66 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -80,6 +80,7 @@ #include "backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc new file mode 100644 index 0000000000..d75511d790 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNodePtr &cnode, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + auto matmul = cnode->input(1); + MS_EXCEPTION_IF_NULL(matmul); + if (matmul->isa() && AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul)) { + std::vector output_used_num{SizeToInt(manager->node_users()[matmul].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), matmul); + std::unordered_set record{cnode, matmul}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void MatmulConfusionTranposeFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + + if (AnfAlgo::GetCNodeName(cnode) == kConfusionTransposeDOpName) { + MatchMatmulConfusionTranpose(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.h new file mode 100644 index 0000000000..b5d599a66c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_MATMUL_CONFUSIONTRANSPOSE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_MATMUL_CONFUSIONTRANSPOSE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class MatmulConfusionTranposeFusionPass : public FusionBasePass { + public: + explicit MatmulConfusionTranposeFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("MatmulConfusionTranposeFusionPass", idAllocator) {} + ~MatmulConfusionTranposeFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchMatmulConfusionTranpose(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_MATMUL_CONFUSIONTRANSPOSE_FUSION_PASS_H_