|
|
|
@@ -59,6 +59,7 @@ void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::stri |
|
|
|
TypeId dtype = kTypeUnknown; |
|
|
|
if (IsInputNotCNode(kernel_node, input_index)) { |
|
|
|
input_no_cnode_indexes->emplace_back(input_index); |
|
|
|
dtype = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); |
|
|
|
} else { |
|
|
|
dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); |
|
|
|
} |
|
|
|
@@ -84,22 +85,25 @@ bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector< |
|
|
|
const std::vector<TypeId> &input_types, |
|
|
|
const std::vector<size_t> &input_not_cnode_indexes) { |
|
|
|
if (kernel_attr.GetInputSize() != input_types.size()) { |
|
|
|
MS_LOG(ERROR) << "Output num is not equal!"; |
|
|
|
MS_LOG(ERROR) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto input_num = input_types.size(); |
|
|
|
for (size_t i = 0; i < input_num; ++i) { |
|
|
|
bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(), |
|
|
|
[i](size_t index) { return index == i; }); |
|
|
|
if (is_not_cnode_idx) { |
|
|
|
bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size()); |
|
|
|
if (have_cnode_input && is_not_cnode_idx) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (kernel_attr.GetInputAttr(i).first != input_types[i]) { |
|
|
|
MS_LOG(ERROR) << "reg dtype=" << kernel_attr.GetInputAttr(i).first << ", input dtype=" << input_types[i]; |
|
|
|
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first |
|
|
|
<< ", actual input dtype:" << input_types[i]; |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (kernel_attr.GetInputAttr(i).second != input_formats[i]) { |
|
|
|
MS_LOG(ERROR) << "reg format=" << kernel_attr.GetInputAttr(i).second << ", input format=" << input_formats[i]; |
|
|
|
MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second |
|
|
|
<< ", actual input format:" << input_formats[i]; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -114,17 +118,19 @@ void SetKernelInfo(const CNodePtr &kernel_node) { |
|
|
|
std::vector<std::string> output_formats; |
|
|
|
std::vector<TypeId> output_types; |
|
|
|
|
|
|
|
MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node); |
|
|
|
GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes); |
|
|
|
|
|
|
|
auto kernel_attrs = |
|
|
|
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node)); |
|
|
|
|
|
|
|
for (auto &kernel_attr : kernel_attrs) { |
|
|
|
if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { |
|
|
|
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types); |
|
|
|
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node); |
|
|
|
for (size_t index = 0; index < kernel_attrs.size(); ++index) { |
|
|
|
if (IsInputFormatDtypeMatched(kernel_attrs[index], input_formats, input_types, input_not_cnode_indexes)) { |
|
|
|
MS_LOG(INFO) << "Input format and dtype is matched, index: " << index; |
|
|
|
GetOutputFormatsAndDtypes(kernel_node, kernel_attrs[index], &output_formats, &output_types); |
|
|
|
UpdatePrevNotCNodeFormatDtype(kernel_attrs[index], input_not_cnode_indexes, kernel_node); |
|
|
|
for (auto &input_index : input_not_cnode_indexes) { |
|
|
|
input_types[input_index] = kernel_attr.GetInputAttr(input_index).first; |
|
|
|
input_types[input_index] = kernel_attrs[index].GetInputAttr(input_index).first; |
|
|
|
} |
|
|
|
break; |
|
|
|
} |
|
|
|
|