Browse Source

fix cpu kernel select

tags/v1.1.0
baihuawei 5 years ago
parent
commit
1ca333a3e5
6 changed files with 53 additions and 9 deletions
  1. +3
    -3
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc
  2. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc
  3. +4
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc
  4. +3
    -3
      mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc
  5. +4
    -1
      mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc
  6. +38
    -0
      mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc

+ 3
- 3
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc View File

@@ -132,14 +132,14 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
bool ArithmeticCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool ArithmeticCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/, const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeInt32) {
if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16) {
LaunchKernel<int>(inputs, outputs); LaunchKernel<int>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) {
LaunchKernel<float>(inputs, outputs); LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt64) { } else if (dtype_ == kNumberTypeInt64) {
LaunchKernel<int64_t>(inputs, outputs); LaunchKernel<int64_t>(inputs, outputs);
} else { } else {
MS_LOG(EXCEPTION) << "Only support int32, float32, but actual data type is " << TypeIdLabel(dtype_);
MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support.";
} }
return true; return true;
} }


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc View File

@@ -56,7 +56,7 @@ bool ArithmeticSelfCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu
} else if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt64) { } else if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt64) {
LaunchKernel<int>(inputs, outputs); LaunchKernel<int>(inputs, outputs);
} else { } else {
MS_LOG(EXCEPTION) << "Only support float32, int32, but actual data type is " << TypeIdLabel(dtype_);
MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support.";
} }
return true; return true;
} }


+ 4
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc View File

@@ -20,8 +20,9 @@ namespace kernel {
void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
size_t type_size = sizeof(float);
for (size_t input_index = 0; input_index < input_num; ++input_index) { for (size_t input_index = 0; input_index < input_num; ++input_index) {
TypeId type_id = AnfAlgo::GetInputDeviceDataType(kernel_node, input_index);
size_t type_size = GetTypeByte(TypeIdToType(type_id));
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index); std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index);
size_t tensor_size = size_t tensor_size =
shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
@@ -29,6 +30,8 @@ void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
} }
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) { for (size_t output_index = 0; output_index < output_num; ++output_index) {
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(kernel_node, output_index);
size_t type_size = GetTypeByte(TypeIdToType(type_id));
std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index); std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index);
size_t tensor_size = size_t tensor_size =
shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());


+ 3
- 3
mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc View File

@@ -123,14 +123,14 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
bool EltWiseGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool EltWiseGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/, const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeInt32) {
if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16) {
LaunchKernel<int>(inputs, outputs); LaunchKernel<int>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) {
LaunchKernel<float>(inputs, outputs); LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt64) { } else if (dtype_ == kNumberTypeInt64) {
LaunchKernel<int64_t>(inputs, outputs); LaunchKernel<int64_t>(inputs, outputs);
} else { } else {
MS_LOG(EXCEPTION) << "Only support int32, float32, but actual data type is " << TypeIdLabel(dtype_);
MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support.";
} }
return true; return true;
} }


+ 4
- 1
mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc View File

@@ -54,7 +54,10 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph
} }
auto tensor = node_value->cast<TensorPtr>(); auto tensor = node_value->cast<TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
TypeId output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0);
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item_node, 0);
if (output_type_id == kTypeUnknown) {
output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0);
}
size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
ShapeVector data_shape = tensor->shape(); ShapeVector data_shape = tensor->shape();
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>()); size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());


+ 38
- 0
mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc View File

@@ -205,6 +205,37 @@ void SetKernelBuildInfo(const std::vector<std::string> &input_formats, const std
builder->SetOutputsDeviceType(output_types); builder->SetOutputsDeviceType(output_types);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node);
} }

void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector<TypeId> &input_types,
const std::vector<TypeId> &infer_output_types) {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
std::stringstream operator_info;
operator_info << "Operator[" << kernel_name << "] ";
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num > 0) {
operator_info << " input(";
for (size_t i = 0; i < input_num; ++i) {
operator_info << TypeIdLabel(input_types[i]);
if (i != input_num - 1) {
operator_info << ",";
}
}
operator_info << ") ";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num > 0) {
operator_info << "output(";
for (size_t i = 0; i < output_num; ++i) {
operator_info << TypeIdLabel(infer_output_types[i]);
if (i != output_num - 1) {
operator_info << ",";
}
}
operator_info << ") ";
}
operator_info << "is not support.";
MS_LOG(EXCEPTION) << operator_info.str();
}
} // namespace } // namespace
bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr, bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
const std::vector<KernelAttr> &kernel_attrs, const std::vector<std::string> &input_formats, const std::vector<KernelAttr> &kernel_attrs, const std::vector<std::string> &input_formats,
@@ -275,10 +306,17 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
std::pair<bool, bool> matched = std::make_pair(false, false); std::pair<bool, bool> matched = std::make_pair(false, false);
if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types,
input_not_cnode_indexes, infer_output_formats, infer_output_types, &matched, true)) { input_not_cnode_indexes, infer_output_formats, infer_output_types, &matched, true)) {
if (AnfAlgo::GetCNodeName(kernel_node) == "Cast") {
KernelNotSupportException(kernel_node, input_types, infer_output_types);
}
matched = std::make_pair(false, false); matched = std::make_pair(false, false);
SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, input_not_cnode_indexes, SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, input_not_cnode_indexes,
infer_output_formats, infer_output_types, &matched, false); infer_output_formats, infer_output_types, &matched, false);
if (!matched.first) {
KernelNotSupportException(kernel_node, input_types, infer_output_types);
}
} }

if (selected_kernel_attr.GetInputSize() > 0 && if (selected_kernel_attr.GetInputSize() > 0 &&
(matched.first || input_types.size() == input_not_cnode_indexes.size())) { (matched.first || input_types.size() == input_not_cnode_indexes.size())) {
MS_LOG(INFO) << "Input format and dtype is matched"; MS_LOG(INFO) << "Input format and dtype is matched";


Loading…
Cancel
Save