You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.h 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_IR_TENSOR_H_
  17. #define MINDSPORE_CCSRC_IR_TENSOR_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include "pybind11/numpy.h"
  22. #include "pybind11/pybind11.h"
  23. #include "Eigen/Core"
  24. #include "device/device_address.h"
  25. #include "ir/meta_tensor.h"
  26. #include "utils/log_adapter.h"
  27. namespace py = pybind11;
  28. using float16 = Eigen::half;
  29. namespace pybind11 {
  30. namespace detail {
  31. // Similar to enums in `pybind11/numpy.h`. Determined by doing:
  32. // python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
  33. constexpr int NPY_FLOAT16 = 23;
  34. template <typename T>
  35. struct npy_scalar_caster {
  36. PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
  37. using Array = array_t<T>;
  38. bool load(handle src, bool convert) {
  39. // Taken from Eigen casters. Permits either scalar dtype or scalar array.
  40. handle type = dtype::of<T>().attr("type");
  41. if (!convert && !isinstance<Array>(src) && !isinstance(src, type)) return false;
  42. Array tmp = Array::ensure(src);
  43. if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
  44. this->value = *tmp.data();
  45. return true;
  46. }
  47. return false;
  48. }
  49. static handle cast(T src, return_value_policy, handle) {
  50. Array tmp({1});
  51. tmp.mutable_at(0) = src;
  52. tmp.resize({});
  53. // You could also just return the array if you want a scalar array.
  54. object scalar = tmp[tuple()];
  55. return scalar.release();
  56. }
  57. };
  58. template <>
  59. struct npy_format_descriptor<float16> {
  60. static constexpr auto name = "float16";
  61. static pybind11::dtype dtype() {
  62. handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
  63. return reinterpret_borrow<pybind11::dtype>(ptr);
  64. }
  65. virtual ~npy_format_descriptor<float16>() {}
  66. };
  67. template <>
  68. struct type_caster<float16> : public npy_scalar_caster<float16> {
  69. static constexpr auto name = "float16";
  70. };
  71. } // namespace detail
  72. } // namespace pybind11
  73. using mindspore::device::DeviceAddress;
  74. using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>;
  75. // brief mindspore namespace.
  76. //
  77. // mindspore namespace is the top level namespace of Mindsporeession project.
  78. // Other namespace should be a sub namespace of mindspore namespace in the ME project.
  79. namespace mindspore {
  80. // brief mindspore::tensor namespace
  81. //
  82. // A sub namespace in ME to support tensor related definition.
  83. namespace tensor {
  84. // Tensor entity class
  85. class Tensor : public MetaTensor {
  86. public:
  87. Tensor() = default;
  88. abstract::AbstractBasePtr ToAbstract() override;
  89. // brief Constructor for Python.
  90. //
  91. // param type_ptr [TypePty] Data type of the tensor.
  92. // param py_shape [py::tuple] The shape represented by py::tuple of the tensor.
  93. Tensor(const TypePtr &type_ptr, const py::tuple &shape);
  94. // brief Constructor for C++.
  95. //
  96. // param data_type [TypeId] Data type of the tensor.
  97. // param shape The shape represented by std::vector<int> of the tensor.
  98. Tensor(TypeId data_type, const std::vector<int> &shape);
  99. // brief Constructor for Python.
  100. //
  101. // param input [py::array] Data value of the tensor.
  102. // param data_type [TypeId] Data type of the tensor.
  103. explicit Tensor(const py::array &input, const TypePtr &data_type = nullptr);
  104. // brief Constructor
  105. //
  106. // param input [py::list] the data for tensor
  107. // param data_type [TypeId] data type
  108. explicit Tensor(const py::list &input, const TypePtr &data_type = nullptr);
  109. // brief Constructor
  110. //
  111. // param input [py::tuple] the data for tensor
  112. // param data_type [TypeId] data type
  113. explicit Tensor(const py::tuple &input, const TypePtr &data_type = nullptr);
  114. // brief Constructor
  115. //
  116. // param input [py::float_] the data for tensor
  117. // param data_type [TypeId] data type
  118. explicit Tensor(const py::float_ &input, const TypePtr &data_type = nullptr);
  119. // brief Constructor
  120. //
  121. // param input [py::int_] the data for tensor
  122. // param data_type [TypeId] data type
  123. explicit Tensor(const py::int_ &input, const TypePtr &data_type = nullptr);
  124. // brief Constructor
  125. //
  126. // param input [Tensor] the data for tensor
  127. // param data_type [TypeId] data type
  128. Tensor(const Tensor &tensor, const TypePtr &data_type = nullptr);
  129. ~Tensor() override = default;
  130. MS_DECLARE_PARENT(Tensor, MetaTensor);
  131. // brief Overloads operator = for Tensor.
  132. //
  133. // The constructed Tensor object has the same type and shape with tensor.
  134. //
  135. // param tensor An existing Tensor object.
  136. Tensor &operator=(const Tensor &tensor);
  137. // brief Compares two Tensor objects.
  138. //
  139. // Compare two tensor objects to see if they have same data type, shape and
  140. // data value.
  141. //
  142. // param tensor The Tensor object to be compared.
  143. // return true: If having same type, shape and data, return true, or return false.
  144. bool operator==(const Tensor &tensor) const;
  145. // It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
  146. bool ValueEqual(const Tensor &other) const;
  147. bool operator==(const Value &other) const override {
  148. if (other.isa<Tensor>()) {
  149. auto other_ = static_cast<const Tensor &>(other);
  150. return *this == other_;
  151. } else {
  152. return false;
  153. }
  154. }
  155. py::tuple GetPyTupleShape() const;
  156. // brief Gets tensor's dimension
  157. //
  158. // return The number of dimensions of the tensor data.
  159. int DataDim() const;
  160. // brief Getting tensor data size
  161. //
  162. // return The total number of elements of the tensor data.
  163. int DataSize() const;
  164. // brief Tensor's data value.
  165. //
  166. // return [py::array] The tensor's data in py::array.
  167. py::array data() const;
  168. // brief Get the data type fo the tensor for C++
  169. //
  170. // return [int] The tensor's data type will be cast to int to return.
  171. int data_type_c() const;
  172. // brief Get the tensor's shape for C++
  173. //
  174. // return [std::vector<int>]
  175. std::vector<int> shape_c(void) const;
  176. // brief Get Tensor data pointer for c++ type
  177. //
  178. // param writable true if writable, false if read only
  179. // return The pointer to the object
  180. void *data_c(bool writable = false);
  181. // brief Get data type from tensor data.
  182. //
  183. // param buf The buffer info of the py::array data.
  184. // return The [TypeId] of the tensor data.
  185. TypeId GetDataType(const py::buffer_info &buf) const;
  186. // brief Sets the data type of a tensor.
  187. //
  188. // param data_type The data type of the tensor to be set.
  189. //
  190. TypeId set_data_type(const TypeId data_type) override;
  191. TypePtr SetDtype(const TypePtr type_ptr) override;
  192. std::string GetShapeAndDataTypeInfo() const;
  193. std::string ToString() const override;
  194. std::string ToStringRepr() const;
  195. py::array data_; // < Tensor's data value
  196. const bool parse_info_ = true;
  197. bool is_init();
  198. void set_init_flag(bool flag);
  199. private:
  200. // brief init tensor
  201. //
  202. // param input [py::array] the data for tensor
  203. // param data_type [TypeId] data type
  204. // return true if succeed, false if failed.
  205. void init(const py::array &input, const TypeId &data_type);
  206. void init(const py::array &input, const TypePtr &type_ptr);
  207. bool init_flag_{false};
  208. // brief init tensor attribute
  209. //
  210. // param data_type [TypeId] Data type of the tensor.
  211. // param shape [py::array] The shape of the tensor.
  212. // return true if succeed, false if failed.
  213. void init(TypeId data_type, const std::vector<int> &shape, py::array *data);
  214. bool convert_data(const py::array &in, const TypeId in_data_type, py::array *out, const TypeId out_data_type);
  215. public:
  216. bool is_dirty() const { return dirty_; }
  217. void set_dirty(const bool dirty) { dirty_ = dirty; }
  218. DeviceAddressPtr device_address() const { return device_address_; }
  219. void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; }
  220. py::array data_sync();
  221. std::string id() const { return id_; }
  222. private:
  223. bool dirty_{true};
  224. std::string id_{""};
  225. DeviceAddressPtr device_address_{nullptr};
  226. };
  227. using TensorPtr = std::shared_ptr<Tensor>;
  228. using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
  229. } // namespace tensor
  230. } // namespace mindspore
  231. #endif // MINDSPORE_CCSRC_IR_TENSOR_H_