| @@ -23,9 +23,7 @@ | |||||
| #include "kernel/oplib/oplib.h" | #include "kernel/oplib/oplib.h" | ||||
| #include "kernel/kernel_query.h" | #include "kernel/kernel_query.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "kernel/kernel_build_info.h" | |||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "operator/ops.h" | |||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -45,7 +43,6 @@ enum MatchCountPriority : int { | |||||
| MATCH_COUNT_PRIORITY_END | MATCH_COUNT_PRIORITY_END | ||||
| }; | }; | ||||
| const size_t kMaxCount = 0xffffffff; | |||||
| const int kUnSupportMixedDataTypeIndex = -1; | const int kUnSupportMixedDataTypeIndex = -1; | ||||
| bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { | bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { | ||||
| @@ -91,7 +88,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { | |||||
| return priority_matched_format; | return priority_matched_format; | ||||
| } | } | ||||
| /** | /** | ||||
| * compare two vector by priority, select a better vector, like compare two num, first compare highest num location, | |||||
| * Compare two vector by priority, select a better vector, like compare two num, first compare highest num location, | |||||
| * if equal then next num location | * if equal then next num location | ||||
| * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3] | * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3] | ||||
| */ | */ | ||||
| @@ -167,8 +164,9 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co | |||||
| if (op_info != nullptr) { | if (op_info != nullptr) { | ||||
| is_ref = op_info->is_ref(); | is_ref = op_info->is_ref(); | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | |||||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode && | |||||
| auto ms_context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | |||||
| if (ms_context->execution_mode() == kPynativeMode && | |||||
| AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { | AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -221,6 +219,7 @@ void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index, | |||||
| std::vector<TypeId> *node_mix_precision_datatype) { | std::vector<TypeId> *node_mix_precision_datatype) { | ||||
| AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index); | AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index); | ||||
| MS_EXCEPTION_IF_NULL(cur_input); | MS_EXCEPTION_IF_NULL(cur_input); | ||||
| MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); | |||||
| TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); | TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); | ||||
| AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index); | AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index); | ||||
| node_mix_precision_datatype->push_back(input_origin_type); | node_mix_precision_datatype->push_back(input_origin_type); | ||||
| @@ -229,6 +228,7 @@ void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index, | |||||
| void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index, | void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index, | ||||
| std::vector<int> *node_mix_precision_datatype_index, | std::vector<int> *node_mix_precision_datatype_index, | ||||
| std::vector<TypeId> *node_mix_precision_datatype) { | std::vector<TypeId> *node_mix_precision_datatype) { | ||||
| MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); | |||||
| auto output_origin_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_index); | auto output_origin_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_index); | ||||
| AddSupportMixedPrecisionDataTypeIndex(output_origin_type, node_mix_precision_datatype_index); | AddSupportMixedPrecisionDataTypeIndex(output_origin_type, node_mix_precision_datatype_index); | ||||
| node_mix_precision_datatype->push_back(output_origin_type); | node_mix_precision_datatype->push_back(output_origin_type); | ||||
| @@ -239,12 +239,12 @@ void CheckDataTypeInputs(const std::vector<int> &node_mix_precision_datatype_ind | |||||
| const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes, | const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes, | ||||
| std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) { | std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) { | ||||
| if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) { | if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) { | ||||
| MS_LOG(EXCEPTION) << "node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size " | |||||
| MS_LOG(EXCEPTION) << "Node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size " | |||||
| << node_mix_precision_datatype.size(); | << node_mix_precision_datatype.size(); | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); | MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); | ||||
| if (kernel_support_datatypes.size() != kernel_match_datatype_idx->size()) { | if (kernel_support_datatypes.size() != kernel_match_datatype_idx->size()) { | ||||
| MS_LOG(EXCEPTION) << "kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size " | |||||
| MS_LOG(EXCEPTION) << "Kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size " | |||||
| << kernel_support_datatypes.size(); | << kernel_support_datatypes.size(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -265,10 +265,10 @@ bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_dat | |||||
| if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { | if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { | ||||
| auto find_iter = kernel_support_datatypes.find(iter->first); | auto find_iter = kernel_support_datatypes.find(iter->first); | ||||
| if (find_iter == kernel_support_datatypes.end()) { | if (find_iter == kernel_support_datatypes.end()) { | ||||
| MS_LOG(EXCEPTION) << "kernel datatype index:%lu can not be found " << iter->first; | |||||
| MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; | |||||
| } | } | ||||
| if (i >= find_iter->second.size()) { | if (i >= find_iter->second.size()) { | ||||
| MS_LOG(EXCEPTION) << "node index " << i << "kernel datatype size " << find_iter->second.size(); | |||||
| MS_LOG(EXCEPTION) << "Node index " << i << "kernel datatype size " << find_iter->second.size(); | |||||
| } | } | ||||
| if (node_mix_precision_datatype[i] != find_iter->second[i]) { | if (node_mix_precision_datatype[i] != find_iter->second[i]) { | ||||
| iter = kernel_match_datatype_idx->erase(iter); | iter = kernel_match_datatype_idx->erase(iter); | ||||
| @@ -279,7 +279,7 @@ bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_dat | |||||
| } | } | ||||
| auto datatype_indexes = iter->second; | auto datatype_indexes = iter->second; | ||||
| if (i >= datatype_indexes.size()) { | if (i >= datatype_indexes.size()) { | ||||
| MS_LOG(EXCEPTION) << "node datatype index: " << i << " kernel support size " << datatype_indexes.size(); | |||||
| MS_LOG(EXCEPTION) << "Node datatype index: " << i << " kernel support size " << datatype_indexes.size(); | |||||
| } | } | ||||
| if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) { | if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) { | ||||
| iter = kernel_match_datatype_idx->erase(iter); | iter = kernel_match_datatype_idx->erase(iter); | ||||
| @@ -415,8 +415,8 @@ std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo( | |||||
| size_t selected_index = 0; | size_t selected_index = 0; | ||||
| for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { | for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { | ||||
| std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0}; | std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0}; | ||||
| auto kernel_build_info = *(kernel_info_list[info_index]); | |||||
| std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index]; | |||||
| auto kernel_info_ptr = kernel_info_list[info_index]; | |||||
| MS_EXCEPTION_IF_NULL(kernel_info_ptr); | |||||
| UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); | UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); | ||||
| // Currently the selection policy is the match format count first, and then is datatype counts. | // Currently the selection policy is the match format count first, and then is datatype counts. | ||||
| if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) { | if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) { | ||||