From f4446c30c810e9ac0381e5d10b3ddab0c6fbcd19 Mon Sep 17 00:00:00 2001 From: liubuyu Date: Thu, 22 Oct 2020 17:52:26 +0800 Subject: [PATCH] raise reduce precision bug fix --- .../device/ascend/kernel_select_ascend.cc | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index faf4115a92..fe6056ef0d 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -224,10 +224,10 @@ std::vector> FilteredKernelInfoByDtype( bool TagRaiseReduce(const std::shared_ptr &kernel_build_info, const CNodePtr &cnode, const std::map &type_map) { + // filte kernel info that unsupported raise or reduce datatype MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(kernel_build_info); - size_t flag_in = 0; - size_t flag_out = 0; + bool flag = false; for (size_t input_index = 0; input_index < kernel_build_info->GetInputNum(); ++input_index) { auto in_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); auto device_dtype = kernel_build_info->GetInputDeviceType(input_index); @@ -235,11 +235,17 @@ bool TagRaiseReduce(const std::shared_ptr &kernel_build device_dtype = kNumberTypeFloat32; } auto iter = type_map.find(in_dtype); + // if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false + if (iter == type_map.end() && in_dtype != device_dtype) { + return false; + } + // infer dtype in type_map, but can not find dst dtype that supported raise or reduce, + // or infer dtype not equal kernel info dtype, return false if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) { return false; } - if (iter == type_map.end() && in_dtype != device_dtype) { - flag_in += 1; + if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) { + flag = true; } } @@ -250,15 +256,22 @@ bool TagRaiseReduce(const std::shared_ptr &kernel_build device_dtype = kNumberTypeFloat32; } auto iter = type_map.find(in_dtype); + // if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false + if (iter == type_map.end() && in_dtype != device_dtype) { + return false; + } + // infer dtype in type_map, but can not find dst dtype that supported raise or reduce, + // or infer dtype not equal kernel info dtype, return false if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) { return false; } - if (iter == type_map.end() && in_dtype != device_dtype) { - flag_out += 1; + if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) { + flag = true; } } - if (flag_in == kernel_build_info->GetInputNum() || flag_out == kernel_build_info->GetOutputNum()) { - return false; + if (flag) { + auto node_name = AnfAlgo::GetCNodeName(cnode); + MS_LOG(WARNING) << "node:[" << node_name << "]reduce precision from int64 to int32"; } return true; }