|
|
|
@@ -1084,12 +1084,14 @@ std::string TbeKernelBuild::GetNodeFusionType(const mindspore::CNodePtr &cnode) |
|
|
|
{kDepthwiseConv2dNativeOpName, "DepthwiseConvolution"}, |
|
|
|
{kAddNOpName, "ElemWise"}, |
|
|
|
{kReluGradV2OpName, "ElemWise"}, |
|
|
|
{kRealDivOpName, "ElemWise"}}; |
|
|
|
{kRealDivOpName, "ElemWise"}, |
|
|
|
{kBiasAddOpName, "BiasAdd"}}; |
|
|
|
auto find = fusion_type_map.find(node_type); |
|
|
|
if (find == fusion_type_map.end()) { |
|
|
|
MS_LOG(INFO) << "Fusion warning: get node fusion type failed, origin node type: " << node_type |
|
|
|
<< " return null string."; |
|
|
|
return ""; |
|
|
|
MS_LOG(INFO) << "Fusion warning: get node fusion type failed from lists, origin node type: " << node_type; |
|
|
|
auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(node_type, cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
|
return op_info->fusion_type(); |
|
|
|
} else { |
|
|
|
return find->second; |
|
|
|
} |
|
|
|
|