You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

meta_tensor.h 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_IR_META_TENSOR_H_
  17. #define MINDSPORE_CCSRC_IR_META_TENSOR_H_
  18. #include <utility>
  19. #include <vector>
  20. #include <memory>
  21. #include <string>
  22. #include "ir/base.h"
  23. #include "ir/dtype.h"
  24. #include "utils/convert_utils.h"
  25. #include "utils/hashing.h"
  26. // brief mindspore namespace.
  27. //
  28. // mindspore namespace is the top level namespace of Mindsporeession project.
  29. // Other namespace should be a sub namespace of mindspore namespace in the ME project.
  30. namespace mindspore {
  31. // brief mindspore::tensor namespace
  32. //
  33. // A sub namespace in ME to support tensor related definition.
  34. namespace tensor {
  35. // brief Device info of Tensor
  36. //
  37. // Includes the format and data type of a tensor.
  38. struct DeviceInfo {
  39. explicit DeviceInfo(std::string format = "DefaultFormat", TypePtr data_type = nullptr)
  40. : format_(std::move(format)), data_type_(std::move(data_type)) {}
  41. std::string format_ = "DefaultFormat";
  42. TypePtr data_type_ = nullptr;
  43. };
  44. // brief Metadata of Tensor
  45. //
  46. // Includes the metadata information of a tensor, such as data type, shape
  47. // and so on. But it does not contain values of a tensor.
  48. class MetaTensor : public Value {
  49. public:
  50. // Construction
  51. MetaTensor();
  52. // brief Constructs a meta tensor of a tensor having data_type data and shape.
  53. //
  54. // The constructed MetaTensor is not a Tensor, but it has the data type and shape
  55. // information of a Tensor. The following codes will create a 2x3 float
  56. // param data_type The data type of the tensor.
  57. // param shape The shape of the tensor.
  58. MetaTensor(const TypeId data_type, const std::vector<int> &shape);
  59. MetaTensor(const TypePtr &type_ptr, const std::vector<int> &shape);
  60. // brief Constructs a MetaTensor object from an existing MetaTensor instance.
  61. //
  62. // The constructed MetaTensor object will have the same data type and shape as the
  63. // meta_tensor.
  64. //
  65. // param meta_tensor An existing MetaTensor object.
  66. MetaTensor(const MetaTensor &meta_tensor);
  67. ~MetaTensor() override = default;
  68. MS_DECLARE_PARENT(MetaTensor, Value)
  69. // brief Overloads operator = for MetaTensor.
  70. //
  71. // The constructed MetaTensor object has the same type and shape with meta_tensor.
  72. //
  73. // param meta_tensor An existing MetaTensor object.
  74. virtual MetaTensor &operator=(const MetaTensor &meta_tensor);
  75. // brief Compares two MetaTensor objects.
  76. //
  77. // The constructed MetaTensor object has the same type and shape with meta_tensor.
  78. //
  79. // param meta_tensor The MetaTensor object to be compared.
  80. // return true: If having same type and shape, return true, or return false.
  81. virtual bool operator==(const MetaTensor &meta_tensor) const;
  82. // brief Returns the data type of the tensor in its MetaTensor.
  83. //
  84. // All the types are defined in "ir/dtype.h".
  85. TypePtr Dtype() const;
  86. abstract::AbstractBasePtr ToAbstract() override;
  87. TypeId data_type() const { return data_type_; }
  88. std::string ToString() const override;
  89. std::string DumpText() const override;
  90. // brief Sets the data type of a tensor in its MetaTensor.
  91. //
  92. // param data_type The data type of the tensor to be set.
  93. virtual TypeId set_data_type(const TypeId data_type) {
  94. data_type_ = data_type;
  95. return data_type_;
  96. }
  97. virtual TypePtr SetDtype(const TypePtr type_ptr);
  98. // brief Get tensor's shape.
  99. //
  100. // The shape of a tensor is stored in a vector<int>. Each
  101. // element of the vector represents the size of a dimension of the tensor.
  102. // The order of each element in the vector is as same as the the dimension's
  103. // order it represents.
  104. //
  105. // return A const vector<int> which represents the shape of the tensor.
  106. std::vector<int> shape() const { return shape_; }
  107. // brief Sets the shape of a tensor.
  108. //
  109. // The shape of a tensor is stored in a vector<int>. Each
  110. // element of the vector represents the size of a dimension of the tensor.
  111. // The order of each element in the vector is as same as the the dimension's
  112. // order it represents.
  113. //
  114. // param shape The shape of the tensor.
  115. // return The shape's size.
  116. size_t set_shape(const std::vector<int> &shape) {
  117. this->shape_ = shape;
  118. return shape_.size();
  119. }
  120. // Get tensor's device info.
  121. DeviceInfo device_info() const { return device_info_; }
  122. // Set tensor's device info.
  123. void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; }
  124. void SetDeviceInfo(const std::string &format, const TypePtr &data_type);
  125. // Get the size of a given dimension by its index number.
  126. int DimensionSize(size_t index) const;
  127. // Get total number of elements in a tensor.
  128. int ElementsNum() const;
  129. std::size_t hash() const override {
  130. std::size_t hash_value = std::hash<int>{}(SizeToInt(data_type_));
  131. hash_value = hash_combine(hash_value, std::hash<size_t>{}(shape_.size()));
  132. // hash all elements may costly, so only take at most 4 elements into account based on
  133. // some experiments.
  134. for (size_t i = 0; (i < shape_.size()) && (i < 4); ++i) {
  135. hash_value = hash_combine(hash_value, (std::hash<int>{}(shape_[i])));
  136. }
  137. return hash_value;
  138. }
  139. bool operator==(const Value &other) const override {
  140. if (other.isa<MetaTensor>()) {
  141. auto other_ = static_cast<const MetaTensor &>(other);
  142. return *this == other_;
  143. } else {
  144. return false;
  145. }
  146. }
  147. const bool parse_info_ = true;
  148. protected:
  149. // brief Data type of the tensor.
  150. //
  151. // All support data type is in Number Types of [TypeId],
  152. // including [kNumberTypeBool], [kNumberTypeInt],
  153. // [kNumberTypeUInt32], [kNumberTypeFloat32] and [kNumberTypeFloat64].
  154. TypeId data_type_;
  155. // brief Shape of the tensor.
  156. //
  157. // A std::vector<int> container is used to store the shape of a tensor.
  158. // Each element of the vector represents the size of a dimension of the tensor.
  159. // The order of each element in the vector is as same as the the dimension's
  160. // order it represents. If the dimension size is not set, its value will be -1.
  161. std::vector<int> shape_;
  162. // brief Device info of Tensor
  163. //
  164. // Includes the format and data type of a tensor on device.
  165. DeviceInfo device_info_;
  166. };
  167. using MetaTensorPtr = std::shared_ptr<MetaTensor>;
  168. } // namespace tensor
  169. } // namespace mindspore
  170. #endif // MINDSPORE_CCSRC_IR_META_TENSOR_H_