Browse Source

support dynamic inputs & outputs

tags/v1.4.0
wilfChen 4 years ago
parent
commit
d6fffdad6e
8 changed files with 500 additions and 1 deletions
  1. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt
  2. +6
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc
  3. +295
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/pyfunc/py_func_cpu_kernel.cc
  4. +76
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/pyfunc/py_func_cpu_kernel.h
  5. +36
    -0
      mindspore/ccsrc/runtime/device/cpu/cpu_common.h
  6. +73
    -1
      mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc
  7. +2
    -0
      mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.h
  8. +11
    -0
      tests/ut/cpp/stub/runtime/kernel_select_cpu.cc

+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt View File

@@ -35,6 +35,7 @@ if(ENABLE_CPU)
"cpu/fl/*.cc"
"cpu/ps/*.cc"
"cpu/quantum/*.cc"
"cpu/pyfunc/*.cc"
)

if(NOT ENABLE_MPI)


+ 6
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc View File

@@ -21,6 +21,7 @@
#include <string>

#include "runtime/device/kernel_info.h"
#include "runtime/device/cpu/kernel_select_cpu.h"

namespace mindspore {
namespace kernel {
@@ -111,6 +112,11 @@ std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &
MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!";
return std::make_pair(false, 0);
}

if (device::cpu::IsDynamicParamKernel(kernel_name)) {
return std::make_pair(true, 0);
}

auto kernel_attrs = GetSupportedKernelAttrList(kernel_name);
if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);


+ 295
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/pyfunc/py_func_cpu_kernel.cc View File

@@ -0,0 +1,295 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/pyfunc/py_func_cpu_kernel.h"

#include <memory>
#include <vector>
#include "Eigen/Core"
#include "Eigen/src/Core/arch/CUDA/Half.h"
#include "abstract/utils.h"
#include "runtime/device/cpu/cpu_common.h"
#include "pybind_api/ir/tensor_py.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"

