| @@ -65,6 +65,8 @@ | |||||
| #include "pre_activate/ascend/buffer_fusion/buffer_fusion.h" | #include "pre_activate/ascend/buffer_fusion/buffer_fusion.h" | ||||
| #include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" | #include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" | ||||
| #include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h" | #include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h" | ||||
| #include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" | |||||
| #include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h" | #include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h" | ||||
| #include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h" | #include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h" | ||||
| #include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" | #include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" | ||||
| @@ -336,16 +338,18 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| fusion_id_allocator->Init(); | fusion_id_allocator->Init(); | ||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | auto optimizer = std::make_shared<GraphOptimizer>(); | ||||
| auto ub_fusion_pm = std::make_shared<PassManager>("ub_fusion_pm"); | auto ub_fusion_pm = std::make_shared<PassManager>("ub_fusion_pm"); | ||||
| ub_fusion_pm->AddPass(std::make_shared<ConvDoubleInFusionPass>(fusion_id_allocator.get())); | |||||
| ub_fusion_pm->AddPass(std::make_shared<ConvSingleInFusionPass>(fusion_id_allocator.get())); | |||||
| ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator.get())); | |||||
| ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator.get())); | |||||
| ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseFusionPass>(fusion_id_allocator.get())); | |||||
| ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator.get())); | |||||
| ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseEltwiseFusionPass>(fusion_id_allocator.get())); | |||||
| ub_fusion_pm->AddPass(std::make_shared<ConvBnReduceFusionPass>(fusion_id_allocator.get())); | |||||
| ub_fusion_pm->AddPass(std::make_shared<ReduceEltwiseFusionPass>(fusion_id_allocator.get())); | |||||
| ub_fusion_pm->AddPass(std::make_shared<SegmentEltwiseFusionPass>(fusion_id_allocator.get())); | |||||
| ub_fusion_pm->AddPass(std::make_shared<Conv2DBackpropEltwiseEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<Conv2DBackpropEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<ConvBnReduceFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<ConvSingleInFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<ConvDoubleInFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<ReduceEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<SegmentEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator)); | |||||
| ub_fusion_pm->AddPass(std::make_shared<UbPatternFusion>()); | ub_fusion_pm->AddPass(std::make_shared<UbPatternFusion>()); | ||||
| optimizer->AddPassManager(ub_fusion_pm); | optimizer->AddPassManager(ub_fusion_pm); | ||||
| (void)optimizer->Optimize(kernel_graph); | (void)optimizer->Optimize(kernel_graph); | ||||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||||
| class BnupdateEltwiseEltwiseFusionPass : public FusionBasePass { | class BnupdateEltwiseEltwiseFusionPass : public FusionBasePass { | ||||
| public: | public: | ||||
| explicit BnupdateEltwiseEltwiseFusionPass(FusionIdAllocator *idAllocator) | |||||
| explicit BnupdateEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||||
| : FusionBasePass("BnupdateEltwiseEltwiseFusionPass", idAllocator) {} | : FusionBasePass("BnupdateEltwiseEltwiseFusionPass", idAllocator) {} | ||||
| ~BnupdateEltwiseEltwiseFusionPass() override = default; | ~BnupdateEltwiseEltwiseFusionPass() override = default; | ||||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | ||||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||||
| class BnupdateEltwiseFusionPass : public FusionBasePass { | class BnupdateEltwiseFusionPass : public FusionBasePass { | ||||
| public: | public: | ||||
| explicit BnupdateEltwiseFusionPass(FusionIdAllocator *idAllocator) | |||||
| explicit BnupdateEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||||
| : FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {} | : FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {} | ||||
| ~BnupdateEltwiseFusionPass() override = default; | ~BnupdateEltwiseFusionPass() override = default; | ||||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | ||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" | |||||
| #include <vector> | |||||
| #include <unordered_set> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "kernel/kernel_fusion.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "operator/ops.h" | |||||
| #include "utils/context/ms_context.h" | |||||
| #include "pre_activate/common/fusion_id_allocator.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltwise( | |||||
| const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | |||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::unordered_set<AnfNodePtr> record{cnode}; | |||||
| auto eltwise_input = cnode->input(1); | |||||
| if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { | |||||
| (void)record.insert(eltwise_input); | |||||
| } else { | |||||
| return; | |||||
| } | |||||
| auto input_cnode = eltwise_input->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(input_cnode); | |||||
| auto double_in_eltwise_input = input_cnode->input(1); | |||||
| if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || | |||||
| fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { | |||||
| return; | |||||
| } | |||||
| if (AnfAlgo::CheckPrimitiveType(double_in_eltwise_input, prim::kPrimConv2DBackpropInput)) { | |||||
| (void)record.insert(double_in_eltwise_input); | |||||
| candidate_fusion->push_back(record); | |||||
| SetRecordFusionId(record); | |||||
| } | |||||
| } | |||||
| void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, | |||||
| FusedNodeRecord *candidate_fusion) { | |||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return()); | |||||
| for (auto &node : node_list) { | |||||
| if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || | |||||
| AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && | |||||
| AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && | |||||
| (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { | |||||
| MatchConv2DBackpropInputEltwiseEltwise(cnode, kernel_graph, candidate_fusion); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ | |||||
| #include <unordered_set> | |||||
| #include <vector> | |||||
| #include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" | |||||
| #include "ir/anf.h" | |||||
| #include "pre_activate/common/pass.h" | |||||
| #include "pre_activate/common/fusion_id_allocator.h" | |||||
| #include "device/kernel_info.h" | |||||
| #include "kernel/kernel.h" | |||||
| #include "session/kernel_graph.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||||
| class Conv2DBackpropEltwiseEltwiseFusionPass : public FusionBasePass { | |||||
| public: | |||||
| explicit Conv2DBackpropEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||||
| : FusionBasePass("Conv2DBackpropEltwiseEltwiseFusionPass", idAllocator) {} | |||||
| ~Conv2DBackpropEltwiseEltwiseFusionPass() override = default; | |||||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||||
| private: | |||||
| void MatchConv2DBackpropInputEltwiseEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, | |||||
| FusedNodeRecord *candidate_fusion); | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ | |||||
| @@ -0,0 +1,69 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" | |||||
| #include <vector> | |||||
| #include <unordered_set> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "kernel/kernel_fusion.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "operator/ops.h" | |||||
| #include "utils/context/ms_context.h" | |||||
| #include "pre_activate/common/fusion_id_allocator.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, | |||||
| const session::KernelGraph &kernel_graph, | |||||
| FusedNodeRecord *candidate_fusion) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | |||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::unordered_set<AnfNodePtr> record{cnode}; | |||||
| auto eltwise_input = cnode->input(1); | |||||
| if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || | |||||
| fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { | |||||
| return; | |||||
| } | |||||
| if (AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimConv2DBackpropInput)) { | |||||
| (void)record.insert(eltwise_input); | |||||
| candidate_fusion->push_back(record); | |||||
| SetRecordFusionId(record); | |||||
| } | |||||
| } | |||||
| void Conv2DBackpropEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, | |||||
| FusedNodeRecord *candidate_fusion) { | |||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return()); | |||||
| for (auto &node : node_list) { | |||||
| if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || | |||||
| AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && | |||||
| AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && | |||||
| (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { | |||||
| MatchConv2DBackpropInputEltwise(cnode, kernel_graph, candidate_fusion); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ | |||||
| #include <unordered_set> | |||||
| #include <vector> | |||||
| #include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" | |||||
| #include "ir/anf.h" | |||||
| #include "pre_activate/common/pass.h" | |||||
| #include "pre_activate/common/fusion_id_allocator.h" | |||||
| #include "device/kernel_info.h" | |||||
| #include "kernel/kernel.h" | |||||
| #include "session/kernel_graph.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||||
| class Conv2DBackpropEltwiseFusionPass : public FusionBasePass { | |||||
| public: | |||||
| explicit Conv2DBackpropEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||||
| : FusionBasePass("Conv2DBackpropEltwiseFusionPass", idAllocator) {} | |||||
| ~Conv2DBackpropEltwiseFusionPass() override = default; | |||||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | |||||
| private: | |||||
| void MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, | |||||
| FusedNodeRecord *candidate_fusion); | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ | |||||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||||
| class ConvBnReduceFusionPass : public FusionBasePass { | class ConvBnReduceFusionPass : public FusionBasePass { | ||||
| public: | public: | ||||
| explicit ConvBnReduceFusionPass(FusionIdAllocator *idAllocator) | |||||
| explicit ConvBnReduceFusionPass(FusionIdAllocatorPtr idAllocator) | |||||
| : FusionBasePass("ConvBnReduceFusionPass", idAllocator) {} | : FusionBasePass("ConvBnReduceFusionPass", idAllocator) {} | ||||
| ~ConvBnReduceFusionPass() override = default; | ~ConvBnReduceFusionPass() override = default; | ||||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | ||||
| @@ -27,19 +27,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| bool ConvDoubleInFusionPass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | |||||
| return false; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto user_nodes = manager->node_users()[node]; | |||||
| return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && | |||||
| AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && | |||||
| cnode->inputs().size() == ELTWISE_INPUT_SIZE; | |||||
| } | |||||
| void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, | void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, | ||||
| FusedNodeRecord *candidate_fusion) { | FusedNodeRecord *candidate_fusion) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||||
| class ConvDoubleInFusionPass : public FusionBasePass { | class ConvDoubleInFusionPass : public FusionBasePass { | ||||
| public: | public: | ||||
| explicit ConvDoubleInFusionPass(FusionIdAllocator *idAllocator) | |||||
| explicit ConvDoubleInFusionPass(FusionIdAllocatorPtr idAllocator) | |||||
| : FusionBasePass("ConvDoubleInFusionPass", idAllocator) {} | : FusionBasePass("ConvDoubleInFusionPass", idAllocator) {} | ||||
| ~ConvDoubleInFusionPass() override = default; | ~ConvDoubleInFusionPass() override = default; | ||||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | ||||
| @@ -41,7 +41,6 @@ class ConvDoubleInFusionPass : public FusionBasePass { | |||||
| private: | private: | ||||
| void MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, | void MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, | ||||
| FusedNodeRecord *candidate_fusion); | FusedNodeRecord *candidate_fusion); | ||||
| bool CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||||
| class ConvSingleInFusionPass : public FusionBasePass { | class ConvSingleInFusionPass : public FusionBasePass { | ||||
| public: | public: | ||||
| explicit ConvSingleInFusionPass(FusionIdAllocator *idAllocator) | |||||
| explicit ConvSingleInFusionPass(FusionIdAllocatorPtr idAllocator) | |||||
| : FusionBasePass("ConvSingleInFusionPass", idAllocator) {} | : FusionBasePass("ConvSingleInFusionPass", idAllocator) {} | ||||
| ~ConvSingleInFusionPass() override = default; | ~ConvSingleInFusionPass() override = default; | ||||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | ||||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||||
| class DepthwiseConvEltwiseFusionPass : public FusionBasePass { | class DepthwiseConvEltwiseFusionPass : public FusionBasePass { | ||||
| public: | public: | ||||
| explicit DepthwiseConvEltwiseFusionPass(FusionIdAllocator *idAllocator) | |||||
| explicit DepthwiseConvEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||||
| : FusionBasePass("DepthwiseConvEltwiseFusionPass", idAllocator) {} | : FusionBasePass("DepthwiseConvEltwiseFusionPass", idAllocator) {} | ||||
| ~DepthwiseConvEltwiseFusionPass() override = default; | ~DepthwiseConvEltwiseFusionPass() override = default; | ||||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | ||||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||||
| class EltwiseFusionPass : public FusionBasePass { | class EltwiseFusionPass : public FusionBasePass { | ||||
| public: | public: | ||||
| explicit EltwiseFusionPass(FusionIdAllocator *idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {} | |||||
| explicit EltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {} | |||||
| ~EltwiseFusionPass() override = default; | ~EltwiseFusionPass() override = default; | ||||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | ||||
| @@ -36,6 +36,19 @@ bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePt | |||||
| cnode->inputs().size() == ELTWISE_INPUT_SIZE; | cnode->inputs().size() == ELTWISE_INPUT_SIZE; | ||||
| } | } | ||||
| bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | |||||
| return false; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto user_nodes = manager->node_users()[node]; | |||||
| return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && | |||||
| AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && | |||||
| cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE; | |||||
| } | |||||
| void FusionBasePass::SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record) { | void FusionBasePass::SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record) { | ||||
| auto id = fusion_id_allocator->AllocateFusionId(); | auto id = fusion_id_allocator->AllocateFusionId(); | ||||
| for (auto node : record) { | for (auto node : record) { | ||||
| @@ -32,13 +32,14 @@ namespace opt { | |||||
| const int8_t MAX_ELTWISE_NUM = 3; | const int8_t MAX_ELTWISE_NUM = 3; | ||||
| const int8_t MIN_ELTWISE_SIZE = 2; | const int8_t MIN_ELTWISE_SIZE = 2; | ||||
| const int8_t ELTWISE_INPUT_SIZE = 2; | const int8_t ELTWISE_INPUT_SIZE = 2; | ||||
| const int8_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3; | |||||
| const int8_t ELTWISE_USE = 1; | const int8_t ELTWISE_USE = 1; | ||||
| const int8_t MAX_ELTWISE_SIZE = 6; | const int8_t MAX_ELTWISE_SIZE = 6; | ||||
| using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | ||||
| class FusionBasePass : public Pass { | class FusionBasePass : public Pass { | ||||
| public: | public: | ||||
| FusionBasePass(const std::string &name, FusionIdAllocator *idAllocator) | |||||
| FusionBasePass(const std::string &name, FusionIdAllocatorPtr idAllocator) | |||||
| : Pass(name), fusion_id_allocator(idAllocator) {} | : Pass(name), fusion_id_allocator(idAllocator) {} | ||||
| ~FusionBasePass() override = default; | ~FusionBasePass() override = default; | ||||
| bool Run(const FuncGraphPtr &graph) override; | bool Run(const FuncGraphPtr &graph) override; | ||||
| @@ -49,7 +50,8 @@ class FusionBasePass : public Pass { | |||||
| FusedNodeRecord *candidate_fusion) = 0; | FusedNodeRecord *candidate_fusion) = 0; | ||||
| void SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record); | void SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record); | ||||
| bool CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); | bool CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); | ||||
| FusionIdAllocator *fusion_id_allocator; | |||||
| bool CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); | |||||
| FusionIdAllocatorPtr fusion_id_allocator; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||||
| class MatmulEltwiseFusionPass : public FusionBasePass { | class MatmulEltwiseFusionPass : public FusionBasePass { | ||||
| public: | public: | ||||
| explicit MatmulEltwiseFusionPass(FusionIdAllocator *idAllocator) | |||||
| explicit MatmulEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||||
| : FusionBasePass("MatmulEltwiseFusionPass", idAllocator) {} | : FusionBasePass("MatmulEltwiseFusionPass", idAllocator) {} | ||||
| ~MatmulEltwiseFusionPass() override = default; | ~MatmulEltwiseFusionPass() override = default; | ||||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | ||||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||||
| class ReduceEltwiseFusionPass : public FusionBasePass { | class ReduceEltwiseFusionPass : public FusionBasePass { | ||||
| public: | public: | ||||
| explicit ReduceEltwiseFusionPass(FusionIdAllocator *idAllocator) | |||||
| explicit ReduceEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||||
| : FusionBasePass("ReduceEltwiseFusionPass", idAllocator) {} | : FusionBasePass("ReduceEltwiseFusionPass", idAllocator) {} | ||||
| ~ReduceEltwiseFusionPass() override = default; | ~ReduceEltwiseFusionPass() override = default; | ||||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | ||||
| @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>; | |||||
| class SegmentEltwiseFusionPass : public FusionBasePass { | class SegmentEltwiseFusionPass : public FusionBasePass { | ||||
| public: | public: | ||||
| explicit SegmentEltwiseFusionPass(FusionIdAllocator *idAllocator) | |||||
| explicit SegmentEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) | |||||
| : FusionBasePass("SegmentEltwiseFusionPass", idAllocator) {} | : FusionBasePass("SegmentEltwiseFusionPass", idAllocator) {} | ||||
| ~SegmentEltwiseFusionPass() override = default; | ~SegmentEltwiseFusionPass() override = default; | ||||
| void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; | ||||
| @@ -16,6 +16,7 @@ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ | #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ | ||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ | #define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ | ||||
| #include <memory> | |||||
| #include "ir/base.h" | #include "ir/base.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -36,6 +37,7 @@ class FusionIdAllocator { | |||||
| private: | private: | ||||
| int32_t fusion_id; | int32_t fusion_id; | ||||
| }; | }; | ||||
| using FusionIdAllocatorPtr = std::shared_ptr<FusionIdAllocator>; | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||