|
|
|
@@ -224,10 +224,10 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype( |
|
|
|
|
|
|
|
bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode, |
|
|
|
const std::map<TypeId, TypeId> &type_map) { |
|
|
|
// filte kernel info that unsupported raise or reduce datatype |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_build_info); |
|
|
|
size_t flag_in = 0; |
|
|
|
size_t flag_out = 0; |
|
|
|
bool flag = false; |
|
|
|
for (size_t input_index = 0; input_index < kernel_build_info->GetInputNum(); ++input_index) { |
|
|
|
auto in_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); |
|
|
|
auto device_dtype = kernel_build_info->GetInputDeviceType(input_index); |
|
|
|
@@ -235,11 +235,17 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build |
|
|
|
device_dtype = kNumberTypeFloat32; |
|
|
|
} |
|
|
|
auto iter = type_map.find(in_dtype); |
|
|
|
// if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false |
|
|
|
if (iter == type_map.end() && in_dtype != device_dtype) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
// infer dtype in type_map, but can not find dst dtype that supported raise or reduce, |
|
|
|
// or infer dtype not equal kernel info dtype, return false |
|
|
|
if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (iter == type_map.end() && in_dtype != device_dtype) { |
|
|
|
flag_in += 1; |
|
|
|
if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) { |
|
|
|
flag = true; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -250,15 +256,22 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build |
|
|
|
device_dtype = kNumberTypeFloat32; |
|
|
|
} |
|
|
|
auto iter = type_map.find(in_dtype); |
|
|
|
// if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false |
|
|
|
if (iter == type_map.end() && in_dtype != device_dtype) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
// infer dtype in type_map, but can not find dst dtype that supported raise or reduce, |
|
|
|
// or infer dtype not equal kernel info dtype, return false |
|
|
|
if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (iter == type_map.end() && in_dtype != device_dtype) { |
|
|
|
flag_out += 1; |
|
|
|
if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) { |
|
|
|
flag = true; |
|
|
|
} |
|
|
|
} |
|
|
|
if (flag_in == kernel_build_info->GetInputNum() || flag_out == kernel_build_info->GetOutputNum()) { |
|
|
|
return false; |
|
|
|
if (flag) { |
|
|
|
auto node_name = AnfAlgo::GetCNodeName(cnode); |
|
|
|
MS_LOG(WARNING) << "node:[" << node_name << "]reduce precision from int64 to int32"; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|