| @@ -92,6 +92,7 @@ | |||||
| #include "backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h" | #include "backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h" | ||||
| #include "backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" | #include "backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" | ||||
| #include "backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.h" | #include "backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.h" | ||||
| #include "backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h" | |||||
| #include "backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" | #include "backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" | ||||
| #include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" | #include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" | ||||
| #include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" | #include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" | ||||
| @@ -435,6 +436,7 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator)); | ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator)); | ||||
| ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator)); | ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator)); | ||||
| ub_fusion_pm->AddPass(std::make_shared<MatmulConfusionTranposeFusionPass>(fusion_id_allocator)); | ub_fusion_pm->AddPass(std::make_shared<MatmulConfusionTranposeFusionPass>(fusion_id_allocator)); | ||||
| ub_fusion_pm->AddPass(std::make_shared<BatchMatmulFusedMulAddFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<UbPatternFusion>()); | ub_fusion_pm->AddPass(std::make_shared<UbPatternFusion>()); | ||||
| optimizer->AddPassManager(ub_fusion_pm); | optimizer->AddPassManager(ub_fusion_pm); | ||||
| (void)optimizer->Optimize(kernel_graph); | (void)optimizer->Optimize(kernel_graph); | ||||
| @@ -0,0 +1,66 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h" | |||||
| #include <vector> | |||||
| #include <unordered_set> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "backend/kernel_compiler/kernel_fusion.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "base/core_ops.h" | |||||
| #include "utils/ms_context.h" | |||||
| #include "backend/optimizer/common/fusion_id_allocator.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| void BatchMatmulFusedMulAddFusionPass::MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, | |||||
| const session::KernelGraph &kernel_graph, | |||||
| FusedNodeRecord *candidate_fusion) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | |||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto batch_matmul = cnode->input(2); | |||||
| MS_EXCEPTION_IF_NULL(batch_matmul); | |||||
| if (batch_matmul->isa<CNode>() && AnfAlgo::CheckPrimitiveType(batch_matmul, prim::kPrimBatchMatMul)) { | |||||
| std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[batch_matmul].size())}; | |||||
| AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), batch_matmul); | |||||
| std::unordered_set<AnfNodePtr> record{cnode, batch_matmul}; | |||||
| candidate_fusion->push_back(record); | |||||
| SetRecordFusionId(record); | |||||
| } | |||||
| } | |||||
| void BatchMatmulFusedMulAddFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, | |||||
| FusedNodeRecord *candidate_fusion) { | |||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return()); | |||||
| for (auto &node : node_list) { | |||||
| if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || | |||||
| AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (AnfAlgo::GetCNodeName(cnode) == kFusedMulAddOpName) { | |||||
| MatchBatchMatmulFusedMulAdd(cnode, kernel_graph, candidate_fusion); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_PASS_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_PASS_H_ | |||||
| #include <unordered_set> | |||||
| #include <vector> | |||||
| #include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" | |||||
| #include "ir/anf.h" | |||||
| #include "backend/optimizer/common/pass.h" | |||||
| #include "backend/optimizer/common/fusion_id_allocator.h" | |||||
| #include "runtime/device/kernel_info.h" | |||||
| #include "backend/kernel_compiler/kernel.h" | |||||
| #include "backend/session/kernel_graph.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||||
| class BatchMatmulFusedMulAddFusionPass : public FusionBasePass { | |||||
| public: | |||||
| explicit BatchMatmulFusedMulAddFusionPass(FusionIdAllocatorPtr idAllocator) | |||||
| : FusionBasePass("BatchMatmulFusedMulAddFusionPass", idAllocator) {} | |||||
| ~BatchMatmulFusedMulAddFusionPass() override = default; | |||||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||||
| private: | |||||
| void MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, | |||||
| FusedNodeRecord *candidate_fusion); | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_PASS_H_ | |||||
| @@ -36,7 +36,8 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| auto matmul = cnode->input(1); | auto matmul = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(matmul); | MS_EXCEPTION_IF_NULL(matmul); | ||||
| if (matmul->isa<CNode>() && AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul)) { | |||||
| if (matmul->isa<CNode>() && (AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) || | |||||
| AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimBatchMatMul))) { | |||||
| std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[matmul].size())}; | std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[matmul].size())}; | ||||
| AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), matmul); | AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), matmul); | ||||
| std::unordered_set<AnfNodePtr> record{cnode, matmul}; | std::unordered_set<AnfNodePtr> record{cnode, matmul}; | ||||