|
|
@@ -61,6 +61,7 @@ constexpr auto kParamDynamic = "dynamic"; |
|
|
constexpr auto kParamRequred = "required"; |
|
|
constexpr auto kParamRequred = "required"; |
|
|
constexpr auto kJDataType = "data_type"; |
|
|
constexpr auto kJDataType = "data_type"; |
|
|
constexpr auto kJOutputIndex = "output_index"; |
|
|
constexpr auto kJOutputIndex = "output_index"; |
|
|
|
|
|
constexpr auto kJOutputDataDesc = "output_data_desc"; |
|
|
constexpr auto kJOutputDesc = "output_desc"; |
|
|
constexpr auto kJOutputDesc = "output_desc"; |
|
|
constexpr auto kJInputDesc = "input_desc"; |
|
|
constexpr auto kJInputDesc = "input_desc"; |
|
|
constexpr auto kJRange = "range"; |
|
|
constexpr auto kJRange = "range"; |
|
|
@@ -84,13 +85,13 @@ constexpr auto kJAddrType = "addr_type"; |
|
|
constexpr auto kJSliceOffset = "slice_offset"; |
|
|
constexpr auto kJSliceOffset = "slice_offset"; |
|
|
constexpr auto kJSplitIndex = "split_index"; |
|
|
constexpr auto kJSplitIndex = "split_index"; |
|
|
constexpr auto kJTotalShape = "total_shape"; |
|
|
constexpr auto kJTotalShape = "total_shape"; |
|
|
|
|
|
constexpr auto kJDynamicCompileStatic = "dynamic_compile_static"; |
|
|
|
|
|
constexpr auto kJInt64Mode = "int64mode"; |
|
|
constexpr auto kJValidShape = "valid_shape"; |
|
|
constexpr auto kJValidShape = "valid_shape"; |
|
|
constexpr auto kJModuleName = "module_name"; |
|
|
constexpr auto kJModuleName = "module_name"; |
|
|
constexpr auto kJPattern = "pattern"; |
|
|
constexpr auto kJPattern = "pattern"; |
|
|
constexpr auto kJPyModulePath = "py_module_path"; |
|
|
constexpr auto kJPyModulePath = "py_module_path"; |
|
|
constexpr auto kJPreBuildOutsAttrs = "prebuild_outs_attrs"; |
|
|
|
|
|
constexpr auto kJKwdArgs = "kwds_args"; |
|
|
|
|
|
constexpr auto kJListArgs = "list_args"; |
|
|
|
|
|
|
|
|
constexpr auto kJAttrDesc = "attr_desc"; |
|
|
constexpr auto kJSocVersion = "socVersion"; |
|
|
constexpr auto kJSocVersion = "socVersion"; |
|
|
constexpr auto kSOC_VERSION = "SOC_VERSION"; |
|
|
constexpr auto kSOC_VERSION = "SOC_VERSION"; |
|
|
constexpr auto kJIsDynamicShape = "is_dynamic_shape"; |
|
|
constexpr auto kJIsDynamicShape = "is_dynamic_shape"; |
|
|
@@ -784,49 +785,32 @@ void TbeKernelBuild::GenFusionComputeCommonJson(const mindspore::CNodePtr &cnode |
|
|
// replace special op type for buffer fusion op |
|
|
// replace special op type for buffer fusion op |
|
|
auto type = GetRealOpType(origin_type); |
|
|
auto type = GetRealOpType(origin_type); |
|
|
(*compute_op_str)[kJtype] = type; |
|
|
(*compute_op_str)[kJtype] = type; |
|
|
auto kernel_name = op_info_ptr->kernel_name(); |
|
|
|
|
|
(*compute_op_str)[kJFuncName] = kernel_name; |
|
|
|
|
|
(*compute_op_str)[kJModuleName] = std::string("impl.") + kernel_name; |
|
|
|
|
|
|
|
|
(*compute_op_str)[kJDynamicCompileStatic] = false; |
|
|
|
|
|
auto func_name = op_info_ptr->kernel_name(); |
|
|
|
|
|
(*compute_op_str)[kJFuncName] = func_name; |
|
|
|
|
|
(*compute_op_str)[kJInt64Mode] = false; |
|
|
|
|
|
(*compute_op_str)[kJModuleName] = std::string("impl.") + func_name; |
|
|
(*compute_op_str)[kJName] = cnode->fullname_with_scope(); |
|
|
(*compute_op_str)[kJName] = cnode->fullname_with_scope(); |
|
|
(*compute_op_str)[kJPattern] = GetNodeFusionType(cnode); |
|
|
(*compute_op_str)[kJPattern] = GetNodeFusionType(cnode); |
|
|
(*compute_op_str)[kJPyModulePath] = "/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe"; |
|
|
(*compute_op_str)[kJPyModulePath] = "/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe"; |
|
|
(void)(*fusion_kernel_name).append("_"); |
|
|
(void)(*fusion_kernel_name).append("_"); |
|
|
(void)(*fusion_kernel_name).append(kernel_name); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void TbeKernelBuild::GenFusionComputePreBuildJson(const mindspore::CNodePtr &cnode, nlohmann::json *compute_op_str) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(compute_op_str); |
|
|
|
|
|
// kwds args |
|
|
|
|
|
nlohmann::json json_prebuild_args; |
|
|
|
|
|
json_prebuild_args[kJKwdArgs] = nlohmann::json::object(); |
|
|
|
|
|
// list_args |
|
|
|
|
|
nlohmann::json json_list_args; |
|
|
|
|
|
// list_args: output args |
|
|
|
|
|
auto output_size = AnfAlgo::GetOutputTensorNum(cnode); |
|
|
|
|
|
for (size_t i = 0; i < output_size; ++i) { |
|
|
|
|
|
nlohmann::json output_desc; |
|
|
|
|
|
GenDescJson(cnode, i, i, &output_desc); |
|
|
|
|
|
output_desc[kJDtype] = output_desc[kJDataType]; |
|
|
|
|
|
json_list_args.push_back(output_desc); |
|
|
|
|
|
} |
|
|
|
|
|
// list_args: attr args |
|
|
|
|
|
auto op_name = AnfAlgo::GetCNodeName(cnode); |
|
|
|
|
|
auto opinfo = OpLib::FindOp(op_name, OpImplyType::kTBE); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(opinfo); |
|
|
|
|
|
|
|
|
(void)(*fusion_kernel_name).append(func_name); |
|
|
|
|
|
// attr_desc |
|
|
TbeKernelJsonCreator json_creater(SINGLE_BUILD); |
|
|
TbeKernelJsonCreator json_creater(SINGLE_BUILD); |
|
|
nlohmann::json json_attr_args; |
|
|
nlohmann::json json_attr_args; |
|
|
if (!json_creater.GenTbeAttrJson(cnode, opinfo, &json_attr_args)) { |
|
|
|
|
|
|
|
|
if (!json_creater.GenTbeAttrJson(cnode, op_info_ptr, &json_attr_args)) { |
|
|
MS_LOG(INFO) << "Fusion warning: get prebuild args of attr failed."; |
|
|
MS_LOG(INFO) << "Fusion warning: get prebuild args of attr failed."; |
|
|
} |
|
|
} |
|
|
|
|
|
nlohmann::json attr_desc; |
|
|
for (const auto &attr : json_attr_args) { |
|
|
for (const auto &attr : json_attr_args) { |
|
|
// if(attr[kJName] != "isRef" && attr["valid"] == true) { |
|
|
// if(attr[kJName] != "isRef" && attr["valid"] == true) { |
|
|
if (attr[kJName] != "isRef" && attr[kJValid] == true) { |
|
|
if (attr[kJName] != "isRef" && attr[kJValid] == true) { |
|
|
json_list_args.push_back(attr[kJValue]); |
|
|
|
|
|
|
|
|
attr_desc.push_back(attr[kJValue]); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
json_prebuild_args[kJListArgs] = json_list_args; |
|
|
|
|
|
(*compute_op_str)[kJPreBuildOutsAttrs] = json_prebuild_args; |
|
|
|
|
|
|
|
|
if (!attr_desc.empty()) { |
|
|
|
|
|
(*compute_op_str)[kJAttrDesc] = attr_desc; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void TbeKernelBuild::GenSuffixDescJson(nlohmann::json *output_desc) { |
|
|
void TbeKernelBuild::GenSuffixDescJson(nlohmann::json *output_desc) { |
|
|
@@ -902,6 +886,18 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_ |
|
|
GenSuffixDescJson(output_desc); |
|
|
GenSuffixDescJson(output_desc); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void TbeKernelBuild::GenFusionOutputDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx, |
|
|
|
|
|
size_t desc_output_idx, nlohmann::json *output_desc, |
|
|
|
|
|
nlohmann::json *output_data_desc) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_desc); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_data_desc); |
|
|
|
|
|
GenDescJson(anf_node, node_out_idx, desc_output_idx, output_desc); |
|
|
|
|
|
*output_data_desc = *output_desc; |
|
|
|
|
|
(*output_data_desc)[kJDtype] = (*output_desc)[kJDataType]; |
|
|
|
|
|
output_data_desc->erase(kJDataType); |
|
|
|
|
|
output_data_desc->erase(kJName); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
void TbeKernelBuild::GenReusedOutputDesc(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t index, |
|
|
void TbeKernelBuild::GenReusedOutputDesc(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t index, |
|
|
size_t output_index, nlohmann::json *output_desc) { |
|
|
size_t output_index, nlohmann::json *output_desc) { |
|
|
std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index); |
|
|
std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index); |
|
|
@@ -1154,8 +1150,10 @@ std::vector<size_t> TbeKernelBuild::GetDescOutputIndex(const std::vector<int64_t |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, |
|
|
bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, |
|
|
std::vector<nlohmann::json> *output_desc_list) { |
|
|
|
|
|
|
|
|
std::vector<nlohmann::json> *output_desc_list, |
|
|
|
|
|
std::vector<nlohmann::json> *output_data_desc_list) { |
|
|
MS_EXCEPTION_IF_NULL(output_desc_list); |
|
|
MS_EXCEPTION_IF_NULL(output_desc_list); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_data_desc_list); |
|
|
auto output_size = AnfAlgo::GetOutputTensorNum(cnode); |
|
|
auto output_size = AnfAlgo::GetOutputTensorNum(cnode); |
|
|
if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) { |
|
|
if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) { |
|
|
auto output_used_nums = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrOutputUsedNum); |
|
|
auto output_used_nums = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrOutputUsedNum); |
|
|
@@ -1168,7 +1166,9 @@ bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode |
|
|
for (size_t i = 0; i < output_size; ++i) { |
|
|
for (size_t i = 0; i < output_size; ++i) { |
|
|
MS_LOG(INFO) << "Fusion index: " << i << ", desc_output_index: " << desc_output_index[i]; |
|
|
MS_LOG(INFO) << "Fusion index: " << i << ", desc_output_index: " << desc_output_index[i]; |
|
|
nlohmann::json output_desc; |
|
|
nlohmann::json output_desc; |
|
|
GenDescJson(cnode, i, desc_output_index[i], &output_desc); |
|
|
|
|
|
|
|
|
nlohmann::json output_data_desc; |
|
|
|
|
|
GenFusionOutputDescJson(cnode, i, desc_output_index[i], &output_desc, &output_data_desc); |
|
|
|
|
|
output_data_desc_list->emplace_back(output_data_desc); |
|
|
output_desc_list->emplace_back(output_desc); |
|
|
output_desc_list->emplace_back(output_desc); |
|
|
} |
|
|
} |
|
|
for (size_t j = output_size; j < desc_output_index.size(); ++j) { |
|
|
for (size_t j = output_size; j < desc_output_index.size(); ++j) { |
|
|
@@ -1180,8 +1180,10 @@ bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode |
|
|
} else { |
|
|
} else { |
|
|
for (size_t i = 0; i < output_size; ++i) { |
|
|
for (size_t i = 0; i < output_size; ++i) { |
|
|
nlohmann::json output_desc; |
|
|
nlohmann::json output_desc; |
|
|
GenDescJson(cnode, i, i, &output_desc); |
|
|
|
|
|
output_desc_list->push_back(output_desc); |
|
|
|
|
|
|
|
|
nlohmann::json output_data_desc; |
|
|
|
|
|
GenFusionOutputDescJson(cnode, i, i, &output_desc, &output_data_desc); |
|
|
|
|
|
output_data_desc_list->emplace_back(output_data_desc); |
|
|
|
|
|
output_desc_list->emplace_back(output_desc); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
@@ -1200,15 +1202,15 @@ bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_n |
|
|
(*compute_op_str)[kJInputDesc] = input_desc_list; |
|
|
(*compute_op_str)[kJInputDesc] = input_desc_list; |
|
|
// gen output desc |
|
|
// gen output desc |
|
|
std::vector<nlohmann::json> output_desc_list; |
|
|
std::vector<nlohmann::json> output_desc_list; |
|
|
if (!GenFusionComputeOutputJson(cnode, &output_desc_list)) { |
|
|
|
|
|
|
|
|
std::vector<nlohmann::json> output_data_desc_list; |
|
|
|
|
|
if (!GenFusionComputeOutputJson(cnode, &output_desc_list, &output_data_desc_list)) { |
|
|
MS_LOG(INFO) << "Fusion Error: gen fusion output desc failed, node full name: " << cnode->fullname_with_scope(); |
|
|
MS_LOG(INFO) << "Fusion Error: gen fusion output desc failed, node full name: " << cnode->fullname_with_scope(); |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
(*compute_op_str)[kJOutputDataDesc] = output_data_desc_list; |
|
|
(*compute_op_str)[kJOutputDesc] = output_desc_list; |
|
|
(*compute_op_str)[kJOutputDesc] = output_desc_list; |
|
|
// gen common desc |
|
|
// gen common desc |
|
|
GenFusionComputeCommonJson(cnode, compute_op_str, fusion_kernel_name); |
|
|
GenFusionComputeCommonJson(cnode, compute_op_str, fusion_kernel_name); |
|
|
// gen prebuild args |
|
|
|
|
|
GenFusionComputePreBuildJson(cnode, compute_op_str); |
|
|
|
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|