|
|
|
@@ -106,7 +106,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { |
|
|
|
bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best_item) { |
|
|
|
MS_EXCEPTION_IF_NULL(best_item); |
|
|
|
if (cur_item.size() != best_item->size()) { |
|
|
|
MS_LOG(ERROR) << "item size should be same!"; |
|
|
|
MS_LOG(ERROR) << "Item size should be same!"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
// Update the best_item by comparing the cur_item and best_item |
|
|
|
@@ -280,8 +280,12 @@ bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_dat |
|
|
|
|
|
|
|
bool CanDataTypeReduce(const std::vector<int> &datatype_indexes, int check_index, |
|
|
|
const std::vector<int> &node_mix_precision_datatype_index) { |
|
|
|
return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex && |
|
|
|
datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index]; |
|
|
|
auto check_index_tmp = IntToSize(check_index); |
|
|
|
if (check_index_tmp < datatype_indexes.size() && check_index_tmp < node_mix_precision_datatype_index.size()) { |
|
|
|
return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex && |
|
|
|
datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index]; |
|
|
|
} |
|
|
|
MS_LOG(EXCEPTION) << "Check index " << check_index << "is outof range"; |
|
|
|
} |
|
|
|
|
|
|
|
bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index, |
|
|
|
@@ -300,10 +304,10 @@ bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_preci |
|
|
|
if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { |
|
|
|
auto find_iter = kernel_support_datatypes.find(iter->first); |
|
|
|
if (find_iter == kernel_support_datatypes.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "kernel datatype index:%lu can not be found " << iter->first; |
|
|
|
MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; |
|
|
|
} |
|
|
|
if (i >= find_iter->second.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "node index " << i << " >= kernel datatype size " << find_iter->second.size(); |
|
|
|
MS_LOG(EXCEPTION) << "Node index " << i << " >= kernel datatype size " << find_iter->second.size(); |
|
|
|
} |
|
|
|
if (node_mix_precision_datatype[i] != find_iter->second[i]) { |
|
|
|
iter = kernel_match_datatype_idx->erase(iter); |
|
|
|
@@ -314,7 +318,7 @@ bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_preci |
|
|
|
} |
|
|
|
auto datatype_indexes = iter->second; |
|
|
|
if (i >= datatype_indexes.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "index " << i << "> kernel datatype indexes size " << datatype_indexes.size(); |
|
|
|
MS_LOG(EXCEPTION) << "Index " << i << "> kernel datatype indexes size " << datatype_indexes.size(); |
|
|
|
} |
|
|
|
if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) { |
|
|
|
iter = kernel_match_datatype_idx->erase(iter); |
|
|
|
@@ -384,9 +388,9 @@ void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode, |
|
|
|
std::ostringstream buffer; |
|
|
|
buffer << cnode->DebugString(); |
|
|
|
if (precision_reduce) { |
|
|
|
buffer << " reduce precision, node datatype: \n"; |
|
|
|
buffer << " Reduce precision, node datatype: \n"; |
|
|
|
} else { |
|
|
|
buffer << " raise precision, node datatype: \n"; |
|
|
|
buffer << " Raise precision, node datatype: \n"; |
|
|
|
} |
|
|
|
PrintInputAndOutputInferType(buffer, cnode); |
|
|
|
buffer << ", select kernel:" << selected_kernel_build_info->ToString(); |
|
|
|
@@ -554,12 +558,12 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kern |
|
|
|
if (select_status == kNoMatched) { |
|
|
|
std::ostringstream buffer; |
|
|
|
PrintInputAndOutputInferType(buffer, kernel_node); |
|
|
|
MS_LOG(WARNING) << ">>> candidates kernel info list:"; |
|
|
|
MS_LOG(WARNING) << ">>> Candidates kernel info list:"; |
|
|
|
for (size_t index = 0; index < kernel_info_list.size(); ++index) { |
|
|
|
MS_LOG(WARNING) << "kernel [" << index << "] :" << kernel_info_list[index]->ToString(); |
|
|
|
MS_LOG(WARNING) << "Kernel [" << index << "] :" << kernel_info_list[index]->ToString(); |
|
|
|
} |
|
|
|
for (size_t index = 0; index < aicpu_kernel_info_list.size(); ++index) { |
|
|
|
MS_LOG(WARNING) << "kernel [" << (kernel_info_list.size() + index) |
|
|
|
MS_LOG(WARNING) << "Kernel [" << (kernel_info_list.size() + index) |
|
|
|
<< "] :" << aicpu_kernel_info_list[index]->ToString(); |
|
|
|
} |
|
|
|
MS_LOG(WARNING) << " <<<"; |
|
|
|
|