|
|
|
@@ -28,6 +28,7 @@ |
|
|
|
#include "backend/kernel_compiler/kernel_query.h" |
|
|
|
#include "backend/kernel_compiler/oplib/oplib.h" |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" |
|
|
|
#include "backend/kernel_compiler/aicpu/aicpu_attr_to_input_registry.h" |
|
|
|
#include "backend/optimizer/common/helper.h" |
|
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
|
#include "utils/ms_device_shape_transfer.h" |
|
|
|
@@ -689,16 +690,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, |
|
|
|
return select_status; |
|
|
|
} |
|
|
|
|
|
|
|
bool ConvertAttrToInput(const CNodePtr &kernel_node) { |
|
|
|
std::vector<size_t> input_attr_idx = AnfAlgo::GetNodeAttr<std::vector<size_t>>(kernel_node, kAttrInputToAttrIdx); |
|
|
|
std::vector<string> input_attr_name = AnfAlgo::GetNodeAttr<std::vector<string>>(kernel_node, kAttrInputToAttrName); |
|
|
|
if (input_attr_idx.size() != input_attr_name.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "The size of input_to_attr_index should be equal to the size of input_to_attr_name, but got " |
|
|
|
<< "input_to_attr_index size: " << input_attr_idx.size() |
|
|
|
<< ", input_to_attr_name size: " << input_attr_name.size() << ". Node:[" |
|
|
|
<< kernel_node->fullname_with_scope() << "]."; |
|
|
|
} |
|
|
|
|
|
|
|
void ConvertAttrToInput(const CNodePtr &kernel_node, std::vector<std::pair<string, size_t>> *infos) { |
|
|
|
auto graph = kernel_node->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto kernel_graph = graph->cast<KernelGraphPtr>(); |
|
|
|
@@ -706,18 +698,28 @@ bool ConvertAttrToInput(const CNodePtr &kernel_node) { |
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node); |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
|
|
|
|
std::ostringstream buf; |
|
|
|
for (auto &info : *infos) { |
|
|
|
buf << " (" << info.first << ", " << info.second << ")"; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Start converting attr to input for aicpu op[" << AnfUtils::GetCNodeName(kernel_node) |
|
|
|
<< "] with attr_name and input_index pairs:" << buf.str(); |
|
|
|
|
|
|
|
std::sort(infos->begin(), infos->end(), |
|
|
|
[](const std::pair<string, size_t> &a, const std::pair<string, size_t> &b) { return a.second < b.second; }); |
|
|
|
auto orig_inputs = kernel_node->inputs(); |
|
|
|
size_t orig_input_num = orig_inputs.size() - 1; |
|
|
|
size_t new_input_num = orig_input_num + input_attr_idx.size(); |
|
|
|
size_t new_input_num = orig_input_num + infos->size(); |
|
|
|
size_t orig_tmp_idx = 0; |
|
|
|
size_t attr_tmp_idx = 0; |
|
|
|
std::vector<AnfNodePtr> new_inputs = {orig_inputs[0]}; |
|
|
|
for (size_t idx = 0; idx < new_input_num; ++idx) { |
|
|
|
if (attr_tmp_idx < input_attr_idx.size() && idx == input_attr_idx[attr_tmp_idx]) { |
|
|
|
auto value = primitive->GetAttr(input_attr_name[attr_tmp_idx]); |
|
|
|
if (attr_tmp_idx < infos->size() && idx == infos->at(attr_tmp_idx).second) { |
|
|
|
auto attr_name = infos->at(attr_tmp_idx).first; |
|
|
|
auto value = primitive->GetAttr(attr_name); |
|
|
|
if (value == nullptr) { |
|
|
|
MS_LOG(INFO) << "Can not get attr[" << input_attr_name[attr_tmp_idx] << "]."; |
|
|
|
return false; |
|
|
|
MS_LOG(INFO) << "Can not get attr[" << attr_name << "]."; |
|
|
|
return; |
|
|
|
} |
|
|
|
tensor::TensorPtr tensor_ptr = nullptr; |
|
|
|
if (value->isa<tensor::Tensor>()) { |
|
|
|
@@ -727,13 +729,12 @@ bool ConvertAttrToInput(const CNodePtr &kernel_node) { |
|
|
|
} else if (value->isa<ValueTuple>()) { |
|
|
|
tensor_ptr = opt::CreateTupleTensor(value->cast<ValueTuplePtr>()); |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "The value of attr[" << input_attr_name[attr_tmp_idx] |
|
|
|
<< "] should be a tensor or scalar or value tuple."; |
|
|
|
return false; |
|
|
|
MS_LOG(INFO) << "The value of attr[" << attr_name << "] should be a tensor or scalar or value tuple."; |
|
|
|
return; |
|
|
|
} |
|
|
|
if (tensor_ptr == nullptr) { |
|
|
|
MS_LOG(INFO) << "Convert attr[" << input_attr_name[attr_tmp_idx] << "] to tensor value failed."; |
|
|
|
return false; |
|
|
|
MS_LOG(INFO) << "Convert attr[" << attr_name << "] to tensor value failed."; |
|
|
|
return; |
|
|
|
} |
|
|
|
auto value_node = kernel_graph->NewValueNode(tensor_ptr); |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
@@ -745,31 +746,6 @@ bool ConvertAttrToInput(const CNodePtr &kernel_node) { |
|
|
|
} |
|
|
|
} |
|
|
|
kernel_node->set_inputs(new_inputs); |
|
|
|
primitive->EraseAttr(kAttrInputToAttrIdx); |
|
|
|
primitive->EraseAttr(kAttrInputToAttrName); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
KernelSelectStatus AICPUSelectWithConvertAttrToInput(const CNodePtr &kernel_node) { |
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrInputToAttrIdx, kernel_node) || |
|
|
|
!AnfAlgo::HasNodeAttr(kAttrInputToAttrName, kernel_node)) { |
|
|
|
return kNoMatched; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "The node [" << kernel_node->fullname_with_scope() |
|
|
|
<< "] cannot find valid kernel info, try to convert attr to input and re-find in ai_cpu kernel info"; |
|
|
|
|
|
|
|
auto orig_inputs = kernel_node->inputs(); |
|
|
|
bool convert_succ = ConvertAttrToInput(kernel_node); |
|
|
|
if (!convert_succ) { |
|
|
|
return kNoMatched; |
|
|
|
} |
|
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list; |
|
|
|
kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list); |
|
|
|
auto select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list); |
|
|
|
if (select_status == kNoMatched) { |
|
|
|
kernel_node->set_inputs(orig_inputs); |
|
|
|
} |
|
|
|
return select_status; |
|
|
|
} |
|
|
|
|
|
|
|
std::string KernelInfoCandidateList(const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &ai_core, |
|
|
|
@@ -799,7 +775,8 @@ std::string KernelInfoCandidateList(const std::vector<std::shared_ptr<kernel::Ke |
|
|
|
|
|
|
|
void PrintNotMatchMessage(const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &ai_core, |
|
|
|
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &ai_cpu, |
|
|
|
const std::ostringstream &buffer, const CNodePtr &kernel_node) { |
|
|
|
const std::ostringstream &aicore_info, const std::ostringstream &aicpu_info, |
|
|
|
const CNodePtr &kernel_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node); |
|
|
|
auto full_name = kernel_node->fullname_with_scope(); |
|
|
|
if (ai_core.empty() && ai_cpu.empty()) { |
|
|
|
@@ -810,8 +787,8 @@ void PrintNotMatchMessage(const std::vector<std::shared_ptr<kernel::KernelBuildI |
|
|
|
auto candidates = KernelInfoCandidateList(ai_core, ai_cpu); |
|
|
|
MS_EXCEPTION(TypeError) << "Can not select a valid kernel info for [" << full_name |
|
|
|
<< "] in AI CORE or AI CPU kernel info candidates list: " << candidates |
|
|
|
<< "Please check the given data type or shape:\n" |
|
|
|
<< buffer.str() |
|
|
|
<< "Please check the given data type or shape:" |
|
|
|
<< "\nAI CORE: " << aicore_info.str() << "\nAI CPU: " << aicpu_info.str() |
|
|
|
<< "\nFor more details, please refer to 'Kernel Select Failed' at " |
|
|
|
"https://www.mindspore.cn" |
|
|
|
<< trace::DumpSourceLines(kernel_node); |
|
|
|
@@ -821,6 +798,7 @@ void PrintNotMatchMessage(const std::vector<std::shared_ptr<kernel::KernelBuildI |
|
|
|
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { |
|
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; |
|
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list; |
|
|
|
std::ostringstream aicore_in_out_info, aicpu_in_out_info; |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node); |
|
|
|
if (AnfAlgo::IsGraphKernel(kernel_node)) { |
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex)); |
|
|
|
@@ -848,21 +826,21 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kern |
|
|
|
} |
|
|
|
// If node can't find valid ai_core kernel info, re-find in ai_cpu kernel info |
|
|
|
if (select_status == kNoMatched) { |
|
|
|
GatherInputAndOutputInferType(aicore_in_out_info, kernel_node); |
|
|
|
MS_LOG(DEBUG) << "The node [" << kernel_node->fullname_with_scope() |
|
|
|
<< "] cannot find valid TBE kernel info, try to get ai_cpu kernel info"; |
|
|
|
std::vector<std::pair<string, size_t>> attr_to_input_infos; |
|
|
|
if (kernel::GetAicpuOpAttrToInputInfo(kernel_node, &attr_to_input_infos) && !AnfAlgo::IsDynamicShape(kernel_node)) { |
|
|
|
ConvertAttrToInput(kernel_node, &attr_to_input_infos); |
|
|
|
} |
|
|
|
kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list); |
|
|
|
select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); |
|
|
|
} |
|
|
|
// If node can't find valid kernel info, try to convert attr to input and re-find in ai_cpu kernel info |
|
|
|
if (select_status == kNoMatched) { |
|
|
|
select_status = AICPUSelectWithConvertAttrToInput(kernel_node); |
|
|
|
} |
|
|
|
// The kernel info can not find in ai_cpu kernel lists and ai_core kernel lists |
|
|
|
if (select_status == kNoMatched) { |
|
|
|
std::ostringstream buffer; |
|
|
|
GatherInputAndOutputInferType(buffer, kernel_node); |
|
|
|
PrintNotMatchMessage(kernel_info_list, aicpu_kernel_info_list, buffer, kernel_node); |
|
|
|
GatherInputAndOutputInferType(aicpu_in_out_info, kernel_node); |
|
|
|
PrintNotMatchMessage(kernel_info_list, aicpu_kernel_info_list, aicore_in_out_info, aicpu_in_out_info, kernel_node); |
|
|
|
} |
|
|
|
return select_status; |
|
|
|
} |
|
|
|
|