diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index d8b105dfca..4dbdcec55b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -92,6 +92,7 @@ #include "backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" @@ -435,6 +436,7 @@ void AscendBackendUBFusionOptimization(const std::shared_ptrAddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared()); optimizer->AddPassManager(ub_fusion_pm); (void)optimizer->Optimize(kernel_graph); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.cc new file mode 100644 index 0000000000..06cc1ae313 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "base/core_ops.h" +#include "utils/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void BatchMatmulFusedMulAddFusionPass::MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + auto batch_matmul = cnode->input(2); + MS_EXCEPTION_IF_NULL(batch_matmul); + if (batch_matmul->isa() && AnfAlgo::CheckPrimitiveType(batch_matmul, prim::kPrimBatchMatMul)) { + std::vector output_used_num{SizeToLong(manager->node_users()[batch_matmul].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), batch_matmul); + std::unordered_set record{cnode, batch_matmul}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void BatchMatmulFusedMulAddFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + + if (AnfAlgo::GetCNodeName(cnode) == kFusedMulAddOpName) { + MatchBatchMatmulFusedMulAdd(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h new file mode 100644 index 0000000000..974e9900a1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class BatchMatmulFusedMulAddFusionPass : public FusionBasePass { + public: + explicit BatchMatmulFusedMulAddFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("BatchMatmulFusedMulAddFusionPass", idAllocator) {} + ~BatchMatmulFusedMulAddFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc index abf0bb0074..7068e5e1ea 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc @@ -36,7 +36,8 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode MS_EXCEPTION_IF_NULL(manager); auto matmul = cnode->input(1); MS_EXCEPTION_IF_NULL(matmul); - if (matmul->isa() && AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul)) { + if (matmul->isa() && (AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) || + AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimBatchMatMul))) { std::vector output_used_num{SizeToLong(manager->node_users()[matmul].size())}; AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), matmul); std::unordered_set record{cnode, matmul};