|
|
|
@@ -205,6 +205,37 @@ void SetKernelBuildInfo(const std::vector<std::string> &input_formats, const std |
|
|
|
builder->SetOutputsDeviceType(output_types); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node); |
|
|
|
} |
|
|
|
|
|
|
|
void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector<TypeId> &input_types, |
|
|
|
const std::vector<TypeId> &infer_output_types) { |
|
|
|
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); |
|
|
|
std::stringstream operator_info; |
|
|
|
operator_info << "Operator[" << kernel_name << "] "; |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); |
|
|
|
if (input_num > 0) { |
|
|
|
operator_info << " input("; |
|
|
|
for (size_t i = 0; i < input_num; ++i) { |
|
|
|
operator_info << TypeIdLabel(input_types[i]); |
|
|
|
if (i != input_num - 1) { |
|
|
|
operator_info << ","; |
|
|
|
} |
|
|
|
} |
|
|
|
operator_info << ") "; |
|
|
|
} |
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); |
|
|
|
if (output_num > 0) { |
|
|
|
operator_info << "output("; |
|
|
|
for (size_t i = 0; i < output_num; ++i) { |
|
|
|
operator_info << TypeIdLabel(infer_output_types[i]); |
|
|
|
if (i != output_num - 1) { |
|
|
|
operator_info << ","; |
|
|
|
} |
|
|
|
} |
|
|
|
operator_info << ") "; |
|
|
|
} |
|
|
|
operator_info << "is not support."; |
|
|
|
MS_LOG(EXCEPTION) << operator_info.str(); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr, |
|
|
|
const std::vector<KernelAttr> &kernel_attrs, const std::vector<std::string> &input_formats, |
|
|
|
@@ -275,10 +306,17 @@ void SetKernelInfo(const CNodePtr &kernel_node) { |
|
|
|
std::pair<bool, bool> matched = std::make_pair(false, false); |
|
|
|
if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, |
|
|
|
input_not_cnode_indexes, infer_output_formats, infer_output_types, &matched, true)) { |
|
|
|
if (AnfAlgo::GetCNodeName(kernel_node) == "Cast") { |
|
|
|
KernelNotSupportException(kernel_node, input_types, infer_output_types); |
|
|
|
} |
|
|
|
matched = std::make_pair(false, false); |
|
|
|
SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, input_not_cnode_indexes, |
|
|
|
infer_output_formats, infer_output_types, &matched, false); |
|
|
|
if (!matched.first) { |
|
|
|
KernelNotSupportException(kernel_node, input_types, infer_output_types); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (selected_kernel_attr.GetInputSize() > 0 && |
|
|
|
(matched.first || input_types.size() == input_not_cnode_indexes.size())) { |
|
|
|
MS_LOG(INFO) << "Input format and dtype is matched"; |
|
|
|
|