| @@ -132,14 +132,14 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| bool ArithmeticCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeInt32) { | |||
| if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16) { | |||
| LaunchKernel<int>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32) { | |||
| } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) { | |||
| LaunchKernel<float>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeInt64) { | |||
| LaunchKernel<int64_t>(inputs, outputs); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Only support int32, float32, but actual data type is " << TypeIdLabel(dtype_); | |||
| MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -56,7 +56,7 @@ bool ArithmeticSelfCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu | |||
| } else if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt64) { | |||
| LaunchKernel<int>(inputs, outputs); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Only support float32, int32, but actual data type is " << TypeIdLabel(dtype_); | |||
| MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -20,8 +20,9 @@ namespace kernel { | |||
| void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| size_t type_size = sizeof(float); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| TypeId type_id = AnfAlgo::GetInputDeviceDataType(kernel_node, input_index); | |||
| size_t type_size = GetTypeByte(TypeIdToType(type_id)); | |||
| std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index); | |||
| size_t tensor_size = | |||
| shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); | |||
| @@ -29,6 +30,8 @@ void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| TypeId type_id = AnfAlgo::GetOutputDeviceDataType(kernel_node, output_index); | |||
| size_t type_size = GetTypeByte(TypeIdToType(type_id)); | |||
| std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index); | |||
| size_t tensor_size = | |||
| shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); | |||
| @@ -123,14 +123,14 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| bool EltWiseGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeInt32) { | |||
| if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16) { | |||
| LaunchKernel<int>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32) { | |||
| } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) { | |||
| LaunchKernel<float>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeInt64) { | |||
| LaunchKernel<int64_t>(inputs, outputs); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Only support int32, float32, but actual data type is " << TypeIdLabel(dtype_); | |||
| MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -54,7 +54,10 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph | |||
| } | |||
| auto tensor = node_value->cast<TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| TypeId output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0); | |||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item_node, 0); | |||
| if (output_type_id == kTypeUnknown) { | |||
| output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0); | |||
| } | |||
| size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); | |||
| ShapeVector data_shape = tensor->shape(); | |||
| size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>()); | |||
| @@ -205,6 +205,37 @@ void SetKernelBuildInfo(const std::vector<std::string> &input_formats, const std | |||
| builder->SetOutputsDeviceType(output_types); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node); | |||
| } | |||
| void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector<TypeId> &input_types, | |||
| const std::vector<TypeId> &infer_output_types) { | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| std::stringstream operator_info; | |||
| operator_info << "Operator[" << kernel_name << "] "; | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num > 0) { | |||
| operator_info << " input("; | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| operator_info << TypeIdLabel(input_types[i]); | |||
| if (i != input_num - 1) { | |||
| operator_info << ","; | |||
| } | |||
| } | |||
| operator_info << ") "; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num > 0) { | |||
| operator_info << "output("; | |||
| for (size_t i = 0; i < output_num; ++i) { | |||
| operator_info << TypeIdLabel(infer_output_types[i]); | |||
| if (i != output_num - 1) { | |||
| operator_info << ","; | |||
| } | |||
| } | |||
| operator_info << ") "; | |||
| } | |||
| operator_info << "is not support."; | |||
| MS_LOG(EXCEPTION) << operator_info.str(); | |||
| } | |||
| } // namespace | |||
| bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr, | |||
| const std::vector<KernelAttr> &kernel_attrs, const std::vector<std::string> &input_formats, | |||
| @@ -275,10 +306,17 @@ void SetKernelInfo(const CNodePtr &kernel_node) { | |||
| std::pair<bool, bool> matched = std::make_pair(false, false); | |||
| if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, | |||
| input_not_cnode_indexes, infer_output_formats, infer_output_types, &matched, true)) { | |||
| if (AnfAlgo::GetCNodeName(kernel_node) == "Cast") { | |||
| KernelNotSupportException(kernel_node, input_types, infer_output_types); | |||
| } | |||
| matched = std::make_pair(false, false); | |||
| SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, input_not_cnode_indexes, | |||
| infer_output_formats, infer_output_types, &matched, false); | |||
| if (!matched.first) { | |||
| KernelNotSupportException(kernel_node, input_types, infer_output_types); | |||
| } | |||
| } | |||
| if (selected_kernel_attr.GetInputSize() > 0 && | |||
| (matched.first || input_types.size() == input_not_cnode_indexes.size())) { | |||
| MS_LOG(INFO) << "Input format and dtype is matched"; | |||