You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_py.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/tensor_py.h"
  17. #include <functional>
  18. #include <numeric>
  19. #include <vector>
  20. #include <sstream>
  21. #include <string>
  22. #include "pybind_api/api_register.h"
  23. #include "pybind_api/export_flags.h"
  24. #include "abstract/abstract_value.h"
  25. namespace mindspore {
  26. namespace tensor {
  27. static TypeId GetDataType(const py::buffer_info &buf) {
  28. if (buf.format.size() == 1) {
  29. switch (buf.format.front()) {
  30. case 'e':
  31. case 'f':
  32. case 'd':
  33. switch (buf.itemsize) {
  34. case 2:
  35. return TypeId::kNumberTypeFloat16;
  36. case 4:
  37. return TypeId::kNumberTypeFloat32;
  38. case 8:
  39. return TypeId::kNumberTypeFloat64;
  40. }
  41. break;
  42. case 'b':
  43. case 'h':
  44. case 'i':
  45. case 'l':
  46. case 'q':
  47. switch (buf.itemsize) {
  48. case 1:
  49. return TypeId::kNumberTypeInt8;
  50. case 2:
  51. return TypeId::kNumberTypeInt16;
  52. case 4:
  53. return TypeId::kNumberTypeInt32;
  54. case 8:
  55. return TypeId::kNumberTypeInt64;
  56. }
  57. break;
  58. case 'B':
  59. case 'H':
  60. case 'I':
  61. case 'L':
  62. case 'Q':
  63. switch (buf.itemsize) {
  64. case 1:
  65. return TypeId::kNumberTypeUInt8;
  66. case 2:
  67. return TypeId::kNumberTypeUInt16;
  68. case 4:
  69. return TypeId::kNumberTypeUInt32;
  70. case 8:
  71. return TypeId::kNumberTypeUInt64;
  72. }
  73. break;
  74. case '?':
  75. return TypeId::kNumberTypeBool;
  76. }
  77. }
  78. MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize;
  79. return TypeId::kTypeUnknown;
  80. }
  81. static std::string GetPyTypeFormat(TypeId data_type) {
  82. switch (data_type) {
  83. case TypeId::kNumberTypeFloat16:
  84. return "e";
  85. case TypeId::kNumberTypeFloat32:
  86. return py::format_descriptor<float>::format();
  87. case TypeId::kNumberTypeFloat64:
  88. return py::format_descriptor<double>::format();
  89. case TypeId::kNumberTypeUInt8:
  90. return py::format_descriptor<uint8_t>::format();
  91. case TypeId::kNumberTypeUInt16:
  92. return py::format_descriptor<uint16_t>::format();
  93. case TypeId::kNumberTypeUInt32:
  94. return py::format_descriptor<uint32_t>::format();
  95. case TypeId::kNumberTypeUInt64:
  96. return py::format_descriptor<uint64_t>::format();
  97. case TypeId::kNumberTypeInt8:
  98. return py::format_descriptor<int8_t>::format();
  99. case TypeId::kNumberTypeInt16:
  100. return py::format_descriptor<int16_t>::format();
  101. case TypeId::kNumberTypeInt32:
  102. return py::format_descriptor<int32_t>::format();
  103. case TypeId::kNumberTypeInt64:
  104. return py::format_descriptor<int64_t>::format();
  105. case TypeId::kNumberTypeBool:
  106. return py::format_descriptor<bool>::format();
  107. default:
  108. MS_LOG(WARNING) << "Unsupported DataType " << data_type << ".";
  109. return "";
  110. }
  111. }
  112. static bool IsCContiguous(const py::array &input) {
  113. auto flags = static_cast<unsigned int>(input.flags());
  114. return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0;
  115. }
  116. TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) {
  117. // Get input buffer info.
  118. py::buffer_info buf = input.request();
  119. // Check data types.
  120. auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kTypeUnknown;
  121. auto buf_type = GetDataType(buf);
  122. if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) {
  123. MS_LOG(EXCEPTION) << "Unsupported tensor type!";
  124. }
  125. // Use buf type as data type if type_ptr not set.
  126. if (data_type == TypeId::kTypeUnknown) {
  127. data_type = buf_type;
  128. }
  129. // Convert input array to C contiguous if need.
  130. std::unique_ptr<char[]> tmp_buf;
  131. if (!IsCContiguous(input)) {
  132. Py_buffer pybuf;
  133. if (PyObject_GetBuffer(input.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS)) {
  134. MS_LOG(EXCEPTION) << "Failed to get buffer from the input!";
  135. }
  136. tmp_buf = std::make_unique<char[]>(pybuf.len);
  137. if (PyBuffer_ToContiguous(tmp_buf.get(), &pybuf, pybuf.len, 'C')) {
  138. MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer.";
  139. }
  140. PyBuffer_Release(&pybuf);
  141. buf.ptr = tmp_buf.get();
  142. }
  143. // Get tensor shape.
  144. std::vector<int> shape(buf.shape.begin(), buf.shape.end());
  145. if (data_type == buf_type) {
  146. // Use memory copy if input data type is same as the required type.
  147. return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf.size * buf.itemsize);
  148. }
  149. // Create tensor with data type converted.
  150. return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf_type);
  151. }
  152. static std::vector<ssize_t> GetStrides(const std::vector<ssize_t> &shape, ssize_t item_size) {
  153. std::vector<ssize_t> strides;
  154. strides.reserve(shape.size());
  155. const auto ndim = shape.size();
  156. for (size_t i = 0; i < ndim; ++i) {
  157. auto stride = item_size;
  158. for (size_t j = i + 1; j < ndim; ++j) {
  159. stride *= shape[j];
  160. }
  161. strides.push_back(stride);
  162. }
  163. return strides;
  164. }
  165. static py::buffer_info GetPyBufferInfo(const Tensor &tensor) {
  166. std::vector<ssize_t> shape(tensor.shape().begin(), tensor.shape().end());
  167. std::vector<ssize_t> strides = GetStrides(shape, tensor.data().itemsize());
  168. return py::buffer_info{
  169. tensor.data_c(), tensor.data().itemsize(), GetPyTypeFormat(tensor.data_type()), tensor.DataDim(), shape, strides};
  170. }
  171. py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) {
  172. auto &shape = tensor.shape();
  173. py::tuple dims(shape.size());
  174. for (size_t i = 0; i < dims.size(); ++i) {
  175. dims[i] = py::int_(shape[i]);
  176. }
  177. return dims;
  178. }
  179. py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
  180. tensor.data_sync();
  181. auto info = GetPyBufferInfo(tensor);
  182. py::object self = py::cast(&tensor);
  183. return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self);
  184. }
  185. py::array TensorPy::AsNumpy(const Tensor &tensor) {
  186. auto info = GetPyBufferInfo(tensor);
  187. py::object self = py::cast(&tensor);
  188. return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self);
  189. }
  190. static std::vector<int> GetShapeFromTuple(const py::tuple &tuple) {
  191. std::vector<int> shape;
  192. const size_t size = tuple.size();
  193. shape.reserve(tuple.size());
  194. for (size_t i = 0; i < size; ++i) {
  195. shape.push_back(py::int_(tuple[i]));
  196. }
  197. return shape;
  198. }
  199. REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
  200. // Define python MetaTensor class.
  201. (void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor")
  202. .def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape"))
  203. .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
  204. .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.")
  205. .def(py::pickle(
  206. [](const MetaTensor &t) { // __getstate__
  207. /* Return a tuple that fully encodes the state of the object */
  208. return py::make_tuple(static_cast<int>(t.data_type()), t.shape());
  209. },
  210. [](const py::tuple &t) { // __setstate__
  211. if (t.size() != 2) {
  212. throw std::runtime_error("Invalid state!");
  213. }
  214. /* Create a new C++ instance */
  215. MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<std::vector<int>>());
  216. return tensor;
  217. }));
  218. // Define python Tensor class.
  219. // dtype should define before Tensor, because Tensor init depend dtype
  220. (void)py::class_<Tensor, MetaTensor, std::shared_ptr<Tensor>>(*m, "Tensor")
  221. .def(py::init([](const Tensor &tensor) { return std::make_shared<Tensor>(tensor); }),
  222. py::arg("input"))
  223. .def(py::init([](const Tensor &tensor, const TypePtr &type_ptr) {
  224. TypeId data_type = type_ptr ? type_ptr->type_id() : kTypeUnknown;
  225. if (data_type == kTypeUnknown || tensor.data_type() == data_type) {
  226. return std::make_shared<Tensor>(tensor);
  227. }
  228. return std::make_shared<Tensor>(tensor, data_type);
  229. }),
  230. py::arg("input"), py::arg("dtype"))
  231. .def(py::init([](const TypePtr &type_ptr, const py::tuple &shape) {
  232. auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64;
  233. return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
  234. }),
  235. py::arg("dtype"), py::arg("shape"))
  236. .def(py::init([](const py::array &input, const TypePtr &type_ptr) {
  237. return TensorPy::MakeTensor(input, type_ptr);
  238. }),
  239. py::arg("input"), py::arg("dtype") = nullptr)
  240. .def(py::init([](py::float_ input, const TypePtr &type_ptr) {
  241. return TensorPy::MakeTensor(py::array(input), type_ptr);
  242. }),
  243. py::arg("input"), py::arg("dtype") = nullptr)
  244. .def(py::init([](py::int_ input, const TypePtr &type_ptr) {
  245. return TensorPy::MakeTensor(py::array(input), type_ptr);
  246. }),
  247. py::arg("input"), py::arg("dtype") = nullptr)
  248. .def(py::init([](py::list input, const TypePtr &type_ptr) {
  249. return TensorPy::MakeTensor(py::array(input), type_ptr);
  250. }),
  251. py::arg("input"), py::arg("dtype") = nullptr)
  252. .def(py::init([](py::tuple input, const TypePtr &type_ptr) {
  253. return TensorPy::MakeTensor(py::array(input), type_ptr);
  254. }),
  255. py::arg("input"), py::arg("dtype") = nullptr)
  256. .def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag)
  257. .def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter(
  258. Get the tensor's data type.
  259. Returns:
  260. type, the data type of tensor.
  261. Examples:
  262. >>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
  263. >>> data.dtype
  264. Int32
  265. )mydelimiter")
  266. .def_property_readonly("shape", TensorPy::GetPyTupleShape, R"mydelimiter(
  267. Get the tensor's shape.
  268. Returns:
  269. tuple[int], the shape of tensor.
  270. Examples:
  271. >>> data = mindspore.Tensor(np.ones((3, 3)))
  272. >>> data.shape()
  273. (3, 3)
  274. )mydelimiter")
  275. .def("asnumpy", TensorPy::SyncAsNumpy, R"mydelimiter(
  276. Convert tensor to numpy.ndarray.
  277. Returns:
  278. numpy.ndarray.
  279. Examples:
  280. >>> data = mindspore.Tensor(np.ones((2, 3)))
  281. >>> array = data.asnumpy()
  282. >>> array
  283. array([[1., 1., 1.],
  284. [1., 1., 1.]])
  285. )mydelimiter")
  286. .def("size", &Tensor::DataSize, R"mydelimiter(
  287. Get tensor's data size.
  288. Returns:
  289. int, the size of tensor.
  290. Examples:
  291. >>> data = mindspore.Tensor(np.ones((2, 3)))
  292. >>> data.size()
  293. 6
  294. )mydelimiter")
  295. .def("is_init", &Tensor::is_init, R"mydelimiter(
  296. Get tensor init_flag.
  297. Returns:
  298. bool, whether the tensor init.
  299. Examples:
  300. >>> data = mindspore.Tensor(np.ones((2, 3)))
  301. >>> data.is_init()
  302. False
  303. )mydelimiter")
  304. .def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter(
  305. Set tensor init_flag.
  306. Examples:
  307. >>> data = mindspore.Tensor(np.ones((2, 3)))
  308. >>> data.set_init_flag(True)
  309. )mydelimiter")
  310. .def("dim", &Tensor::DataDim, R"mydelimiter(
  311. Get tensor's data dimension.
  312. Returns:
  313. int, the dimension of tensor.
  314. Examples:
  315. >>> data = mindspore.Tensor(np.ones((2, 3)))
  316. >>> data.dim()
  317. 2
  318. )mydelimiter")
  319. .def("assign_value", &Tensor::AssignValue, R"mydelimiter(
  320. Assign another tensor value to this.
  321. Arg:
  322. value (:class:`mindspore.tensor`): The value tensor.
  323. Examples:
  324. >>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
  325. >>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32))
  326. >>> data.assign_value(data2)
  327. >>> data.shape
  328. (2, 2)
  329. )mydelimiter")
  330. .def("set_dtype", &Tensor::SetDtype, R"mydelimiter(
  331. Set the tensor's data type.
  332. Arg:
  333. dtype (:class:`mindspore.dtype`): The type of output tensor.
  334. Examples:
  335. >>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
  336. >>> data.set_dtype(mindspore.int32)
  337. mindspore.int32
  338. )mydelimiter")
  339. .def("__str__", &Tensor::ToString)
  340. .def("__repr__", &Tensor::ToStringRepr)
  341. .def(py::pickle(
  342. [](const Tensor &t) { // __getstate__
  343. /* Return a tuple that fully encodes the state of the object */
  344. return py::make_tuple(TensorPy::SyncAsNumpy(t));
  345. },
  346. [](const py::tuple &t) { // __setstate__
  347. if (t.size() != 1) {
  348. throw std::runtime_error("Invalid state!");
  349. }
  350. /* Create a new C++ instance */
  351. return TensorPy::MakeTensor(t[0].cast<py::array>());
  352. }));
  353. }));
  354. } // namespace tensor
  355. } // namespace mindspore