Browse Source

!8294 fix ub fusion key error

From: @jjfeing
Reviewed-by: @kisnwang,@limingqi107
Signed-off-by: @limingqi107
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
49e3aa35a2
2 changed files with 7 additions and 3 deletions
  1. +5
    -3
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc
  2. +2
    -0
      mindspore/ccsrc/utils/utils.h

+ 5
- 3
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc View File

@@ -483,7 +483,7 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo
MS_EXCEPTION_IF_NULL(value_type);
auto value_type_str = value_type->ToString();
if (value_type_str == kVTypeInt64) {
int64_t data = GetValue<int64_t>(value);
auto data = GetValue<int64_t>(value);
attr_value.push_back(data);
} else {
auto vec =
@@ -737,10 +737,10 @@ void TbeKernelBuild::GenFusionComputeCommonJson(const mindspore::CNodePtr &cnode
(*compute_op_str)[kJtype] = type;
auto kernel_name = op_info_ptr->kernel_name();
(*compute_op_str)[kJFuncName] = kernel_name;
(*compute_op_str)[kJModuleName] = std::string("impl.") + type;
(*compute_op_str)[kJModuleName] = std::string("impl.") + kernel_name;
(*compute_op_str)[kJName] = cnode->fullname_with_scope();
(*compute_op_str)[kJPattern] = GetNodeFusionType(cnode);
(*compute_op_str)[kJPyModulePath] = "/usr/local/Ascend/opp/op_impl/build_in/ai_core/tbe";
(*compute_op_str)[kJPyModulePath] = "/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe";
(void)(*fusion_kernel_name).append("_");
(void)(*fusion_kernel_name).append(kernel_name);
}
@@ -1034,6 +1034,8 @@ std::string TbeKernelBuild::GetNodeFusionType(const mindspore::CNodePtr &cnode)
{kReluV2OpName, "ElemWise"},
{kTensorAddOpName, "ElemWise"},
{kConv2DBackpropInputOpName, "Conv2d_backprop_input"},
{kConv2DBackpropFilterOpName, "Conv2d_backprop_filter"},
{kDepthwiseConv2dNativeName, "DepthwiseConvolution"},
{kAddNOpName, "ElemWise"},
{kReluGradV2OpName, "ElemWise"},
{kRealDivOpName, "ElemWise"}};


+ 2
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -157,6 +157,8 @@ constexpr auto kSpaceToBatchOpName = "SpaceToBatch";
constexpr auto kBatchToSpaceOpName = "BatchToSpace";
constexpr auto kPadOpName = "Pad";
constexpr auto kConv2DBackpropInputOpName = "Conv2DBackpropInput";
constexpr auto kConv2DBackpropFilterOpName = "Conv2DBackpropFilter";
constexpr auto kDepthwiseConv2dNativeName = "DepthwiseConv2dNative";
constexpr auto kFusionOpConv2DBackpropInputReluGradV2Name = "FusionOp_Conv2DBackpropInput_ReluGradV2";
constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2DBackpropInput_AddN_ReluGradV2";
constexpr auto kLabelSetOpName = "LabelSet";


Loading…
Cancel
Save