namespace mindspore {
namespace kernel {
namespace {
py::object RawMemoryToScalar(const void *data, const TypePtr &type) {
switch (type->type_id()) {
case kNumberTypeBool:
return py::bool_(*reinterpret_cast<const bool *>(data));
case kNumberTypeInt16:
return py::int_(*reinterpret_cast<const int16_t *>(data));
case kNumberTypeUInt16:
return py::int_(*reinterpret_cast<const uint16_t *>(data));
case kNumberTypeInt8:
return py::int_(*reinterpret_cast<const int8_t *>(data));
case kNumberTypeUInt8:
return py::int_(*reinterpret_cast<const uint8_t *>(data));
case kNumberTypeInt32:
return py::int_(*reinterpret_cast<const int32_t *>(data));
case kNumberTypeUInt32:
return py::int_(*reinterpret_cast<const uint32_t *>(data));
case kNumberTypeInt64:
return py::int_(*reinterpret_cast<const int64_t *>(data));
case kNumberTypeUInt64:
return py::int_(*reinterpret_cast<const uint64_t *>(data));
case kNumberTypeFloat16: {
const Eigen::half_impl::__half_raw data_half(*reinterpret_cast<const uint16_t *>(data));
return py::float_(Eigen::half_impl::half_to_float(data_half));
}
case kNumberTypeFloat32:
return py::float_(*reinterpret_cast<const float *>(data));
case kNumberTypeFloat64:
return py::float_(*reinterpret_cast<const double *>(data));
default:
MS_LOG(EXCEPTION) << "Type: " << type->type_id() << " not supported.";
}
}

void ScalarToRawMemory(const py::object &obj, const TypePtr &type, const AddressPtr &address) {
switch (type->type_id()) {
case kNumberTypeBool: {
bool data = py::cast<bool>(obj);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, &data, sizeof(bool)), EOK, "memcpy failed.");
return;
}
// ref: pybind11-src/include/pybind11/pytypes.h
// py::int_ convert py::object to `long`, `unsigned long`, `long long`, `unsigned long long` with Python API
// according to typename T, and then convert to target data type with C style cast.
case kNumberTypeInt8: {
int8_t data = py::cast<int8_t>(obj);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, &data, sizeof(int8_t)), EOK, "memcpy failed.");
return;
}
case kNumberTypeUInt8: {
uint8_t data = py::cast<uint8_t>(obj);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, &data, sizeof(uint8_t)), EOK, "memcpy failed.");
return;
}
case kNumberTypeInt16: {
int16_t data = py::cast<int16_t>(obj);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, &data, sizeof(int16_t)), EOK, "memcpy failed.");
return;
}
case kNumberTypeUInt16: {
uint8_t data = py::cast<uint8_t>(obj);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, &data, sizeof(uint8_t)), EOK, "memcpy failed.");
return;
}
case kNumberTypeInt32: {
int32_t data = py::cast<int32_t>(obj);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, &data, sizeof(int32_t)), EOK, "memcpy failed.");
return;
}
case kNumberTypeUInt32: {
uint32_t data = py::cast<uint32_t>(obj);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, &data, sizeof(uint32_t)), EOK, "memcpy failed.");
return;
}
case kNumberTypeInt64: {
int64_t data = py::cast<int64_t>(obj);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, &data, sizeof(int64_t)), EOK, "memcpy failed.");
return;
}
case kNumberTypeUInt64: {
uint64_t data = py::cast<uint64_t>(obj);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, &data, sizeof(uint64_t)), EOK, "memcpy failed.");
return;
}
case kNumberTypeFloat16: {
float data = py::cast<float>(obj);
Eigen::half_impl::__half_raw data_half = Eigen::half_impl::float_to_half_rtne(data);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, &data_half.x, sizeof(data_half.x)), EOK,
"memcpy failed.");
return;
}
case kNumberTypeFloat32: {
float data = py::cast<float>(obj);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, &data, sizeof(float)), EOK, "memcpy failed.");
return;
}
case kNumberTypeFloat64: {
float data = py::cast<double>(obj);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, &data, sizeof(double)), EOK, "memcpy failed.");
return;
}
default:
MS_LOG(EXCEPTION) << "Type: " << type->type_id() << " not supported.";
}
}

void ArrayToRawMemory(const py::array &array, const AddressPtr &address) {
if (static_cast<unsigned int>(array.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) {
const py::buffer_info &buf_info = array.request();
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, buf_info.ptr, buf_info.size), EOK, "memcpy failed.");
} else {
// Transform numpy array to row major buffer.
Py_buffer pybuf;
if (PyObject_GetBuffer(array.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS)) {
MS_LOG(EXCEPTION) << "Failed to get buffer from the input!";
}

auto buffer = std::make_unique<char[]>(pybuf.len);
if (PyBuffer_ToContiguous(buffer.get(), &pybuf, pybuf.len, 'C')) {
PyBuffer_Release(&pybuf);
MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer.";
}
PyBuffer_Release(&pybuf);
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, buffer.get(), pybuf.len), EOK, "memcpy failed.");
}
}

void ObjectToRawMemory(const py::object &object, const PythonOjectType &object_type, const TypePtr &data_type,
const AddressPtr &address) {
switch (object_type) {
case PythonOjectType::kScalar:
return ScalarToRawMemory(object, data_type, address);
case PythonOjectType::kNumpyArray:
return ArrayToRawMemory(object.cast<py::array>(), address);
default:
MS_LOG(EXCEPTION) << "python object not supported. type: " << object_type;
}
}

py::tuple RawMemoryToPyObjects(const std::vector<AddressPtr> &inputs, const PyFuncArgumentInfo &input_infos,
const std::vector<tensor::TensorPtr> &input_tensors) {
py::tuple result(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
switch (input_infos.object_types[i]) {
case PythonOjectType::kScalar:
result[i] = RawMemoryToScalar(inputs[i]->addr, input_infos.dtypes[i]);
break;
case PythonOjectType::kNumpyArray: {
const tensor::TensorPtr &tensor = input_tensors[i];
CHECK_RET_WITH_EXCEPT(memcpy_s(tensor->data_c(), tensor->Size(), inputs[i]->addr, inputs[i]->size), EOK,
"memcpy failed.");
result[i] = tensor::TensorPy::AsNumpy(*tensor);
break;
}
default:
MS_LOG(EXCEPTION) << "Python args not support. Index: " << i << ", type" << input_infos.object_types[i];
}
}
return result;
}

