|
|
|
@@ -71,9 +71,6 @@ void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::stri |
|
|
|
void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr, |
|
|
|
std::vector<std::string> *output_formats, std::vector<TypeId> *output_types) { |
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); |
|
|
|
if (kernel_attr.GetOutputSize() != output_num) { |
|
|
|
MS_LOG(EXCEPTION) << "Output num is not equal!"; |
|
|
|
} |
|
|
|
for (size_t output_index = 0; output_index < output_num; ++output_index) { |
|
|
|
output_formats->emplace_back(kernel_attr.GetOutputAttr(output_index).second); |
|
|
|
auto dtype = kernel_attr.GetOutputAttr(output_index).first; |
|
|
|
@@ -145,6 +142,11 @@ void SetKernelInfo(const CNodePtr &kernel_node) { |
|
|
|
ExpandKernelAttr(kernel_node, &kernel_attr); |
|
|
|
} |
|
|
|
if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { |
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); |
|
|
|
if (kernel_attr.GetOutputSize() != output_num) { |
|
|
|
MS_LOG(DEBUG) << "Output num is not equal!"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Input format and dtype is matched, index: " << index; |
|
|
|
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types); |
|
|
|
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node); |
|
|
|
|