|
|
@@ -40,6 +40,7 @@ enum MatchCountPriority : int { |
|
|
MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, |
|
|
MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, |
|
|
MATCH_FORMAT_COUNT, |
|
|
MATCH_FORMAT_COUNT, |
|
|
MATCH_SPECIAL_FORMAT_COUNT, |
|
|
MATCH_SPECIAL_FORMAT_COUNT, |
|
|
|
|
|
MATCH_DEFAULT_FORMAT_COUNT, |
|
|
MATCH_OUTPUT_DTYPE_COUNT, |
|
|
MATCH_OUTPUT_DTYPE_COUNT, |
|
|
MATCH_COUNT_PRIORITY_END |
|
|
MATCH_COUNT_PRIORITY_END |
|
|
}; |
|
|
}; |
|
|
@@ -73,7 +74,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { |
|
|
auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); |
|
|
auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); |
|
|
if (AnfAlgo::IsFeatureMapInput(cnode, index) && |
|
|
if (AnfAlgo::IsFeatureMapInput(cnode, index) && |
|
|
kNeedTransFormatSet.find(pre_output_format) != kNeedTransFormatSet.end()) { |
|
|
kNeedTransFormatSet.find(pre_output_format) != kNeedTransFormatSet.end()) { |
|
|
priority_matched_format = !is_init ? priority_matched_format : pre_output_format; |
|
|
|
|
|
|
|
|
priority_matched_format = !is_init ? pre_output_format : priority_matched_format; |
|
|
is_init = true; |
|
|
is_init = true; |
|
|
} |
|
|
} |
|
|
// feature map has two or more special format; |
|
|
// feature map has two or more special format; |
|
|
@@ -83,7 +84,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { |
|
|
auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); |
|
|
auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); |
|
|
need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1)); |
|
|
need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1)); |
|
|
} |
|
|
} |
|
|
if (need_change_nd) { |
|
|
|
|
|
|
|
|
if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) { |
|
|
priority_matched_format = kOpFormat_DEFAULT; |
|
|
priority_matched_format = kOpFormat_DEFAULT; |
|
|
} |
|
|
} |
|
|
AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); |
|
|
AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); |
|
|
@@ -134,6 +135,9 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons |
|
|
if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { |
|
|
if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { |
|
|
(*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score; |
|
|
(*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score; |
|
|
} |
|
|
} |
|
|
|
|
|
if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT) { |
|
|
|
|
|
(*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { |
|
|
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { |
|
|
@@ -410,10 +414,10 @@ std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo( |
|
|
if (kernel_info_list.empty()) { |
|
|
if (kernel_info_list.empty()) { |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
std::vector<int> most_match_counts = {-1, -1, -1, -1}; |
|
|
|
|
|
|
|
|
std::vector<int> most_match_counts = {-1, -1, -1, -1, -1}; |
|
|
size_t selected_index = 0; |
|
|
size_t selected_index = 0; |
|
|
for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { |
|
|
for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { |
|
|
std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0}; |
|
|
|
|
|
|
|
|
std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0}; |
|
|
auto kernel_build_info = *(kernel_info_list[info_index]); |
|
|
auto kernel_build_info = *(kernel_info_list[info_index]); |
|
|
std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index]; |
|
|
std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index]; |
|
|
UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); |
|
|
UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); |
|
|
|