void PyObjectToRawMemorys(const py::object &object, const PyFuncArgumentInfo &output_infos,
const std::vector<AddressPtr> &outputs) {
// Single output.
if (!py::isinstance<py::tuple>(object)) {
if (outputs.size() != 1) {
MS_LOG(EXCEPTION) << "The output num is 1, with " << outputs.size() << " expect.";
}
return ObjectToRawMemory(object, output_infos.object_types[0], output_infos.dtypes[0], outputs[0]);
}

// Multiply outputs.
auto result_tuple = object.cast<py::tuple>();
if (result_tuple.size() != outputs.size()) {
MS_LOG(EXCEPTION) << "The output num is: " << result_tuple.size() << ", with " << outputs.size() << " expect.";
}

for (size_t i = 0; i < outputs.size(); i++) {
ObjectToRawMemory(result_tuple[i], output_infos.object_types[i], output_infos.dtypes[i], outputs[i]);
}
}
} // namespace

void PyFuncCpuKernel::InitKernel(const CNodePtr &kernel_node) {
func_id_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "fn_id");
BuildFuncInfo(kernel_node);
}

bool PyFuncCpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
if (!init_) {
py_func_ = GetPythonFunc(func_id_);
init_ = true;
}

return ExecuteKernel(inputs, outputs);
}

void PyFuncCpuKernel::BuildFuncInfo(const CNodePtr &kernel_node) {
const auto &in_shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "in_shapes");
const auto &in_types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "in_types");
const auto &out_shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "out_shapes");
const auto &out_types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "out_types");

input_infos_.dtypes = in_types;
input_infos_.shapes = in_shapes;
for (size_t i = 0; i < in_shapes.size(); i++) {
auto tensor = std::make_shared<tensor::Tensor>(in_types[i]->type_id(), in_shapes[i]);
input_tensors_.push_back(tensor);

const auto &object_type = in_shapes[i].empty() ? PythonOjectType::kScalar : PythonOjectType::kNumpyArray;
input_infos_.object_types.emplace_back(object_type);
}

output_infos_.dtypes = out_types;
output_infos_.shapes = out_shapes;
for (size_t j = 0; j < out_shapes.size(); j++) {
const auto &object_type = out_shapes[j].empty() ? PythonOjectType::kScalar : PythonOjectType::kNumpyArray;
output_infos_.object_types.emplace_back(object_type);
}
}

bool PyFuncCpuKernel::ExecuteKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
if (Py_IsInitialized() != true) {
MS_LOG(ERROR) << "Py_IsInitialized failed.";
return false;
}

py::gil_scoped_acquire gil_acquire;
py::object result;
if (inputs.size()) {
py::tuple args = RawMemoryToPyObjects(inputs, input_infos_, input_tensors_);
result = py_func_(*args);
} else {
result = py_func_();
}

if (output_infos_.shapes.empty()) {
return true;
}

PyObjectToRawMemorys(result, output_infos_, outputs);
return true;
}

py::function PyFuncCpuKernel::GetPythonFunc(const int64_t &func_id) {
py::gil_scoped_acquire gil_acquire;
static const std::string &module_name = "mindspore.ops.operations.other_ops";
static const std::string &func_name = "get_pyfunc";
py::module module = py::module::import(module_name.c_str());
py::object get_pyfunc_obj = module.attr(func_name.c_str());
if (get_pyfunc_obj.is_none()) {
MS_LOG(EXCEPTION) << "Cannot find a python function named " << func_name << "in module" << module_name;
}

py::function get_pyfunc = get_pyfunc_obj.cast<py::function>();
py::object py_func_obj = get_pyfunc(py::int_(func_id));
if (py_func_obj.is_none()) {
MS_LOG(EXCEPTION) << "Cannot find python func with id: " << func_id;
}

return py_func_obj.cast<py::function>();
}

