| @@ -31,12 +31,13 @@ namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| namespace { | |||
| const float kWegihtBaseScore = 1; | |||
| const float kFeatureMapBaseScore = 10; | |||
| enum MatchCountPriority : int { | |||
| MATCH_COUNT_PRIORITY_BEGIN = 0, | |||
| MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, | |||
| MATCH_FORMAT_COUNT, | |||
| MATCH_SPECIAL_FORMAT_COUNT, | |||
| MATCH_5D_FORMAT_COUNT, | |||
| MATCH_OUTPUT_DTYPE_COUNT, | |||
| MATCH_COUNT_PRIORITY_END | |||
| }; | |||
| @@ -82,13 +83,6 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel:: | |||
| } | |||
| return true; | |||
| }; | |||
| if (AnfAlgo::GetCNodeName(kernel_node) == "Adam") { | |||
| auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_num - 1) != | |||
| kernel_build_info.GetInputFormat(input_num - 1)) { | |||
| return false; | |||
| } | |||
| } | |||
| if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { | |||
| return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && | |||
| AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); | |||
| @@ -112,21 +106,7 @@ bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildIn | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| // Check input data type | |||
| for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { | |||
| AnfNodePtr cur_input = AnfAlgo::GetInputNode(cnode, input_index); | |||
| MS_EXCEPTION_IF_NULL(cur_input); | |||
| TypeId input_origin_type; | |||
| if (cur_input->isa<Parameter>() && AnfAlgo::IsParameterWeight(cur_input->cast<ParameterPtr>())) { | |||
| // weight | |||
| input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0); | |||
| } else if (cur_input->isa<ValueNode>()) { | |||
| input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0); | |||
| } else { | |||
| // feature map | |||
| input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | |||
| } | |||
| if (input_origin_type == kTypeUnknown) { | |||
| continue; | |||
| } | |||
| TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | |||
| if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { | |||
| return false; | |||
| } | |||
| @@ -140,6 +120,29 @@ bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildIn | |||
| return true; | |||
| } | |||
| string GetPriorityMatchFormat(const CNodePtr &cnode) { | |||
| string priority_matched_format = kOpFormat_NC1HWC0; | |||
| bool is_init = false; | |||
| bool need_change_nd = false; | |||
| for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { | |||
| auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); | |||
| if (AnfAlgo::IsFeatureMapInput(cnode, index) && | |||
| kNeedTransFormatSet.find(pre_output_format) != kNeedTransFormatSet.end()) { | |||
| priority_matched_format = !is_init ? priority_matched_format : pre_output_format; | |||
| is_init = true; | |||
| } | |||
| // feature map has two or more special format; | |||
| if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) { | |||
| priority_matched_format = kOpFormat_DEFAULT; | |||
| } | |||
| auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); | |||
| need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1)); | |||
| } | |||
| if (need_change_nd) { | |||
| priority_matched_format = kOpFormat_DEFAULT; | |||
| } | |||
| return priority_matched_format; | |||
| } | |||
| /** | |||
| * compare two vector by priority, select a better vector, like compare two num, first compare highest num location, | |||
| * if equal then next num location | |||
| @@ -172,34 +175,18 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||
| if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) { | |||
| MS_LOG(EXCEPTION) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END; | |||
| } | |||
| auto pri_match_format = GetPriorityMatchFormat(kernel_node); | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| AnfNodePtr input_anf_node = AnfAlgo::GetInputNode(kernel_node, input_index); | |||
| MS_EXCEPTION_IF_NULL(input_anf_node); | |||
| // if a input parameter is a weight with default format, the input shouldn't participate the judge | |||
| if (input_anf_node->isa<Parameter>()) { | |||
| auto para = input_anf_node->cast<ParameterPtr>(); | |||
| if (AnfAlgo::IsParameterWeight(para) && AnfAlgo::GetOutputDeviceDataType(para, 0) == kTypeUnknown) { | |||
| continue; | |||
| } | |||
| } | |||
| auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore; | |||
| if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { | |||
| if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index) && | |||
| kNeedTransFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kNeedTransFormatSet.end()) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT]++; | |||
| } | |||
| (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++; | |||
| (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score; | |||
| } | |||
| if (kernel_build_info.GetInputDeviceType(input_index) == | |||
| AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT]++; | |||
| (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score; | |||
| } | |||
| if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_NC1HWC0) { | |||
| // input is from a feature map & this input's shape is not 4d | |||
| if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index) && | |||
| AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_index).size() != kShape4dDims) { | |||
| continue; | |||
| } | |||
| (*cur_kernelinfo_match_counts)[MATCH_5D_FORMAT_COUNT]++; | |||
| if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score; | |||
| } | |||
| } | |||
| @@ -207,7 +194,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||
| // cal count of same output dtype between abstract and kernel info | |||
| if (kernel_build_info.GetOutputDeviceType(output_index) == | |||
| AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++; | |||
| (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT] += 1; | |||
| } | |||
| } | |||
| } | |||
| @@ -517,7 +504,7 @@ void SelectKernelInfo(const CNodePtr &kernel_node) { | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| kernel::KernelQuery(kernel_node, &kernel_info_list); | |||
| std::vector<int> most_match_counts = {-1, -1, -1, -1, -1}; | |||
| std::vector<int> most_match_counts = {-1, -1, -1, -1}; | |||
| int selected_index = -1; | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| @@ -527,7 +514,7 @@ void SelectKernelInfo(const CNodePtr &kernel_node) { | |||
| std::vector<int> node_mix_precision_datatype_index; | |||
| std::vector<TypeId> node_mix_precision_datatype; | |||
| for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { | |||
| std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0}; | |||
| std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0}; | |||
| auto kernel_build_info = *(kernel_info_list[info_index]); | |||
| if (!IsValidKernelInfo(kernel_node, kernel_build_info)) { | |||
| continue; | |||