You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.h 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef TRANSFORM_UTIL_H_
  17. #define TRANSFORM_UTIL_H_
  18. #include <string>
  19. #include <vector>
  20. #include <algorithm>
  21. #include <memory>
  22. #include "securec/include/securec.h"
  23. #include "ir/anf.h"
  24. #include "ir/dtype.h"
  25. #include "ir/meta_tensor.h"
  26. #include "transform/types.h"
  27. #include "graph/tensor.h"
  28. namespace mindspore {
  29. namespace transform {
  30. class TransformUtil {
  31. public:
  32. /*
  33. * Parameters:
  34. * type: [MeDataType] the data type for ME tensor
  35. * Return:
  36. * [GeDataType] the data type for ge tensor
  37. * */
  38. static std::vector<int64_t> ConvertIntToList(int64_t data, int size);
  39. /*
  40. * Parameters:
  41. * type: [MeDataType] the data type for ME tensor
  42. * Return:
  43. * [GeDataType] the data type for ge tensor
  44. * */
  45. static GeDataType ConvertDataType(const MeDataType &type);
  46. /*
  47. * Parameters:
  48. * type: [string] the data format in ME op
  49. * Return:
  50. * [GeFormat] the data format for ge tensor
  51. * */
  52. static GeFormat ConvertFormat(const std::string &format);
  53. /*
  54. * Parameters:
  55. * type: [MeDataType] the data type for ME tensor
  56. * Return:
  57. * [size_t] the buff size for the type in ME
  58. * */
  59. static size_t GetDataTypeSize(const MeDataType &type);
  60. /*
  61. * Parameters:
  62. * tensor: [MeTensorPtr] the me tensor to get description from
  63. * format: [string] the data format in ME
  64. * is_input: [bool] whether the tensor is used as input, default:false
  65. * Return:
  66. * [shared_ptr<GeTensorDesc>] the shared pointer of ge tensor description
  67. * */
  68. static std::shared_ptr<GeTensorDesc> GetGeTensorDesc(const std::vector<int> &shape, const MeDataType &me_type,
  69. const std::string &format);
  70. /*
  71. * Parameters:
  72. * tensor: [MeTensor] the data tensor in ME
  73. * format: [string] the data format in ME op
  74. * is_input: [bool] whether the tensor is used as input, default:false
  75. * Return:
  76. * [GeTensor] the data tensor in GE
  77. * */
  78. static GeTensorPtr ConvertTensor(const MeTensorPtr &tensor, const std::string &format);
  79. /*
  80. * Parameters:
  81. * me_tensors: [vector<MeTensorPtr>] the data tensors in ME
  82. * format: [string] the data format in ME op
  83. * Return:
  84. * [std::vector<GeTensorPtr>] the data tensors in GE
  85. * */
  86. static std::vector<GeTensorPtr> ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors,
  87. const std::string &format);
  88. /*
  89. * Parameters:
  90. * tensor: [GeTensor] the data tensor in GE
  91. * Return:
  92. * [MeTensor] the data tensor in ME
  93. * */
  94. static MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor);
  95. /*
  96. * Parameters:
  97. * tensor: [GeTensor] the data tensor in GE
  98. * request_dims [std::vector<int>] the output Me tensors must adjust to this shapes
  99. * Return:
  100. * [MeTensor] the data tensor in ME
  101. * */
  102. static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const std::vector<int> &request_dims);
  103. /*
  104. * Parameters:
  105. * ge_tensors: [std::vector<GeTensorPtr>] the data tensor in GE
  106. * request_dims [std::vector<std::vector<int>>] the output Me tensors must adjust to this shapes
  107. * Return:
  108. * [std::vector<MeTensorPtr>] the data tensor in ME
  109. * */
  110. static std::vector<MeTensorPtr> ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors,
  111. const std::vector<std::vector<int>> &request_dims);
  112. /*
  113. * Parameters:
  114. * ge_tensors: [std::vector<GeTensorPtr>] the data tensor in GE
  115. * Return:
  116. * [std::vector<MeTensorPtr>] the data tensor in ME
  117. * */
  118. static std::vector<MeTensorPtr> ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors);
  119. /*
  120. * Parameters:
  121. * ge_tensor: [GeTensor] the data tensor in GE
  122. * me_dims: [std::vector<int>] the shape of created Me tensor
  123. * me_type: [TypeId] the type of created Me tensor
  124. * Return:
  125. * [MeTensor] the data tensor in ME
  126. * */
  127. static MeTensorPtr GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector<int> &me_dims,
  128. const TypeId &me_type);
  129. /*
  130. * Parameters:
  131. * type: [GeDataType] the ge tensor data type
  132. * Return:
  133. * [MeDataType] the me tensor data type
  134. * */
  135. static MeDataType ConvertGeDataType(const GeDataType &type);
  136. /*
  137. * Parameters:
  138. * me_dims: [std::vector<int>] the me shape
  139. * Return:
  140. * [GeShape] the ge shape
  141. * */
  142. static GeShape ConvertMeShape(const std::vector<int> &me_dims);
  143. /*
  144. * Parameters:
  145. * ge_shape: [GeShape] the ge shape
  146. * Return:
  147. * [vector<int>] the me shape
  148. * */
  149. static std::vector<int> ConvertGeShape(const GeShape &ge_shape);
  150. /* Function:
  151. * Convert GeShape to Me request shape, Support pattern:
  152. * {1, x, 1, 1} --> {x}
  153. * {x, 1, 1, 1} --> {x}
  154. * {x, x, 1, 1} --> {x, x}
  155. * {x, x, x, 1} --> {x, x, x}
  156. * {x, x, x, x} --> {x, x, x, x}
  157. * If unmatch upon patterns, return original ge dims
  158. * Parameters:
  159. * ge_shape: [GeShape] the ge shape
  160. * request_dims: [vector<int>] request dims
  161. * Return:
  162. * [vector<int>] the me shape
  163. * */
  164. static std::vector<int> ConvertGeShape(const GeShape &ge_shape, const std::vector<int> &request_dims);
  165. /*
  166. * Parameters:
  167. * vec: [std::vector<int>] the vector to print
  168. * Return:
  169. * [string] value string
  170. * */
  171. template <typename T, typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
  172. static std::string PrintVector(const std::vector<T> &vec) {
  173. const int MAX_PRINT_NUM = 100;
  174. std::stringstream ss;
  175. ss << "{ ";
  176. int i = 0;
  177. for (auto it = vec.begin(); it != vec.end(); ++it) {
  178. ss << std::to_string(*it) << ", ";
  179. i++;
  180. if (i >= MAX_PRINT_NUM) {
  181. break;
  182. }
  183. }
  184. if (i >= MAX_PRINT_NUM) {
  185. ss << "... to be continue}";
  186. } else {
  187. ss << "}";
  188. }
  189. return ss.str();
  190. }
  191. /*
  192. * Parameters:
  193. * ge_tensor: [GeTensorPtr] the ge tensor
  194. * Return:
  195. * [stringstream] value string
  196. * */
  197. static std::string PrintGeTensor(const GeTensorPtr ge_tensor);
  198. /*
  199. * Parameters:
  200. * data: [uint8_t *] the ge tensor data pointer
  201. * size: [size_t] the ge tensor data bytes
  202. * Return:
  203. * [shared_ptr<std::vector<T>] vector pointer
  204. * */
  205. template <typename T, typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
  206. static std::vector<T> MakeVector(const uint8_t *const data, size_t size) {
  207. auto dest = std::vector<T>(size / sizeof(T));
  208. if (data == nullptr) {
  209. return dest;
  210. }
  211. errno_t ret = memcpy_s(dest.data(), dest.size() * sizeof(T), data, size);
  212. if (EOK != ret) {
  213. return std::vector<T>();
  214. }
  215. return dest;
  216. }
  217. };
  218. } // namespace transform
  219. } // namespace mindspore
  220. #endif // TRANSFORM_UTIL_H_