| @@ -78,33 +78,40 @@ void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &ke | |||||
| } | } | ||||
| } | } | ||||
| bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector<std::string> &input_formats, | |||||
| const std::vector<TypeId> &input_types, | |||||
| const std::vector<size_t> &input_not_cnode_indexes) { | |||||
| std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr, | |||||
| const std::vector<std::string> &input_formats, | |||||
| const std::vector<TypeId> &input_types, | |||||
| const std::vector<size_t> &input_not_cnode_indexes) { | |||||
| if (kernel_attr.GetInputSize() != input_types.size()) { | if (kernel_attr.GetInputSize() != input_types.size()) { | ||||
| MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size(); | MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size(); | ||||
| return false; | |||||
| return std::make_pair(0, 0); | |||||
| } | } | ||||
| int data_type_matched_num = 0; | |||||
| int format_matched_num = 0; | |||||
| auto input_num = input_types.size(); | auto input_num = input_types.size(); | ||||
| for (size_t i = 0; i < input_num; ++i) { | for (size_t i = 0; i < input_num; ++i) { | ||||
| bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(), | bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(), | ||||
| [i](size_t index) { return index == i; }); | [i](size_t index) { return index == i; }); | ||||
| bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size()); | bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size()); | ||||
| if (have_cnode_input && is_not_cnode_idx) { | if (have_cnode_input && is_not_cnode_idx) { | ||||
| data_type_matched_num++; | |||||
| format_matched_num++; | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (kernel_attr.GetInputAttr(i).first != input_types[i]) { | if (kernel_attr.GetInputAttr(i).first != input_types[i]) { | ||||
| MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first | MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first | ||||
| << ", actual input dtype:" << input_types[i]; | << ", actual input dtype:" << input_types[i]; | ||||
| return false; | |||||
| } else { | |||||
| data_type_matched_num++; | |||||
| } | } | ||||
| if (kernel_attr.GetInputAttr(i).second != input_formats[i]) { | if (kernel_attr.GetInputAttr(i).second != input_formats[i]) { | ||||
| MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second | MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second | ||||
| << ", actual input format:" << input_formats[i]; | << ", actual input format:" << input_formats[i]; | ||||
| return false; | |||||
| } else { | |||||
| format_matched_num++; | |||||
| } | } | ||||
| } | } | ||||
| return true; | |||||
| return std::make_pair(data_type_matched_num, format_matched_num); | |||||
| } | } | ||||
| void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) { | void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) { | ||||
| @@ -121,6 +128,18 @@ void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) { | |||||
| kernel_attr->AddOutputAttr(output_dtype); | kernel_attr->AddOutputAttr(output_dtype); | ||||
| } | } | ||||
| } | } | ||||
| void SetKernelBuildInfo(const std::vector<std::string> &input_formats, const std::vector<TypeId> &input_types, | |||||
| const std::vector<std::string> &output_formats, const std::vector<TypeId> &output_types, | |||||
| AnfNode *kernel_node) { | |||||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| MS_EXCEPTION_IF_NULL(builder); | |||||
| builder->SetInputsFormat(input_formats); | |||||
| builder->SetInputsDeviceType(input_types); | |||||
| builder->SetOutputsFormat(output_formats); | |||||
| builder->SetOutputsDeviceType(output_types); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node); | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| void SetKernelInfo(const CNodePtr &kernel_node) { | void SetKernelInfo(const CNodePtr &kernel_node) { | ||||
| @@ -136,38 +155,49 @@ void SetKernelInfo(const CNodePtr &kernel_node) { | |||||
| auto kernel_attrs = | auto kernel_attrs = | ||||
| kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node)); | kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node)); | ||||
| for (size_t index = 0; index < kernel_attrs.size(); ++index) { | |||||
| auto kernel_attr = kernel_attrs[index]; | |||||
| int max_type_matched_num = -1; | |||||
| int max_format_matched_num = -1; | |||||
| KernelAttr selected_kernel_attr; | |||||
| for (auto kernel_attr : kernel_attrs) { | |||||
| if (kernel_attr.GetAllSame()) { | if (kernel_attr.GetAllSame()) { | ||||
| ExpandKernelAttr(kernel_node, &kernel_attr); | ExpandKernelAttr(kernel_node, &kernel_attr); | ||||
| } | } | ||||
| bool ignore_check = false; | |||||
| if (index == kernel_attrs.size() - 1 && input_types.size() == input_not_cnode_indexes.size()) { | |||||
| ignore_check = true; | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (kernel_attr.GetOutputSize() != output_num) { | |||||
| MS_LOG(DEBUG) << "Output num is not equal!"; | |||||
| continue; | |||||
| } | |||||
| std::pair<int, int> input_type_format_matched_num = | |||||
| GetInputDtypeFormatMatchedNum(kernel_attr, input_formats, input_types, input_not_cnode_indexes); | |||||
| // Data type first | |||||
| if (input_type_format_matched_num.first > max_type_matched_num) { | |||||
| max_type_matched_num = input_type_format_matched_num.first; | |||||
| max_format_matched_num = input_type_format_matched_num.second; | |||||
| selected_kernel_attr = kernel_attr; | |||||
| } else if (input_type_format_matched_num.first == max_type_matched_num && | |||||
| input_type_format_matched_num.second > max_format_matched_num) { | |||||
| max_format_matched_num = input_type_format_matched_num.second; | |||||
| selected_kernel_attr = kernel_attr; | |||||
| } | } | ||||
| if (ignore_check || IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (kernel_attr.GetOutputSize() != output_num) { | |||||
| MS_LOG(DEBUG) << "Output num is not equal!"; | |||||
| continue; | |||||
| } | |||||
| MS_LOG(INFO) << "Input format and dtype is matched, index: " << index; | |||||
| GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types); | |||||
| UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node); | |||||
| for (auto &input_index : input_not_cnode_indexes) { | |||||
| input_types[input_index] = kernel_attr.GetInputAttr(input_index).first; | |||||
| } | |||||
| // All formats and data types matched | |||||
| if (max_type_matched_num == SizeToInt(input_types.size()) && | |||||
| max_format_matched_num == SizeToInt(input_types.size())) { | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| MS_EXCEPTION_IF_NULL(builder); | |||||
| builder->SetInputsFormat(input_formats); | |||||
| builder->SetInputsDeviceType(input_types); | |||||
| builder->SetOutputsFormat(output_formats); | |||||
| builder->SetOutputsDeviceType(output_types); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); | |||||
| if ((max_type_matched_num == SizeToInt(input_types.size()) && | |||||
| max_format_matched_num == SizeToInt(input_types.size())) || | |||||
| input_types.size() == input_not_cnode_indexes.size()) { | |||||
| MS_LOG(INFO) << "Input format and dtype is matched, max_type_matched_num: " << max_type_matched_num | |||||
| << ", max_format_matched_num: " << max_format_matched_num; | |||||
| GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &output_formats, &output_types); | |||||
| UpdatePrevNotCNodeFormatDtype(selected_kernel_attr, input_not_cnode_indexes, kernel_node); | |||||
| for (auto &input_index : input_not_cnode_indexes) { | |||||
| input_types[input_index] = selected_kernel_attr.GetInputAttr(input_index).first; | |||||
| } | |||||
| } | |||||
| SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get()); | |||||
| } | } | ||||
| } // namespace cpu | } // namespace cpu | ||||
| } // namespace device | } // namespace device | ||||