From 7017ac3f7e1bca9686d4b0b5f44c4b39376ccf4b Mon Sep 17 00:00:00 2001 From: wangcong Date: Wed, 13 May 2020 18:03:50 +0800 Subject: [PATCH] depthwisecon2d-eltwise pass --- .../ccsrc/kernel/tbe/tbe_kernel_build.cc | 19 +++++++++++-------- mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h | 2 +- mindspore/ccsrc/operator/ops.cc | 1 + .../ascend/buffer_fusion/buffer_fusion.cc | 5 ++--- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc index 1f8bb35974..825126672b 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc @@ -38,11 +38,6 @@ constexpr auto kFusionKernelNamePrfix = "te_fusion"; constexpr auto kOptional = "optional_"; constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z"; -std::map TbeKernelBuild::buffer_fussion_op_map_ = { - {"DepthwiseConv2dNative", "DepthwiseConv2D"}, - {"TensorAdd", "Add"} -}; - std::string NormalizeFullScopeName(const string &full_scope_name) { // exp:Default/ReLU-op0 -->Default_ReLU_op0 string normal_ret = full_scope_name; @@ -726,6 +721,16 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size()); } +std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) { + static std::map TbeKernelBuild::buffer_fussion_op_map = { + {"DepthwiseConv2dNative", "DepthwiseConv2D"}, {"TensorAdd", "Add"}}; + string result = origin_type; + if (buffer_fussion_op_map.find(origin_type) != buffer_fussion_op_map.end()) { + result = buffer_fussion_op_map[origin_type]; + } + return result; +} + bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, std::vector>::iterator *layer_iter, std::vector *input_desc_list, size_t *index) { @@ -831,9 +836,7 @@ bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_n // gen others auto type = AnfAlgo::GetCNodeName(cnode); // replace special op type for buffer fusion op - if (buffer_fussion_op_map_.find(type) != buffer_fussion_op_map_.end()) { - type = buffer_fussion_op_map_[type]; - } + type = GetRealOpType(type); (*compute_op_str)["type"] = type; tbe::TbeAdapter::NormalizeFuncName(&type); (*compute_op_str)["func_name"] = type; diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h index a2b100dc1b..8bd21e7c70 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h @@ -76,7 +76,7 @@ class TbeKernelBuild { std::map *spec_data_input); static bool IsDynamicInput(const CNodePtr &cnode); static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input); - static std::map buffer_fussion_op_map_; + std::string GetRealOpType(const std::string &origin_type); }; class TbeKernelJsonCreator { diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index 59dd011c9d..c85e6a72ce 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -200,6 +200,7 @@ const PrimitivePtr kPrimOneHot = std::make_shared("OneHot"); const PrimitivePtr kPrimGelu = std::make_shared("Gelu"); const PrimitivePtr kPrimGeluGrad = std::make_shared("GeluGrad"); const PrimitivePtr kPrimRelu = std::make_shared("ReLU"); +const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); const PrimitivePtr kPrimZerosLikeTensor = std::make_shared("zeros_like_tensor"); const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc index cac3b425d3..e4a3ae9f46 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc @@ -555,7 +555,7 @@ void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session:: // DepthwiseConvolution--->Elemwise auto depthwise_conv = cnode->input(1); MS_EXCEPTION_IF_NULL(depthwise_conv); - if (cnode->isa() && AnfAlgo::GetCNodeName(depthwise_conv) == prim::kPrimDepthwiseConv2dNative->name()) { + if (cnode->isa() && IsPrimitiveCNode(depthwise_conv, prim::kPrimDepthwiseConv2dNative)) { std::vector output_used_num{SizeToInt(manager->node_users()[depthwise_conv].size())}; AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv); std::unordered_set record{cnode, depthwise_conv}; @@ -566,8 +566,7 @@ void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session:: // Elemwise-->DepthwiseConvolution auto relu = cnode->input(1); MS_EXCEPTION_IF_NULL(relu); - if (cnode->isa() && - (AnfAlgo::GetCNodeName(relu) == prim::kPrimRelu->name() || AnfAlgo::GetCNodeName(relu) == kReluV2OpName)) { + if (cnode->isa() && (IsPrimitiveCNode(relu, prim::kPrimRelu) || IsPrimitiveCNode(relu, prim::kPrimReluV2))) { std::vector output_used_num{SizeToInt(manager->node_users()[relu].size())}; AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu); std::unordered_set record{cnode, relu};