| @@ -65,6 +65,8 @@ | |||
| #include "pre_activate/ascend/buffer_fusion/buffer_fusion.h" | |||
| #include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" | |||
| #include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h" | |||
| #include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" | |||
| #include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" | |||
| #include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h" | |||
| #include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h" | |||
| #include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" | |||
| @@ -336,16 +338,18 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| fusion_id_allocator->Init(); | |||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||
| auto ub_fusion_pm = std::make_shared<PassManager>("ub_fusion_pm"); | |||
| ub_fusion_pm->AddPass(std::make_shared<ConvDoubleInFusionPass>(fusion_id_allocator.get())); | |||
| ub_fusion_pm->AddPass(std::make_shared<ConvSingleInFusionPass>(fusion_id_allocator.get())); | |||
| ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator.get())); | |||
| ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator.get())); | |||
| ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseFusionPass>(fusion_id_allocator.get())); | |||
| ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator.get())); | |||
| ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseEltwiseFusionPass>(fusion_id_allocator.get())); | |||
| ub_fusion_pm->AddPass(std::make_shared<ConvBnReduceFusionPass>(fusion_id_allocator.get())); | |||
| ub_fusion_pm->AddPass(std::make_shared<ReduceEltwiseFusionPass>(fusion_id_allocator.get())); | |||
| ub_fusion_pm->AddPass(std::make_shared<SegmentEltwiseFusionPass>(fusion_id_allocator.get())); | |||
| ub_fusion_pm->AddPass(std::make_shared<Conv2DBackpropEltwiseEltwiseFusionPass>(fusion_id_allocator)); | |||
| ub_fusion_pm->AddPass(std::make_shared<Conv2DBackpropEltwiseFusionPass>(fusion_id_allocator)); | |||
| ub_fusion_pm->AddPass(std::make_shared<ConvBnReduceFusionPass>(fusion_id_allocator)); | |||
| ub_fusion_pm->AddPass(std::make_shared<ConvSingleInFusionPass>(fusion_id_allocator)); | |||
| ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseFusionPass>(fusion_id_allocator)); | |||
| ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseEltwiseFusionPass>(fusion_id_allocator)); | |||
| ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator)); | |||
| ub_fusion_pm->AddPass(std::make_shared<ConvDoubleInFusionPass>(fusion_id_allocator)); | |||
| ub_fusion_pm->AddPass(std::make_shared<ReduceEltwiseFusionPass>(fusion_id_allocator)); | |||
| ub_fusion_pm->AddPass(std::make_shared<SegmentEltwiseFusionPass>(fusion_id_allocator)); | |||
| ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator)); | |||
| ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator)); | |||
| ub_fusion_pm->AddPass(std::make_shared<UbPatternFusion>()); | |||
| optimizer->AddPassManager(ub_fusion_pm); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||
| class BnupdateEltwiseEltwiseFusionPass : public FusionBasePass { | |||
| public: | |||
| explicit BnupdateEltwiseEltwiseFusionPass(FusionIdAllocator *idAllocator) | |||
| explicit BnupdateEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||
| : FusionBasePass("BnupdateEltwiseEltwiseFusionPass", idAllocator) {} | |||
| ~BnupdateEltwiseEltwiseFusionPass() override = default; | |||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||
| class BnupdateEltwiseFusionPass : public FusionBasePass { | |||
| public: | |||
| explicit BnupdateEltwiseFusionPass(FusionIdAllocator *idAllocator) | |||
| explicit BnupdateEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||
| : FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {} | |||
| ~BnupdateEltwiseFusionPass() override = default; | |||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||
| @@ -0,0 +1,76 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" | |||
| #include <vector> | |||
| #include <unordered_set> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "kernel/kernel_fusion.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "operator/ops.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "pre_activate/common/fusion_id_allocator.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltwise( | |||
| const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | |||
| auto manager = kernel_graph.manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| std::unordered_set<AnfNodePtr> record{cnode}; | |||
| auto eltwise_input = cnode->input(1); | |||
| if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { | |||
| (void)record.insert(eltwise_input); | |||
| } else { | |||
| return; | |||
| } | |||
| auto input_cnode = eltwise_input->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(input_cnode); | |||
| auto double_in_eltwise_input = input_cnode->input(1); | |||
| if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || | |||
| fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { | |||
| return; | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(double_in_eltwise_input, prim::kPrimConv2DBackpropInput)) { | |||
| (void)record.insert(double_in_eltwise_input); | |||
| candidate_fusion->push_back(record); | |||
| SetRecordFusionId(record); | |||
| } | |||
| } | |||
| void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, | |||
| FusedNodeRecord *candidate_fusion) { | |||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | |||
| std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || | |||
| AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && | |||
| AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && | |||
| (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { | |||
| MatchConv2DBackpropInputEltwiseEltwise(cnode, kernel_graph, candidate_fusion); | |||
| } | |||
| } | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ | |||
| #include <unordered_set> | |||
| #include <vector> | |||
| #include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" | |||
| #include "ir/anf.h" | |||
| #include "pre_activate/common/pass.h" | |||
| #include "pre_activate/common/fusion_id_allocator.h" | |||
| #include "device/kernel_info.h" | |||
| #include "kernel/kernel.h" | |||
| #include "session/kernel_graph.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||
| class Conv2DBackpropEltwiseEltwiseFusionPass : public FusionBasePass { | |||
| public: | |||
| explicit Conv2DBackpropEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||
| : FusionBasePass("Conv2DBackpropEltwiseEltwiseFusionPass", idAllocator) {} | |||
| ~Conv2DBackpropEltwiseEltwiseFusionPass() override = default; | |||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||
| private: | |||
| void MatchConv2DBackpropInputEltwiseEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, | |||
| FusedNodeRecord *candidate_fusion); | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" | |||
| #include <vector> | |||
| #include <unordered_set> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "kernel/kernel_fusion.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "operator/ops.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "pre_activate/common/fusion_id_allocator.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, | |||
| const session::KernelGraph &kernel_graph, | |||
| FusedNodeRecord *candidate_fusion) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | |||
| auto manager = kernel_graph.manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| std::unordered_set<AnfNodePtr> record{cnode}; | |||
| auto eltwise_input = cnode->input(1); | |||
| if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || | |||
| fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { | |||
| return; | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimConv2DBackpropInput)) { | |||
| (void)record.insert(eltwise_input); | |||
| candidate_fusion->push_back(record); | |||
| SetRecordFusionId(record); | |||
| } | |||
| } | |||
| void Conv2DBackpropEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, | |||
| FusedNodeRecord *candidate_fusion) { | |||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | |||
| std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || | |||
| AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && | |||
| AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && | |||
| (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { | |||
| MatchConv2DBackpropInputEltwise(cnode, kernel_graph, candidate_fusion); | |||
| } | |||
| } | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ | |||
| #include <unordered_set> | |||
| #include <vector> | |||
| #include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" | |||
| #include "ir/anf.h" | |||
| #include "pre_activate/common/pass.h" | |||
| #include "pre_activate/common/fusion_id_allocator.h" | |||
| #include "device/kernel_info.h" | |||
| #include "kernel/kernel.h" | |||
| #include "session/kernel_graph.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||
| class Conv2DBackpropEltwiseFusionPass : public FusionBasePass { | |||
| public: | |||
| explicit Conv2DBackpropEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||
| : FusionBasePass("Conv2DBackpropEltwiseFusionPass", idAllocator) {} | |||
| ~Conv2DBackpropEltwiseFusionPass() override = default; | |||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||
| private: | |||
| void MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, | |||
| FusedNodeRecord *candidate_fusion); | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ | |||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||
| class ConvBnReduceFusionPass : public FusionBasePass { | |||
| public: | |||
| explicit ConvBnReduceFusionPass(FusionIdAllocator *idAllocator) | |||
| explicit ConvBnReduceFusionPass(FusionIdAllocatorPtr idAllocator) | |||
| : FusionBasePass("ConvBnReduceFusionPass", idAllocator) {} | |||
| ~ConvBnReduceFusionPass() override = default; | |||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||
| @@ -27,19 +27,6 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| bool ConvDoubleInFusionPass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | |||
| return false; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto user_nodes = manager->node_users()[node]; | |||
| return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && | |||
| AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && | |||
| cnode->inputs().size() == ELTWISE_INPUT_SIZE; | |||
| } | |||
| void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, | |||
| FusedNodeRecord *candidate_fusion) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||
| class ConvDoubleInFusionPass : public FusionBasePass { | |||
| public: | |||
| explicit ConvDoubleInFusionPass(FusionIdAllocator *idAllocator) | |||
| explicit ConvDoubleInFusionPass(FusionIdAllocatorPtr idAllocator) | |||
| : FusionBasePass("ConvDoubleInFusionPass", idAllocator) {} | |||
| ~ConvDoubleInFusionPass() override = default; | |||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||
| @@ -41,7 +41,6 @@ class ConvDoubleInFusionPass : public FusionBasePass { | |||
| private: | |||
| void MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, | |||
| FusedNodeRecord *candidate_fusion); | |||
| bool CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||
| class ConvSingleInFusionPass : public FusionBasePass { | |||
| public: | |||
| explicit ConvSingleInFusionPass(FusionIdAllocator *idAllocator) | |||
| explicit ConvSingleInFusionPass(FusionIdAllocatorPtr idAllocator) | |||
| : FusionBasePass("ConvSingleInFusionPass", idAllocator) {} | |||
| ~ConvSingleInFusionPass() override = default; | |||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||
| class DepthwiseConvEltwiseFusionPass : public FusionBasePass { | |||
| public: | |||
| explicit DepthwiseConvEltwiseFusionPass(FusionIdAllocator *idAllocator) | |||
| explicit DepthwiseConvEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||
| : FusionBasePass("DepthwiseConvEltwiseFusionPass", idAllocator) {} | |||
| ~DepthwiseConvEltwiseFusionPass() override = default; | |||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||
| class EltwiseFusionPass : public FusionBasePass { | |||
| public: | |||
| explicit EltwiseFusionPass(FusionIdAllocator *idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {} | |||
| explicit EltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {} | |||
| ~EltwiseFusionPass() override = default; | |||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||
| @@ -36,6 +36,19 @@ bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePt | |||
| cnode->inputs().size() == ELTWISE_INPUT_SIZE; | |||
| } | |||
| bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | |||
| return false; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto user_nodes = manager->node_users()[node]; | |||
| return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && | |||
| AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && | |||
| cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE; | |||
| } | |||
| void FusionBasePass::SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record) { | |||
| auto id = fusion_id_allocator->AllocateFusionId(); | |||
| for (auto node : record) { | |||
| @@ -32,13 +32,14 @@ namespace opt { | |||
| const int8_t MAX_ELTWISE_NUM = 3; | |||
| const int8_t MIN_ELTWISE_SIZE = 2; | |||
| const int8_t ELTWISE_INPUT_SIZE = 2; | |||
| const int8_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3; | |||
| const int8_t ELTWISE_USE = 1; | |||
| const int8_t MAX_ELTWISE_SIZE = 6; | |||
| using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||
| class FusionBasePass : public Pass { | |||
| public: | |||
| FusionBasePass(const std::string &name, FusionIdAllocator *idAllocator) | |||
| FusionBasePass(const std::string &name, FusionIdAllocatorPtr idAllocator) | |||
| : Pass(name), fusion_id_allocator(idAllocator) {} | |||
| ~FusionBasePass() override = default; | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| @@ -49,7 +50,8 @@ class FusionBasePass : public Pass { | |||
| FusedNodeRecord *candidate_fusion) = 0; | |||
| void SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record); | |||
| bool CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); | |||
| FusionIdAllocator *fusion_id_allocator; | |||
| bool CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); | |||
| FusionIdAllocatorPtr fusion_id_allocator; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||
| class MatmulEltwiseFusionPass : public FusionBasePass { | |||
| public: | |||
| explicit MatmulEltwiseFusionPass(FusionIdAllocator *idAllocator) | |||
| explicit MatmulEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||
| : FusionBasePass("MatmulEltwiseFusionPass", idAllocator) {} | |||
| ~MatmulEltwiseFusionPass() override = default; | |||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||
| class ReduceEltwiseFusionPass : public FusionBasePass { | |||
| public: | |||
| explicit ReduceEltwiseFusionPass(FusionIdAllocator *idAllocator) | |||
| explicit ReduceEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||
| : FusionBasePass("ReduceEltwiseFusionPass", idAllocator) {} | |||
| ~ReduceEltwiseFusionPass() override = default; | |||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||
| class SegmentEltwiseFusionPass : public FusionBasePass { | |||
| public: | |||
| explicit SegmentEltwiseFusionPass(FusionIdAllocator *idAllocator) | |||
| explicit SegmentEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||
| : FusionBasePass("SegmentEltwiseFusionPass", idAllocator) {} | |||
| ~SegmentEltwiseFusionPass() override = default; | |||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||
| @@ -16,6 +16,7 @@ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ | |||
| #include <memory> | |||
| #include "ir/base.h" | |||
| namespace mindspore { | |||
| @@ -36,6 +37,7 @@ class FusionIdAllocator { | |||
| private: | |||
| int32_t fusion_id; | |||
| }; | |||
| using FusionIdAllocatorPtr = std::shared_ptr<FusionIdAllocator>; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||