|
|
|
@@ -54,6 +54,9 @@ void CPUKernelFactory::SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_i |
|
|
|
std::vector<KernelAttr> *kernel_attrs) { |
|
|
|
auto inputs_ptr = op_info->inputs_ptr(); |
|
|
|
auto outputs_ptr = op_info->outputs_ptr(); |
|
|
|
if (inputs_ptr.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "op " << op_info->op_name() << " input size is zero."; |
|
|
|
} |
|
|
|
auto first_input_dtypes = inputs_ptr[0]->dtypes(); |
|
|
|
auto input_formats = inputs_ptr[0]->formats(); |
|
|
|
|
|
|
|
@@ -82,8 +85,7 @@ void CPUKernelFactory::UpdateKernelAttrs(const std::string &kernel_name, const s |
|
|
|
std::vector<std::pair<KernelAttr, CPUKernelCreator>> attr_creators(attr_size); |
|
|
|
auto iter = name_to_attr_creator_.find(kernel_name); |
|
|
|
if (iter == name_to_attr_creator_.end()) { |
|
|
|
MS_LOG(ERROR) << "CPUKernelFactory has not registered operator: " << kernel_name; |
|
|
|
return; |
|
|
|
MS_LOG(EXCEPTION) << "CPUKernelFactory has not registered operator: " << kernel_name; |
|
|
|
} |
|
|
|
|
|
|
|
if (attr_size <= iter->second.size()) { |
|
|
|
@@ -113,7 +115,7 @@ std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string & |
|
|
|
if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) { |
|
|
|
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU); |
|
|
|
if (op_info_ptr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Not find op[" << kernel_name << "] in cpu"; |
|
|
|
MS_LOG(EXCEPTION) << "Not find op[" << kernel_name << "] in cpu"; |
|
|
|
} |
|
|
|
kernel_attrs.clear(); |
|
|
|
SetKernelAttrs(op_info_ptr, &kernel_attrs); |
|
|
|
@@ -152,8 +154,7 @@ std::vector<KernelAttr> CPUKernelFactory::GetSupportedKernelAttrList(const std:: |
|
|
|
std::vector<KernelAttr> result; |
|
|
|
auto iter = name_to_attr_creator_.find(kernel_name); |
|
|
|
if (iter == name_to_attr_creator_.end()) { |
|
|
|
MS_LOG(WARNING) << "Not registered CPU kernel: op[" << kernel_name << "]!"; |
|
|
|
return result; |
|
|
|
MS_LOG(EXCEPTION) << "Not registered CPU kernel: op[" << kernel_name << "]!"; |
|
|
|
} |
|
|
|
auto creators = iter->second; |
|
|
|
for (size_t index = 0; index < creators.size(); ++index) { |
|
|
|
|