Browse Source

add BatchMatMul&FusedMulAdd, BatchMatmul&ConfusionTranpose UB fusion pass

pull/14264/head
yuchaojie 4 years ago
parent
commit
50f7f6b3de
4 changed files with 118 additions and 1 deletions
  1. +2
    -0
      mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
  2. +66
    -0
      mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.cc
  3. +48
    -0
      mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h
  4. +2
    -1
      mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc

+ 2
- 0
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc View File

@@ -92,6 +92,7 @@
#include "backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h"
#include "backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h"
#include "backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.h"
#include "backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h"
#include "backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h"
#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h"
#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h"
@@ -435,6 +436,7 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<MatmulConfusionTranposeFusionPass>(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared<MatmulConfusionTranposeFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulFusedMulAddFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<UbPatternFusion>()); ub_fusion_pm->AddPass(std::make_shared<UbPatternFusion>());
optimizer->AddPassManager(ub_fusion_pm); optimizer->AddPassManager(ub_fusion_pm);
(void)optimizer->Optimize(kernel_graph); (void)optimizer->Optimize(kernel_graph);


+ 66
- 0
mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.cc View File

@@ -0,0 +1,66 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h"
#include <vector>
#include <unordered_set>
#include <memory>
#include <string>
#include "backend/kernel_compiler/kernel_fusion.h"
#include "debug/anf_ir_dump.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "base/core_ops.h"
#include "utils/ms_context.h"
#include "backend/optimizer/common/fusion_id_allocator.h"

namespace mindspore {
namespace opt {
void BatchMatmulFusedMulAddFusionPass::MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode,
const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
auto batch_matmul = cnode->input(2);
MS_EXCEPTION_IF_NULL(batch_matmul);
if (batch_matmul->isa<CNode>() && AnfAlgo::CheckPrimitiveType(batch_matmul, prim::kPrimBatchMatMul)) {
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[batch_matmul].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), batch_matmul);
std::unordered_set<AnfNodePtr> record{cnode, batch_matmul};
candidate_fusion->push_back(record);
SetRecordFusionId(record);
}
}

void BatchMatmulFusedMulAddFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);

if (AnfAlgo::GetCNodeName(cnode) == kFusedMulAddOpName) {
MatchBatchMatmulFusedMulAdd(cnode, kernel_graph, candidate_fusion);
}
}
}
} // namespace opt
} // namespace mindspore

+ 48
- 0
mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h View File

@@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_PASS_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_PASS_H_

#include <unordered_set>
#include <vector>

#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h"
#include "ir/anf.h"
#include "backend/optimizer/common/pass.h"
#include "backend/optimizer/common/fusion_id_allocator.h"
#include "runtime/device/kernel_info.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/session/kernel_graph.h"

namespace mindspore {
namespace opt {
using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;

class BatchMatmulFusedMulAddFusionPass : public FusionBasePass {
public:
explicit BatchMatmulFusedMulAddFusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("BatchMatmulFusedMulAddFusionPass", idAllocator) {}
~BatchMatmulFusedMulAddFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

private:
void MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion);
};
} // namespace opt
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_PASS_H_

+ 2
- 1
mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc View File

@@ -36,7 +36,8 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
auto matmul = cnode->input(1); auto matmul = cnode->input(1);
MS_EXCEPTION_IF_NULL(matmul); MS_EXCEPTION_IF_NULL(matmul);
if (matmul->isa<CNode>() && AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul)) {
if (matmul->isa<CNode>() && (AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) ||
AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimBatchMatMul))) {
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[matmul].size())}; std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[matmul].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), matmul); AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), matmul);
std::unordered_set<AnfNodePtr> record{cnode, matmul}; std::unordered_set<AnfNodePtr> record{cnode, matmul};


Loading…
Cancel
Save