| @@ -16,23 +16,23 @@ | |||||
| #include "runtime/device/ascend/kernel_select_ascend.h" | #include "runtime/device/ascend/kernel_select_ascend.h" | ||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <map> | #include <map> | ||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include "utils/ms_utils.h" | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/kernel_build_info.h" | |||||
| #include "backend/kernel_compiler/kernel_query.h" | |||||
| #include "backend/kernel_compiler/oplib/oplib.h" | |||||
| #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" | #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "backend/kernel_compiler/kernel_query.h" | |||||
| #include "backend/kernel_compiler/oplib/oplib.h" | |||||
| #include "backend/kernel_compiler/kernel_build_info.h" | |||||
| #include "utils/ms_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| @@ -172,218 +172,6 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||||
| } | } | ||||
| } | } | ||||
| void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) { | |||||
| MS_EXCEPTION_IF_NULL(support_index); | |||||
| int index = kUnSupportMixedDataTypeIndex; | |||||
| switch (data_type) { | |||||
| case kNumberTypeFloat16: | |||||
| index = 0; | |||||
| break; | |||||
| case kNumberTypeFloat32: | |||||
| case kNumberTypeFloat: | |||||
| index = 1; | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| support_index->push_back(index); | |||||
| } | |||||
| void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t input_index, | |||||
| std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) { | |||||
| MS_EXCEPTION_IF_NULL(support_datatype); | |||||
| auto data_type = kernel_build_info.GetInputDeviceType(input_index); | |||||
| support_datatype->push_back(data_type); | |||||
| AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); | |||||
| } | |||||
| void AddKernelOutputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t output_index, | |||||
| std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) { | |||||
| MS_EXCEPTION_IF_NULL(support_datatype); | |||||
| auto data_type = kernel_build_info.GetOutputDeviceType(output_index); | |||||
| support_datatype->push_back(data_type); | |||||
| AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); | |||||
| } | |||||
| void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index, | |||||
| std::vector<int> *node_mix_precision_datatype_index, | |||||
| std::vector<TypeId> *node_mix_precision_datatype) { | |||||
| AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index); | |||||
| MS_EXCEPTION_IF_NULL(cur_input); | |||||
| MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); | |||||
| TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); | |||||
| AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index); | |||||
| node_mix_precision_datatype->push_back(input_origin_type); | |||||
| } | |||||
| void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index, | |||||
| std::vector<int> *node_mix_precision_datatype_index, | |||||
| std::vector<TypeId> *node_mix_precision_datatype) { | |||||
| MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); | |||||
| auto output_origin_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_index); | |||||
| AddSupportMixedPrecisionDataTypeIndex(output_origin_type, node_mix_precision_datatype_index); | |||||
| node_mix_precision_datatype->push_back(output_origin_type); | |||||
| } | |||||
| void CheckDataTypeInputs(const std::vector<int> &node_mix_precision_datatype_index, | |||||
| const std::vector<TypeId> &node_mix_precision_datatype, | |||||
| const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes, | |||||
| std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) { | |||||
| if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) { | |||||
| MS_LOG(EXCEPTION) << "Node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size " | |||||
| << node_mix_precision_datatype.size(); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); | |||||
| if (kernel_support_datatypes.size() != kernel_match_datatype_idx->size()) { | |||||
| MS_LOG(EXCEPTION) << "Kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size " | |||||
| << kernel_support_datatypes.size(); | |||||
| } | |||||
| } | |||||
| bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index, | |||||
| const std::vector<TypeId> &node_mix_precision_datatype, | |||||
| const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes, | |||||
| std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); | |||||
| CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, | |||||
| kernel_match_datatype_idx); | |||||
| for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { | |||||
| if (node_mix_precision_datatype[i] == kTypeUnknown) { | |||||
| continue; | |||||
| } | |||||
| auto iter = kernel_match_datatype_idx->begin(); | |||||
| while (iter != kernel_match_datatype_idx->end()) { | |||||
| if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { | |||||
| auto find_iter = kernel_support_datatypes.find(iter->first); | |||||
| if (find_iter == kernel_support_datatypes.end()) { | |||||
| MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; | |||||
| } | |||||
| if (i >= find_iter->second.size()) { | |||||
| MS_LOG(EXCEPTION) << "Node index " << i << "kernel datatype size " << find_iter->second.size(); | |||||
| } | |||||
| if (node_mix_precision_datatype[i] != find_iter->second[i]) { | |||||
| iter = kernel_match_datatype_idx->erase(iter); | |||||
| } else { | |||||
| ++iter; | |||||
| } | |||||
| continue; | |||||
| } | |||||
| auto datatype_indexes = iter->second; | |||||
| if (i >= datatype_indexes.size()) { | |||||
| MS_LOG(EXCEPTION) << "Node datatype index: " << i << " kernel support size " << datatype_indexes.size(); | |||||
| } | |||||
| if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) { | |||||
| iter = kernel_match_datatype_idx->erase(iter); | |||||
| } else { | |||||
| ++iter; | |||||
| } | |||||
| } | |||||
| } | |||||
| return !kernel_match_datatype_idx->empty(); | |||||
| } | |||||
| bool CanDataTypeReduce(const std::vector<int> &datatype_indexes, int check_index, | |||||
| const std::vector<int> &node_mix_precision_datatype_index) { | |||||
| auto check_index_tmp = IntToSize(check_index); | |||||
| if (check_index_tmp < datatype_indexes.size() && check_index_tmp < node_mix_precision_datatype_index.size()) { | |||||
| return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex && | |||||
| datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index]; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Check index " << check_index << "is outof range"; | |||||
| } | |||||
| bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index, | |||||
| const std::vector<TypeId> &node_mix_precision_datatype, | |||||
| const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes, | |||||
| std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); | |||||
| CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, | |||||
| kernel_match_datatype_idx); | |||||
| for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { | |||||
| if (node_mix_precision_datatype[i] == kTypeUnknown) { | |||||
| continue; | |||||
| } | |||||
| auto iter = kernel_match_datatype_idx->begin(); | |||||
| while (iter != kernel_match_datatype_idx->end()) { | |||||
| if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { | |||||
| auto find_iter = kernel_support_datatypes.find(iter->first); | |||||
| if (find_iter == kernel_support_datatypes.end()) { | |||||
| MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; | |||||
| } | |||||
| if (i >= find_iter->second.size()) { | |||||
| MS_LOG(EXCEPTION) << "Node index " << i << " >= kernel datatype size " << find_iter->second.size(); | |||||
| } | |||||
| if (node_mix_precision_datatype[i] != find_iter->second[i]) { | |||||
| iter = kernel_match_datatype_idx->erase(iter); | |||||
| } else { | |||||
| ++iter; | |||||
| } | |||||
| continue; | |||||
| } | |||||
| auto datatype_indexes = iter->second; | |||||
| if (i >= datatype_indexes.size()) { | |||||
| MS_LOG(EXCEPTION) << "Index " << i << "> kernel datatype indexes size " << datatype_indexes.size(); | |||||
| } | |||||
| if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) { | |||||
| iter = kernel_match_datatype_idx->erase(iter); | |||||
| } else { | |||||
| ++iter; | |||||
| } | |||||
| } | |||||
| } | |||||
| return !kernel_match_datatype_idx->empty(); | |||||
| } | |||||
| void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info, | |||||
| std::vector<int> *support_indexes, std::vector<TypeId> *node_mix_precision_datatype, | |||||
| std::vector<TypeId> *support_datatypes, | |||||
| std::vector<int> *node_mix_precision_datatype_index) { | |||||
| MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); | |||||
| bool add_node_datatype_flag = false; | |||||
| if (node_mix_precision_datatype->empty()) { | |||||
| add_node_datatype_flag = true; | |||||
| } | |||||
| for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { | |||||
| AddKernelInputSupportDataType(kernel_build_info, input_index, support_indexes, support_datatypes); | |||||
| if (add_node_datatype_flag) { | |||||
| AddNodeInputDataType(kernel_node, input_index, node_mix_precision_datatype_index, node_mix_precision_datatype); | |||||
| } | |||||
| } | |||||
| // Check output data type | |||||
| for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { | |||||
| AddKernelOutputSupportDataType(kernel_build_info, output_index, support_indexes, support_datatypes); | |||||
| if (add_node_datatype_flag) { | |||||
| AddNodeOutputDataType(kernel_node, output_index, node_mix_precision_datatype_index, node_mix_precision_datatype); | |||||
| } | |||||
| } | |||||
| } | |||||
| void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index, | |||||
| const std::vector<TypeId> &node_mix_precision_datatype, | |||||
| const std::map<size_t, std::vector<TypeId>> &kernel_support_datatype, | |||||
| std::map<size_t, std::vector<int>> *kernel_match_datatype_idx, bool *precision_reduce) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| MS_EXCEPTION_IF_NULL(precision_reduce); | |||||
| std::map<size_t, std::vector<int>> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx; | |||||
| // raise precision | |||||
| bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, | |||||
| kernel_support_datatype, kernel_match_datatype_idx); | |||||
| if (selected_ret) { | |||||
| *precision_reduce = false; | |||||
| return; | |||||
| } | |||||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION)) { | |||||
| selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, | |||||
| kernel_support_datatype, &kernel_match_datatype_idx_copy); | |||||
| } | |||||
| if (selected_ret) { | |||||
| *precision_reduce = true; | |||||
| *kernel_match_datatype_idx = kernel_match_datatype_idx_copy; | |||||
| } | |||||
| } | |||||
| void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode, | void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode, | ||||
| const std::shared_ptr<kernel::KernelBuildInfo> &selected_kernel_build_info, | const std::shared_ptr<kernel::KernelBuildInfo> &selected_kernel_build_info, | ||||
| bool precision_reduce) { | bool precision_reduce) { | ||||
| @@ -434,30 +222,82 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype( | |||||
| return result; | return result; | ||||
| } | } | ||||
| bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode, | |||||
| const std::map<TypeId, TypeId> &type_map) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_EXCEPTION_IF_NULL(kernel_build_info); | |||||
| size_t flag_in = 0; | |||||
| size_t flag_out = 0; | |||||
| for (size_t input_index = 0; input_index < kernel_build_info->GetInputNum(); ++input_index) { | |||||
| auto in_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | |||||
| auto device_dtype = kernel_build_info->GetInputDeviceType(input_index); | |||||
| if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) { | |||||
| device_dtype = kNumberTypeFloat32; | |||||
| } | |||||
| auto iter = type_map.find(in_dtype); | |||||
| if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) { | |||||
| return false; | |||||
| } | |||||
| if (iter == type_map.end() && in_dtype != device_dtype) { | |||||
| flag_in += 1; | |||||
| } | |||||
| } | |||||
| for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) { | |||||
| auto in_dtype = AnfAlgo::GetOutputInferDataType(cnode, output_index); | |||||
| auto device_dtype = kernel_build_info->GetOutputDeviceType(output_index); | |||||
| if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) { | |||||
| device_dtype = kNumberTypeFloat32; | |||||
| } | |||||
| auto iter = type_map.find(in_dtype); | |||||
| if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) { | |||||
| return false; | |||||
| } | |||||
| if (iter == type_map.end() && in_dtype != device_dtype) { | |||||
| flag_out += 1; | |||||
| } | |||||
| } | |||||
| if (flag_in == kernel_build_info->GetInputNum() || flag_out == kernel_build_info->GetOutputNum()) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecisionMatchedKernelInfo( | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecisionMatchedKernelInfo( | ||||
| const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list, | const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list, | ||||
| bool *precision_reduce) { | bool *precision_reduce) { | ||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_kernel_info_list; | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_kernel_info_list; | ||||
| std::map<size_t, std::vector<int>> kernel_match_datatype_idx; | |||||
| std::map<size_t, std::vector<TypeId>> kernel_support_datatype; | |||||
| std::vector<int> node_mix_precision_datatype_index; | |||||
| std::vector<TypeId> node_mix_precision_datatype; | |||||
| const std::map<TypeId, TypeId> raise_map = {{kNumberTypeFloat16, kNumberTypeFloat32}}; | |||||
| const std::map<TypeId, TypeId> reduce_map = {{kNumberTypeInt64, kNumberTypeInt32}, | |||||
| {kNumberTypeFloat, kNumberTypeFloat16}, | |||||
| {kNumberTypeFloat32, kNumberTypeFloat16}}; | |||||
| // raise precision | |||||
| for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { | for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { | ||||
| std::vector<int> support_indexes; | |||||
| std::vector<TypeId> support_datatypes; | |||||
| MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]); | MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]); | ||||
| AddNodeAndKernelDataType(cnode, *kernel_info_list[info_index], &support_indexes, &node_mix_precision_datatype, | |||||
| &support_datatypes, &node_mix_precision_datatype_index); | |||||
| kernel_match_datatype_idx[info_index] = support_indexes; | |||||
| kernel_support_datatype[info_index] = support_datatypes; | |||||
| if (TagRaiseReduce(kernel_info_list[info_index], cnode, raise_map)) { | |||||
| filtered_kernel_info_list.push_back(kernel_info_list[info_index]); | |||||
| } | |||||
| } | |||||
| if (!filtered_kernel_info_list.empty()) { | |||||
| *precision_reduce = false; | |||||
| return filtered_kernel_info_list; | |||||
| } | |||||
| // reduce precision | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION)) { | |||||
| for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]); | |||||
| if (TagRaiseReduce(kernel_info_list[info_index], cnode, reduce_map)) { | |||||
| filtered_kernel_info_list.push_back(kernel_info_list[info_index]); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!filtered_kernel_info_list.empty()) { | |||||
| *precision_reduce = true; | |||||
| } | } | ||||
| PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype, | |||||
| &kernel_match_datatype_idx, precision_reduce); | |||||
| std::transform( | |||||
| kernel_match_datatype_idx.begin(), kernel_match_datatype_idx.end(), std::back_inserter(filtered_kernel_info_list), | |||||
| [&](const std::pair<size_t, std::vector<int>> &matched_idx) -> std::shared_ptr<kernel::KernelBuildInfo> { | |||||
| return kernel_info_list[matched_idx.first]; | |||||
| }); | |||||
| return filtered_kernel_info_list; | return filtered_kernel_info_list; | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||