| @@ -132,14 +132,14 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| bool ArithmeticCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | bool ArithmeticCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | const std::vector<kernel::AddressPtr> & /*workspace*/, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| if (dtype_ == kNumberTypeInt32) { | |||||
| if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16) { | |||||
| LaunchKernel<int>(inputs, outputs); | LaunchKernel<int>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeFloat32) { | |||||
| } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) { | |||||
| LaunchKernel<float>(inputs, outputs); | LaunchKernel<float>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeInt64) { | } else if (dtype_ == kNumberTypeInt64) { | ||||
| LaunchKernel<int64_t>(inputs, outputs); | LaunchKernel<int64_t>(inputs, outputs); | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Only support int32, float32, but actual data type is " << TypeIdLabel(dtype_); | |||||
| MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -56,7 +56,7 @@ bool ArithmeticSelfCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu | |||||
| } else if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt64) { | } else if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt64) { | ||||
| LaunchKernel<int>(inputs, outputs); | LaunchKernel<int>(inputs, outputs); | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Only support float32, int32, but actual data type is " << TypeIdLabel(dtype_); | |||||
| MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -20,8 +20,9 @@ namespace kernel { | |||||
| void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| size_t type_size = sizeof(float); | |||||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | for (size_t input_index = 0; input_index < input_num; ++input_index) { | ||||
| TypeId type_id = AnfAlgo::GetInputDeviceDataType(kernel_node, input_index); | |||||
| size_t type_size = GetTypeByte(TypeIdToType(type_id)); | |||||
| std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index); | std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index); | ||||
| size_t tensor_size = | size_t tensor_size = | ||||
| shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); | shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); | ||||
| @@ -29,6 +30,8 @@ void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | ||||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | for (size_t output_index = 0; output_index < output_num; ++output_index) { | ||||
| TypeId type_id = AnfAlgo::GetOutputDeviceDataType(kernel_node, output_index); | |||||
| size_t type_size = GetTypeByte(TypeIdToType(type_id)); | |||||
| std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index); | std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index); | ||||
| size_t tensor_size = | size_t tensor_size = | ||||
| shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); | shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); | ||||
| @@ -123,14 +123,14 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| bool EltWiseGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | bool EltWiseGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | const std::vector<kernel::AddressPtr> & /*workspace*/, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| if (dtype_ == kNumberTypeInt32) { | |||||
| if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16) { | |||||
| LaunchKernel<int>(inputs, outputs); | LaunchKernel<int>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeFloat32) { | |||||
| } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) { | |||||
| LaunchKernel<float>(inputs, outputs); | LaunchKernel<float>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeInt64) { | } else if (dtype_ == kNumberTypeInt64) { | ||||
| LaunchKernel<int64_t>(inputs, outputs); | LaunchKernel<int64_t>(inputs, outputs); | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Only support int32, float32, but actual data type is " << TypeIdLabel(dtype_); | |||||
| MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -54,7 +54,10 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph | |||||
| } | } | ||||
| auto tensor = node_value->cast<TensorPtr>(); | auto tensor = node_value->cast<TensorPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(tensor); | MS_EXCEPTION_IF_NULL(tensor); | ||||
| TypeId output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0); | |||||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item_node, 0); | |||||
| if (output_type_id == kTypeUnknown) { | |||||
| output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0); | |||||
| } | |||||
| size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); | size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); | ||||
| ShapeVector data_shape = tensor->shape(); | ShapeVector data_shape = tensor->shape(); | ||||
| size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>()); | size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>()); | ||||
| @@ -205,6 +205,37 @@ void SetKernelBuildInfo(const std::vector<std::string> &input_formats, const std | |||||
| builder->SetOutputsDeviceType(output_types); | builder->SetOutputsDeviceType(output_types); | ||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node); | AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node); | ||||
| } | } | ||||
| void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector<TypeId> &input_types, | |||||
| const std::vector<TypeId> &infer_output_types) { | |||||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| std::stringstream operator_info; | |||||
| operator_info << "Operator[" << kernel_name << "] "; | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num > 0) { | |||||
| operator_info << " input("; | |||||
| for (size_t i = 0; i < input_num; ++i) { | |||||
| operator_info << TypeIdLabel(input_types[i]); | |||||
| if (i != input_num - 1) { | |||||
| operator_info << ","; | |||||
| } | |||||
| } | |||||
| operator_info << ") "; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num > 0) { | |||||
| operator_info << "output("; | |||||
| for (size_t i = 0; i < output_num; ++i) { | |||||
| operator_info << TypeIdLabel(infer_output_types[i]); | |||||
| if (i != output_num - 1) { | |||||
| operator_info << ","; | |||||
| } | |||||
| } | |||||
| operator_info << ") "; | |||||
| } | |||||
| operator_info << "is not support."; | |||||
| MS_LOG(EXCEPTION) << operator_info.str(); | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr, | bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr, | ||||
| const std::vector<KernelAttr> &kernel_attrs, const std::vector<std::string> &input_formats, | const std::vector<KernelAttr> &kernel_attrs, const std::vector<std::string> &input_formats, | ||||
| @@ -275,10 +306,17 @@ void SetKernelInfo(const CNodePtr &kernel_node) { | |||||
| std::pair<bool, bool> matched = std::make_pair(false, false); | std::pair<bool, bool> matched = std::make_pair(false, false); | ||||
| if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, | if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, | ||||
| input_not_cnode_indexes, infer_output_formats, infer_output_types, &matched, true)) { | input_not_cnode_indexes, infer_output_formats, infer_output_types, &matched, true)) { | ||||
| if (AnfAlgo::GetCNodeName(kernel_node) == "Cast") { | |||||
| KernelNotSupportException(kernel_node, input_types, infer_output_types); | |||||
| } | |||||
| matched = std::make_pair(false, false); | matched = std::make_pair(false, false); | ||||
| SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, input_not_cnode_indexes, | SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, input_not_cnode_indexes, | ||||
| infer_output_formats, infer_output_types, &matched, false); | infer_output_formats, infer_output_types, &matched, false); | ||||
| if (!matched.first) { | |||||
| KernelNotSupportException(kernel_node, input_types, infer_output_types); | |||||
| } | |||||
| } | } | ||||
| if (selected_kernel_attr.GetInputSize() > 0 && | if (selected_kernel_attr.GetInputSize() > 0 && | ||||
| (matched.first || input_types.size() == input_not_cnode_indexes.size())) { | (matched.first || input_types.size() == input_not_cnode_indexes.size())) { | ||||
| MS_LOG(INFO) << "Input format and dtype is matched"; | MS_LOG(INFO) << "Input format and dtype is matched"; | ||||