MS_REG_CPU_KERNEL(PyFunc, KernelAttr(), PyFuncCpuKernel)
} // namespace kernel
} // namespace mindspore

+ 76
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/pyfunc/py_func_cpu_kernel.h View File

@@ -0,0 +1,76 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PYFUNC_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PYFUNC_KERNEL_H_

#include <memory>
#include <string>
#include <vector>
#include <Python.h>
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#include "backend/kernel_compiler/cpu/cpu_kernel.h"

namespace py = pybind11;
namespace mindspore {
namespace kernel {
// Indicate Python object type. The input/output of PyFun should be either Scalar or Numpy Array.
enum class PythonOjectType : char { kScalar, kNumpyArray };
// Indicate PyFunc input/output information
struct PyFuncArgumentInfo {
// Empty vector indicate the Python object is Scalar and non-empty means Numpy Array.
std::vector<std::vector<int64_t>> shapes;
// Data type as int, float, bool.
std::vector<TypePtr> dtypes;
// Python object type
std::vector<PythonOjectType> object_types;
};

class PyFuncCpuKernel : public CPUKernel {
public:
PyFuncCpuKernel() : init_(false), func_id_(-1) {}
~PyFuncCpuKernel() = default;

// Init kernel including analyse PyFunc input and output info.
void InitKernel(const CNodePtr &kernel_node) override;
// Construct arguments with raw memory, invoke Python function and then convert result to raw memory.
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

protected:
// Analyse PyFunc input/output spec.
void BuildFuncInfo(const CNodePtr &kernel_node);
// Get Python function from anchor.
py::function GetPythonFunc(const int64_t &func_id);
bool ExecuteKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);

bool init_;
// The Python object is not acceptable for `Primitive` attribute. So we pass an unique key instead of Python function.
// ME store the Python function to a dict, and pass the key to backend kernel.
// The kernel get the Python functhon by the key from the dict when the kernel is first invoked.
size_t func_id_;
py::function py_func_;
// Input and output specifications.
PyFuncArgumentInfo input_infos_;
PyFuncArgumentInfo output_infos_;
// The kernel hold the input tensors during execution to avoid dynamic malloc/free host memory.
std::vector<std::shared_ptr<tensor::Tensor>> input_tensors_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PYFUNC_KERNEL_H_

+ 36
- 0
mindspore/ccsrc/runtime/device/cpu/cpu_common.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_COMMON_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_COMMON_H_

#include "utils/log_adapter.h"

namespace mindspore {
namespace device {
namespace cpu {
#define CHECK_RET_WITH_EXCEPT(expression, status, message) \
{ \
auto ret = (expression); \
if (ret != status) { \
MS_LOG(EXCEPTION) << message; \
} \
}
} // namespace cpu
} // namespace device
} // namespace mindspore

#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_CPU_COMMON_H_

+ 73
- 1
mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc View File

@@ -31,6 +31,8 @@ namespace cpu {
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
using mindspore::kernel::KernelBuildInfo;
namespace {
constexpr auto kParamDynamic = "dynamic";

bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) {
auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first;
MS_EXCEPTION_IF_NULL(input_node);
@@ -66,6 +68,13 @@ void GetOutputDtypes(const CNodePtr &kernel_node, std::vector<TypeId> *output_ty
}
}

void GetOutputFormat(const CNodePtr &kernel_node, std::vector<std::string> *output_formats) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
output_formats->emplace_back(kOpFormat_DEFAULT);
}
}

void GetInputDtypes(const CNodePtr &kernel_node, std::vector<TypeId> *input_types,
std::vector<size_t> *input_no_cnode_indexes) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
@@ -81,6 +90,13 @@ void GetInputDtypes(const CNodePtr &kernel_node, std::vector<TypeId> *input_type
}
}

