You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

meta_tensor.h 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_IR_META_TENSOR_H_
  17. #define MINDSPORE_CCSRC_IR_META_TENSOR_H_
  18. #include <utility>
  19. #include <vector>
  20. #include <memory>
  21. #include <string>
  22. #include "device/device_address.h"
  23. #include "pybind11/numpy.h"
  24. #include "pybind11/pybind11.h"
  25. #include "Eigen/Core"
  26. #include "ir/base.h"
  27. #include "ir/dtype.h"
  28. #include "utils/log_adapter.h"
  29. #include "utils/convert_utils.h"
  30. #include "utils/hashing.h"
  31. namespace py = pybind11;
  32. using float16 = Eigen::half;
  33. namespace pybind11 {
  34. namespace detail {
  35. // Similar to enums in `pybind11/numpy.h`. Determined by doing:
  36. // python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
  37. constexpr int NPY_FLOAT16 = 23;
  38. template <typename T>
  39. struct npy_scalar_caster {
  40. PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
  41. using Array = array_t<T>;
  42. bool load(handle src, bool convert) {
  43. // Taken from Eigen casters. Permits either scalar dtype or scalar array.
  44. handle type = dtype::of<T>().attr("type");
  45. if (!convert && !isinstance<Array>(src) && !isinstance(src, type)) return false;
  46. Array tmp = Array::ensure(src);
  47. if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
  48. this->value = *tmp.data();
  49. return true;
  50. }
  51. return false;
  52. }
  53. static handle cast(T src, return_value_policy, handle) {
  54. Array tmp({1});
  55. tmp.mutable_at(0) = src;
  56. tmp.resize({});
  57. // You could also just return the array if you want a scalar array.
  58. object scalar = tmp[tuple()];
  59. return scalar.release();
  60. }
  61. };
  62. template <>
  63. struct npy_format_descriptor<float16> {
  64. static constexpr auto name = "float16";
  65. static pybind11::dtype dtype() {
  66. handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
  67. return reinterpret_borrow<pybind11::dtype>(ptr);
  68. }
  69. virtual ~npy_format_descriptor<float16>() {}
  70. };
  71. template <>
  72. struct type_caster<float16> : public npy_scalar_caster<float16> {
  73. static constexpr auto name = "float16";
  74. };
  75. } // namespace detail
  76. } // namespace pybind11
  77. using mindspore::device::DeviceAddress;
  78. using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>;
  79. // brief mindspore namespace.
  80. //
  81. // mindspore namespace is the top level namespace of Mindsporeession project.
  82. // Other namespace should be a sub namespace of mindspore namespace in the ME project.
  83. namespace mindspore {
  84. // brief mindspore::tensor namespace
  85. //
  86. // A sub namespace in ME to support tensor related definition.
  87. namespace tensor {
  88. // brief Device info of Tensor
  89. //
  90. // Includes the format and data type of a tensor.
  91. struct DeviceInfo {
  92. explicit DeviceInfo(std::string format = "DefaultFormat", TypePtr data_type = nullptr)
  93. : format_(std::move(format)), data_type_(std::move(data_type)) {}
  94. std::string format_ = "DefaultFormat";
  95. TypePtr data_type_ = nullptr;
  96. };
  97. // brief Metadata of Tensor
  98. //
  99. // Includes the metadata information of a tensor, such as data type, shape
  100. // and so on. But it does not contain values of a tensor.
  101. class MetaTensor : public Value {
  102. public:
  103. // Construction
  104. MetaTensor();
  105. // brief Constructs a meta tensor of a tensor having data_type data and shape.
  106. //
  107. // The constructed MetaTensor is not a Tensor, but it has the data type and shape
  108. // information of a Tensor. The following codes will create a 2x3 float
  109. // param data_type The data type of the tensor.
  110. // param shape The shape of the tensor.
  111. MetaTensor(const TypeId data_type, const std::vector<int> &shape);
  112. MetaTensor(const TypePtr &type_ptr, const py::tuple &shape);
  113. // brief Constructs a MetaTensor object from an existing MetaTensor instance.
  114. //
  115. // The constructed MetaTensor object will have the same data type and shape as the
  116. // meta_tensor.
  117. //
  118. // param meta_tensor An existing MetaTensor object.
  119. MetaTensor(const MetaTensor &meta_tensor);
  120. ~MetaTensor() override = default;
  121. MS_DECLARE_PARENT(MetaTensor, Value)
  122. // brief Overloads operator = for MetaTensor.
  123. //
  124. // The constructed MetaTensor object has the same type and shape with meta_tensor.
  125. //
  126. // param meta_tensor An existing MetaTensor object.
  127. virtual MetaTensor &operator=(const MetaTensor &meta_tensor);
  128. // brief Compares two MetaTensor objects.
  129. //
  130. // The constructed MetaTensor object has the same type and shape with meta_tensor.
  131. //
  132. // param meta_tensor The MetaTensor object to be compared.
  133. // return true: If having same type and shape, return true, or return false.
  134. virtual bool operator==(const MetaTensor &meta_tensor) const;
  135. // brief Returns the data type of the tensor in its MetaTensor.
  136. //
  137. // All the types are defined in "ir/dtype.h".
  138. TypePtr Dtype() const;
  139. TypeId data_type() const { return data_type_; }
  140. std::string ToString() const override;
  141. std::string DumpText() const override;
  142. // brief Sets the data type of a tensor in its MetaTensor.
  143. //
  144. // param data_type The data type of the tensor to be set.
  145. virtual TypeId set_data_type(const TypeId data_type) {
  146. data_type_ = data_type;
  147. return data_type_;
  148. }
  149. virtual TypePtr SetDtype(const TypePtr type_ptr);
  150. // brief Get tensor's shape.
  151. //
  152. // The shape of a tensor is stored in a vector<int>. Each
  153. // element of the vector represents the size of a dimension of the tensor.
  154. // The order of each element in the vector is as same as the the dimension's
  155. // order it represents.
  156. //
  157. // return A const vector<int> which represents the shape of the tensor.
  158. std::vector<int> shape() const { return shape_; }
  159. // brief Sets the shape of a tensor.
  160. //
  161. // The shape of a tensor is stored in a vector<int>. Each
  162. // element of the vector represents the size of a dimension of the tensor.
  163. // The order of each element in the vector is as same as the the dimension's
  164. // order it represents.
  165. //
  166. // param shape The shape of the tensor.
  167. // return The shape's size.
  168. size_t set_shape(const std::vector<int> &shape) {
  169. this->shape_ = shape;
  170. return shape_.size();
  171. }
  172. // Get tensor's device info.
  173. DeviceInfo device_info() const { return device_info_; }
  174. // Set tensor's device info.
  175. void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; }
  176. void SetDeviceInfo(const std::string &format, const TypePtr &data_type);
  177. // Get the size of a given dimension by its index number.
  178. int DimensionSize(size_t index) const;
  179. // Get total number of elements in a tensor.
  180. int ElementsNum() const;
  181. std::size_t hash() const override {
  182. std::size_t hash_value = std::hash<int>{}(SizeToInt(data_type_));
  183. hash_value = hash_combine(hash_value, std::hash<size_t>{}(shape_.size()));
  184. // hash all elements may costly, so only take at most 4 elements into account based on
  185. // some experiments.
  186. for (size_t i = 0; (i < shape_.size()) && (i < 4); ++i) {
  187. hash_value = hash_combine(hash_value, (std::hash<int>{}(shape_[i])));
  188. }
  189. return hash_value;
  190. }
  191. bool operator==(const Value &other) const override {
  192. if (other.isa<MetaTensor>()) {
  193. auto other_ = static_cast<const MetaTensor &>(other);
  194. return *this == other_;
  195. } else {
  196. return false;
  197. }
  198. }
  199. protected:
  200. // brief Data type of the tensor.
  201. //
  202. // All support data type is in Number Types of [TypeId],
  203. // including [kNumberTypeBool], [kNumberTypeInt],
  204. // [kNumberTypeUInt32], [kNumberTypeFloat32] and [kNumberTypeFloat64].
  205. TypeId data_type_;
  206. // brief Shape of the tensor.
  207. //
  208. // A std::vector<int> container is used to store the shape of a tensor.
  209. // Each element of the vector represents the size of a dimension of the tensor.
  210. // The order of each element in the vector is as same as the the dimension's
  211. // order it represents. If the dimension size is not set, its value will be -1.
  212. std::vector<int> shape_;
  213. // brief Device info of Tensor
  214. //
  215. // Includes the format and data type of a tensor on device.
  216. DeviceInfo device_info_;
  217. };
  218. // Tensor entity class
  219. class Tensor : public MetaTensor {
  220. public:
  221. Tensor() = default;
  222. abstract::AbstractBasePtr ToAbstract() override;
  223. // brief Constructor for Python.
  224. //
  225. // param type_ptr [TypePty] Data type of the tensor.
  226. // param py_shape [py::tuple] The shape represented by py::tuple of the tensor.
  227. Tensor(const TypePtr &type_ptr, const py::tuple &shape);
  228. // brief Constructor for C++.
  229. //
  230. // param data_type [TypeId] Data type of the tensor.
  231. // param shape The shape represented by std::vector<int> of the tensor.
  232. Tensor(TypeId data_type, const std::vector<int> &shape);
  233. // brief Constructor for Python.
  234. //
  235. // param input [py::array] Data value of the tensor.
  236. // param data_type [TypeId] Data type of the tensor.
  237. explicit Tensor(const py::array &input, const TypePtr &data_type = nullptr);
  238. // brief Constructor
  239. //
  240. // param input [py::list] the data for tensor
  241. // param data_type [TypeId] data type
  242. explicit Tensor(const py::list &input, const TypePtr &data_type = nullptr);
  243. // brief Constructor
  244. //
  245. // param input [py::tuple] the data for tensor
  246. // param data_type [TypeId] data type
  247. explicit Tensor(const py::tuple &input, const TypePtr &data_type = nullptr);
  248. // brief Constructor
  249. //
  250. // param input [py::float_] the data for tensor
  251. // param data_type [TypeId] data type
  252. explicit Tensor(const py::float_ &input, const TypePtr &data_type = nullptr);
  253. // brief Constructor
  254. //
  255. // param input [py::int_] the data for tensor
  256. // param data_type [TypeId] data type
  257. explicit Tensor(const py::int_ &input, const TypePtr &data_type = nullptr);
  258. // brief Constructor
  259. //
  260. // param input [Tensor] the data for tensor
  261. // param data_type [TypeId] data type
  262. Tensor(const Tensor &tensor, const TypePtr &data_type = nullptr);
  263. ~Tensor() override = default;
  264. MS_DECLARE_PARENT(Tensor, MetaTensor);
  265. // brief Overloads operator = for Tensor.
  266. //
  267. // The constructed Tensor object has the same type and shape with tensor.
  268. //
  269. // param tensor An existing Tensor object.
  270. Tensor &operator=(const Tensor &tensor);
  271. // brief Compares two Tensor objects.
  272. //
  273. // Compare two tensor objects to see if they have same data type, shape and
  274. // data value.
  275. //
  276. // param tensor The Tensor object to be compared.
  277. // return true: If having same type, shape and data, return true, or return false.
  278. bool operator==(const Tensor &tensor) const;
  279. // It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
  280. bool ValueEqual(const Tensor &other) const;
  281. bool operator==(const Value &other) const override {
  282. if (other.isa<Tensor>()) {
  283. auto other_ = static_cast<const Tensor &>(other);
  284. return *this == other_;
  285. } else {
  286. return false;
  287. }
  288. }
  289. // brief Gets tensor's dimension
  290. //
  291. // return The number of dimensions of the tensor data.
  292. int DataDim() const;
  293. // brief Getting tensor data size
  294. //
  295. // return The total number of elements of the tensor data.
  296. int DataSize() const;
  297. // brief Get tensor's shape
  298. //
  299. // return [py::tuple] The tensor's shape
  300. py::tuple GetPyTupleShape() const;
  301. // brief Tensor's data value.
  302. //
  303. // return [py::array] The tensor's data in py::array.
  304. py::array data() const;
  305. // brief Get the data type fo the tensor for C++
  306. //
  307. // return [int] The tensor's data type will be cast to int to return.
  308. int data_type_c() const;
  309. // brief Get the tensor's shape for C++
  310. //
  311. // return [std::vector<int>]
  312. std::vector<int> shape_c(void) const;
  313. // brief Get Tensor data pointer for c++ type
  314. //
  315. // param writable true if writable, false if read only
  316. // return The pointer to the object
  317. void *data_c(bool writable = false);
  318. // brief Get data type from tensor data.
  319. //
  320. // param buf The buffer info of the py::array data.
  321. // return The [TypeId] of the tensor data.
  322. TypeId GetDataType(const py::buffer_info &buf) const;
  323. // brief Sets the data type of a tensor.
  324. //
  325. // param data_type The data type of the tensor to be set.
  326. //
  327. TypeId set_data_type(const TypeId data_type) override;
  328. TypePtr SetDtype(const TypePtr type_ptr) override;
  329. std::string GetShapeAndDataTypeInfo() const;
  330. std::string ToString() const override;
  331. std::string ToStringRepr() const;
  332. py::array data_; // < Tensor's data value
  333. const bool parse_info_ = true;
  334. private:
  335. // brief init tensor
  336. //
  337. // param input [py::array] the data for tensor
  338. // param data_type [TypeId] data type
  339. // return true if succeed, false if failed.
  340. void init(const py::array &input, const TypeId &data_type);
  341. void init(const py::array &input, const TypePtr &type_ptr);
  342. // brief init tensor attribute
  343. //
  344. // param data_type [TypeId] Data type of the tensor.
  345. // param shape [py::array] The shape of the tensor.
  346. // return true if succeed, false if failed.
  347. void init(TypeId data_type, const std::vector<int> &shape, py::array *data);
  348. bool convert_data(const py::array &in, const TypeId in_data_type, py::array *out, const TypeId out_data_type);
  349. public:
  350. bool is_dirty() const { return dirty_; }
  351. void set_dirty(const bool dirty) { dirty_ = dirty; }
  352. DeviceAddressPtr device_address() const { return device_address_; }
  353. void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; }
  354. py::array data_sync();
  355. private:
  356. bool dirty_{true};
  357. DeviceAddressPtr device_address_{nullptr};
  358. };
  359. using TensorPtr = std::shared_ptr<Tensor>;
  360. using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
  361. } // namespace tensor
  362. } // namespace mindspore
  363. #endif // MINDSPORE_CCSRC_IR_META_TENSOR_H_