diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc index 6c6ab73335..8fe80ce78f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc @@ -61,6 +61,7 @@ constexpr auto kParamDynamic = "dynamic"; constexpr auto kParamRequred = "required"; constexpr auto kJDataType = "data_type"; constexpr auto kJOutputIndex = "output_index"; +constexpr auto kJOutputDataDesc = "output_data_desc"; constexpr auto kJOutputDesc = "output_desc"; constexpr auto kJInputDesc = "input_desc"; constexpr auto kJRange = "range"; @@ -84,13 +85,13 @@ constexpr auto kJAddrType = "addr_type"; constexpr auto kJSliceOffset = "slice_offset"; constexpr auto kJSplitIndex = "split_index"; constexpr auto kJTotalShape = "total_shape"; +constexpr auto kJDynamicCompileStatic = "dynamic_compile_static"; +constexpr auto kJInt64Mode = "int64mode"; constexpr auto kJValidShape = "valid_shape"; constexpr auto kJModuleName = "module_name"; constexpr auto kJPattern = "pattern"; constexpr auto kJPyModulePath = "py_module_path"; -constexpr auto kJPreBuildOutsAttrs = "prebuild_outs_attrs"; -constexpr auto kJKwdArgs = "kwds_args"; -constexpr auto kJListArgs = "list_args"; +constexpr auto kJAttrDesc = "attr_desc"; constexpr auto kJSocVersion = "socVersion"; constexpr auto kSOC_VERSION = "SOC_VERSION"; constexpr auto kJIsDynamicShape = "is_dynamic_shape"; @@ -784,49 +785,32 @@ void TbeKernelBuild::GenFusionComputeCommonJson(const mindspore::CNodePtr &cnode // replace special op type for buffer fusion op auto type = GetRealOpType(origin_type); (*compute_op_str)[kJtype] = type; - auto kernel_name = op_info_ptr->kernel_name(); - (*compute_op_str)[kJFuncName] = kernel_name; - (*compute_op_str)[kJModuleName] = std::string("impl.") + kernel_name; + (*compute_op_str)[kJDynamicCompileStatic] = false; + auto func_name = op_info_ptr->kernel_name(); + (*compute_op_str)[kJFuncName] = func_name; + (*compute_op_str)[kJInt64Mode] = false; + (*compute_op_str)[kJModuleName] = std::string("impl.") + func_name; (*compute_op_str)[kJName] = cnode->fullname_with_scope(); (*compute_op_str)[kJPattern] = GetNodeFusionType(cnode); (*compute_op_str)[kJPyModulePath] = "/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe"; (void)(*fusion_kernel_name).append("_"); - (void)(*fusion_kernel_name).append(kernel_name); -} - -void TbeKernelBuild::GenFusionComputePreBuildJson(const mindspore::CNodePtr &cnode, nlohmann::json *compute_op_str) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(compute_op_str); - // kwds args - nlohmann::json json_prebuild_args; - json_prebuild_args[kJKwdArgs] = nlohmann::json::object(); - // list_args - nlohmann::json json_list_args; - // list_args: output args - auto output_size = AnfAlgo::GetOutputTensorNum(cnode); - for (size_t i = 0; i < output_size; ++i) { - nlohmann::json output_desc; - GenDescJson(cnode, i, i, &output_desc); - output_desc[kJDtype] = output_desc[kJDataType]; - json_list_args.push_back(output_desc); - } - // list_args: attr args - auto op_name = AnfAlgo::GetCNodeName(cnode); - auto opinfo = OpLib::FindOp(op_name, OpImplyType::kTBE); - MS_EXCEPTION_IF_NULL(opinfo); + (void)(*fusion_kernel_name).append(func_name); + // attr_desc TbeKernelJsonCreator json_creater(SINGLE_BUILD); nlohmann::json json_attr_args; - if (!json_creater.GenTbeAttrJson(cnode, opinfo, &json_attr_args)) { + if (!json_creater.GenTbeAttrJson(cnode, op_info_ptr, &json_attr_args)) { MS_LOG(INFO) << "Fusion warning: get prebuild args of attr failed."; } + nlohmann::json attr_desc; for (const auto &attr : json_attr_args) { // if(attr[kJName] != "isRef" && attr["valid"] == true) { if (attr[kJName] != "isRef" && attr[kJValid] == true) { - json_list_args.push_back(attr[kJValue]); + attr_desc.push_back(attr[kJValue]); } } - json_prebuild_args[kJListArgs] = json_list_args; - (*compute_op_str)[kJPreBuildOutsAttrs] = json_prebuild_args; + if (!attr_desc.empty()) { + (*compute_op_str)[kJAttrDesc] = attr_desc; + } } void TbeKernelBuild::GenSuffixDescJson(nlohmann::json *output_desc) { @@ -902,6 +886,18 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr &anf_ GenSuffixDescJson(output_desc); } +void TbeKernelBuild::GenFusionOutputDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, + size_t desc_output_idx, nlohmann::json *output_desc, + nlohmann::json *output_data_desc) { + MS_EXCEPTION_IF_NULL(output_desc); + MS_EXCEPTION_IF_NULL(output_data_desc); + GenDescJson(anf_node, node_out_idx, desc_output_idx, output_desc); + *output_data_desc = *output_desc; + (*output_data_desc)[kJDtype] = (*output_desc)[kJDataType]; + output_data_desc->erase(kJDataType); + output_data_desc->erase(kJName); +} + void TbeKernelBuild::GenReusedOutputDesc(const std::shared_ptr &anf_node, size_t index, size_t output_index, nlohmann::json *output_desc) { std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index); @@ -1154,8 +1150,10 @@ std::vector TbeKernelBuild::GetDescOutputIndex(const std::vector *output_desc_list) { + std::vector *output_desc_list, + std::vector *output_data_desc_list) { MS_EXCEPTION_IF_NULL(output_desc_list); + MS_EXCEPTION_IF_NULL(output_data_desc_list); auto output_size = AnfAlgo::GetOutputTensorNum(cnode); if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) { auto output_used_nums = AnfAlgo::GetNodeAttr>(cnode, kAttrOutputUsedNum); @@ -1168,7 +1166,9 @@ bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode for (size_t i = 0; i < output_size; ++i) { MS_LOG(INFO) << "Fusion index: " << i << ", desc_output_index: " << desc_output_index[i]; nlohmann::json output_desc; - GenDescJson(cnode, i, desc_output_index[i], &output_desc); + nlohmann::json output_data_desc; + GenFusionOutputDescJson(cnode, i, desc_output_index[i], &output_desc, &output_data_desc); + output_data_desc_list->emplace_back(output_data_desc); output_desc_list->emplace_back(output_desc); } for (size_t j = output_size; j < desc_output_index.size(); ++j) { @@ -1180,8 +1180,10 @@ bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode } else { for (size_t i = 0; i < output_size; ++i) { nlohmann::json output_desc; - GenDescJson(cnode, i, i, &output_desc); - output_desc_list->push_back(output_desc); + nlohmann::json output_data_desc; + GenFusionOutputDescJson(cnode, i, i, &output_desc, &output_data_desc); + output_data_desc_list->emplace_back(output_data_desc); + output_desc_list->emplace_back(output_desc); } } return true; @@ -1200,15 +1202,15 @@ bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_n (*compute_op_str)[kJInputDesc] = input_desc_list; // gen output desc std::vector output_desc_list; - if (!GenFusionComputeOutputJson(cnode, &output_desc_list)) { + std::vector output_data_desc_list; + if (!GenFusionComputeOutputJson(cnode, &output_desc_list, &output_data_desc_list)) { MS_LOG(INFO) << "Fusion Error: gen fusion output desc failed, node full name: " << cnode->fullname_with_scope(); return false; } + (*compute_op_str)[kJOutputDataDesc] = output_data_desc_list; (*compute_op_str)[kJOutputDesc] = output_desc_list; // gen common desc GenFusionComputeCommonJson(cnode, compute_op_str, fusion_kernel_name); - // gen prebuild args - GenFusionComputePreBuildJson(cnode, compute_op_str); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h index 23df11d3ee..f0bec0b61a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h @@ -60,14 +60,17 @@ class TbeKernelBuild { std::vector *input_desc_list, size_t *index); static std::vector GetDescOutputIndex(const std::vector &output_used_nums); static bool GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, - std::vector *output_desc_list); + std::vector *output_desc_list, + std::vector *output_data_desc_list); static void GenPreDescJson(nlohmann::json *output_desc); static void GenFusionComputeCommonJson(const mindspore::CNodePtr &cnode, nlohmann::json *compute_op_str, std::string *fusion_kernel_name); - static void GenFusionComputePreBuildJson(const mindspore::CNodePtr &cnode, nlohmann::json *compute_op_str); static void GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, size_t desc_output_idx, nlohmann::json *output_desc, FusionDataType fusion_data_type = kFusionNormal); + static void GenFusionOutputDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, + size_t desc_output_idx, nlohmann::json *output_desc, + nlohmann::json *output_data_desc); static void GenSuffixDescJson(nlohmann::json *output_desc); static void GenReusedOutputDesc(const std::shared_ptr &anf_node, size_t index, size_t output_index, nlohmann::json *output_desc);