void GetInputFormat(const CNodePtr &kernel_node, std::vector<std::string> *input_formats) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
input_formats->emplace_back(kOpFormat_DEFAULT);
}
}

void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr,
std::vector<std::string> *output_formats, std::vector<TypeId> *output_types) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
@@ -200,7 +216,57 @@ void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector<
operator_info << "is not support.";
MS_EXCEPTION(TypeError) << operator_info.str() << " Trace: " << trace::DumpSourceLines(kernel_node);
}

void UpdateDynamicKernelBuildInfoAndAttrs(const CNodePtr &kernel_node) {
const std::string &op_name = AnfAlgo::GetCNodeName(kernel_node);
MS_LOG(INFO) << "Operator name: " << op_name;
// Set kernel build info
std::vector<TypeId> input_types;
std::vector<size_t> input_not_cnode_indexes;
GetInputDtypes(kernel_node, &input_types, &input_not_cnode_indexes);
std::vector<TypeId> output_types;
GetOutputDtypes(kernel_node, &output_types);
std::vector<std::string> input_formats;
GetInputFormat(kernel_node, &input_formats);
std::vector<std::string> output_formats;
GetOutputFormat(kernel_node, &output_formats);
SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get());

// Set kernel attrs
KernelAttr attr;
for (size_t i = 0; i < input_types.size(); i++) {
attr.AddInputAttr(input_types[i]);
}
for (size_t j = 0; j < output_types.size(); j++) {
attr.AddInputAttr(output_types[j]);
}
std::vector<KernelAttr> kernel_attrs =
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
kernel_attrs.emplace_back(attr);
kernel::CPUKernelFactory::GetInstance().UpdateKernelAttrs(op_name, kernel_attrs);
return;
}
} // namespace

bool IsDynamicParamKernel(const std::string &op_name) {
const auto &op_info = kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kCPU);
if (op_info == nullptr) {
return false;
}

const auto &input_io_info = op_info->inputs_ptr();
if (input_io_info.size() != 1 || input_io_info[0]->param_type() != kParamDynamic) {
return false;
}

const auto &output_io_info = op_info->outputs_ptr();
if (output_io_info.size() != 1 || output_io_info[0]->param_type() != kParamDynamic) {
return false;
}

return true;
}

bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
const std::vector<KernelAttr> &kernel_attrs, const std::vector<TypeId> &input_types,
const std::vector<size_t> &input_not_cnode_indexes, const std::vector<TypeId> &output_types,
@@ -229,7 +295,14 @@ bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
}
return false;
}

void SetKernelInfo(const CNodePtr &kernel_node) {
// Select for dynamic kernel(both the number and data type are undetermined).
const std::string &op_name = AnfAlgo::GetCNodeName(kernel_node);
if (IsDynamicParamKernel(op_name)) {
return UpdateDynamicKernelBuildInfoAndAttrs(kernel_node);
}

std::vector<std::string> input_formats;
std::vector<TypeId> input_types;
std::vector<size_t> input_not_cnode_indexes;
@@ -241,7 +314,6 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
if (kernel_attrs.empty() || (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0)) {
MS_LOG(DEBUG) << "Operator[" << AnfAlgo::GetCNodeName(kernel_node) << "] will get ops attr info.";
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kCPU);
if (op_info_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Not find op[" << op_name << "] in cpu";


+ 2
- 0
mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.h View File

@@ -29,6 +29,8 @@ namespace mindspore {
namespace device {
namespace cpu {
void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
// Indicate whether the kernel input/output number are variable.
bool IsDynamicParamKernel(const std::string &op_name);

class KernelAttr {
public:


+ 11
- 0
tests/ut/cpp/stub/runtime/kernel_select_cpu.cc View File

@@ -0,0 +1,11 @@
#include <string>

namespace mindspore {
namespace device {
namespace cpu {

bool IsDynamicParamKernel(const std::string &op_name) { return false; }

} // namespace cpu
} // namespace device
} // namespace mindspore

Loading…
Cancel
Save