|
|
|
@@ -342,7 +342,7 @@ void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelB |
|
|
|
std::vector<int> *node_mix_precision_datatype_index) { |
|
|
|
MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); |
|
|
|
bool add_node_datatype_flag = false; |
|
|
|
if (node_mix_precision_datatype->size() == 0) { |
|
|
|
if (node_mix_precision_datatype->empty()) { |
|
|
|
add_node_datatype_flag = true; |
|
|
|
} |
|
|
|
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { |
|
|
|
@@ -464,8 +464,9 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void SelectKernelInfo(const CNodePtr &kernel_node) { |
|
|
|
int SelectKernelInfo(const CNodePtr &kernel_node) { |
|
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; |
|
|
|
int status = kStatusAllMatched; |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node); |
|
|
|
bool precision_reduce = false; |
|
|
|
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr; |
|
|
|
@@ -486,11 +487,13 @@ void SelectKernelInfo(const CNodePtr &kernel_node) { |
|
|
|
<< "] cannot find valid kernel info, not supported the type" << buffer.str(); |
|
|
|
} else { |
|
|
|
PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce); |
|
|
|
status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; |
|
|
|
} |
|
|
|
} |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); |
|
|
|
// Set format and data type for input tensor. |
|
|
|
SetTensorDeviceInfo(*selected_kernel_info, kernel_node); |
|
|
|
return status; |
|
|
|
} |
|
|
|
|
|
|
|
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, |
|
|
|
|