|
|
|
@@ -78,8 +78,11 @@ static TypeId GetDataType(const py::buffer_info &buf) { |
|
|
|
case '?': |
|
|
|
return TypeId::kNumberTypeBool; |
|
|
|
} |
|
|
|
} else if (buf.format.size() >= 2 && buf.format.back() == 'w') { |
|
|
|
// Support np.str_ dtype, format: {x}w. {x} is a number that means the maximum length of the string items. |
|
|
|
return TypeId::kObjectTypeString; |
|
|
|
} |
|
|
|
MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize; |
|
|
|
MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << ", item size " << buf.itemsize; |
|
|
|
return TypeId::kTypeUnknown; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -109,6 +112,8 @@ static std::string GetPyTypeFormat(TypeId data_type) { |
|
|
|
return py::format_descriptor<int64_t>::format(); |
|
|
|
case TypeId::kNumberTypeBool: |
|
|
|
return py::format_descriptor<bool>::format(); |
|
|
|
case TypeId::kObjectTypeString: |
|
|
|
return py::format_descriptor<uint8_t>::format(); |
|
|
|
default: |
|
|
|
MS_LOG(WARNING) << "Unsupported DataType " << data_type << "."; |
|
|
|
return ""; |
|
|
|
@@ -181,6 +186,10 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) |
|
|
|
if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) { |
|
|
|
MS_LOG(EXCEPTION) << "Unsupported tensor type!"; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "data_type: " << data_type << ", buf_type: " << buf_type; |
|
|
|
if (data_type == TypeId::kObjectTypeString || buf_type == TypeId::kObjectTypeString) { |
|
|
|
return TensorPy::MakeTensorOfNumpy(input); |
|
|
|
} |
|
|
|
// Use buf type as data type if type_ptr not set. |
|
|
|
if (data_type == TypeId::kTypeUnknown) { |
|
|
|
data_type = buf_type; |
|
|
|
@@ -210,7 +219,7 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) |
|
|
|
} |
|
|
|
|
|
|
|
/// Creates a Tensor from a numpy array without copy |
|
|
|
TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) { |
|
|
|
TensorPtr TensorPy::MakeTensorOfNumpy(const py::array &input) { |
|
|
|
// Check format. |
|
|
|
if (!IsCContiguous(input)) { |
|
|
|
MS_LOG(EXCEPTION) << "Array should be C contiguous."; |
|
|
|
@@ -504,7 +513,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { |
|
|
|
>>> data.strides |
|
|
|
(4, 4) |
|
|
|
)mydelimiter") |
|
|
|
.def("from_numpy", TensorPy::MakeTensorNoCopy, R"mydelimiter( |
|
|
|
.def("from_numpy", TensorPy::MakeTensorOfNumpy, R"mydelimiter( |
|
|
|
Creates a Tensor from a numpy.ndarray without copy. |
|
|
|
|
|
|
|
Arg: |
|
|
|
|