You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.h 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_UTIL_H_
  17. #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_UTIL_H_
  18. #include <string>
  19. #include <vector>
  20. #include <algorithm>
  21. #include <memory>
  22. #include "securec/include/securec.h"
  23. #include "ir/anf.h"
  24. #include "ir/dtype.h"
  25. #include "ir/tensor.h"
  26. #include "transform/graph_ir/types.h"
  27. #include "graph/tensor.h"
  28. #include "utils/shape_utils.h"
  29. namespace mindspore {
  30. namespace transform {
  31. class TransformUtil {
  32. public:
  33. /*
  34. * Parameters:
  35. * type: [MeDataType] the data type for ME tensor
  36. * Return:
  37. * [GeDataType] the data type for ge tensor
  38. * */
  39. static std::vector<int64_t> ConvertIntToList(int64_t data, int size);
  40. /*
  41. * Parameters:
  42. * type: [MeDataType] the data type for ME tensor
  43. * Return:
  44. * [GeDataType] the data type for ge tensor
  45. * */
  46. static GeDataType ConvertDataType(const MeDataType &type);
  47. /*
  48. * Parameters:
  49. * type: [string] the data format in ME op
  50. * Return:
  51. * [GeFormat] the data format for ge tensor
  52. * */
  53. static GeFormat ConvertFormat(const std::string &format);
  54. /*
  55. * Parameters:
  56. * type: [MeDataType] the data type for ME tensor
  57. * Return:
  58. * [size_t] the buff size for the type in ME
  59. * */
  60. static size_t GetDataTypeSize(const MeDataType &type);
  61. /*
  62. * Parameters:
  63. * tensor: [MeTensorPtr] the me tensor to get description from
  64. * format: [string] the data format in ME
  65. * is_input: [bool] whether the tensor is used as input, default:false
  66. * Return:
  67. * [shared_ptr<GeTensorDesc>] the shared pointer of ge tensor description
  68. * */
  69. static std::shared_ptr<GeTensorDesc> GetGeTensorDesc(const ShapeVector &shape, const MeDataType &me_type,
  70. const std::string &format);
  71. /*
  72. * Parameters:
  73. * tensor: [MeTensor] the data tensor in ME
  74. * format: [string] the data format in ME op
  75. * is_input: [bool] whether the tensor is used as input, default:false
  76. * Return:
  77. * [GeTensor] the data tensor in GE
  78. * */
  79. static GeTensorPtr ConvertTensor(const MeTensorPtr &tensor, const std::string &format);
  80. /*
  81. * Parameters:
  82. * me_tensors: [vector<MeTensorPtr>] the data tensors in ME
  83. * format: [string] the data format in ME op
  84. * Return:
  85. * [std::vector<GeTensorPtr>] the data tensors in GE
  86. * */
  87. static std::vector<GeTensorPtr> ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors,
  88. const std::string &format);
  89. /*
  90. * Parameters:
  91. * tensor: [GeTensor] the data tensor in GE
  92. * Return:
  93. * [MeTensor] the data tensor in ME
  94. * */
  95. static MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor);
  96. /*
  97. * Parameters:
  98. * tensor: [GeTensor] the data tensor in GE
  99. * request_dims [ShapeVector] the output Me tensors must adjust to this shapes
  100. * Return:
  101. * [MeTensor] the data tensor in ME
  102. * */
  103. static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const ShapeVector &request_dims);
  104. /*
  105. * Parameters:
  106. * ge_tensors: [std::vector<GeTensorPtr>] the data tensor in GE
  107. * request_dims [std::vector<ShapeVector>] the output Me tensors must adjust to this shapes
  108. * Return:
  109. * [std::vector<MeTensorPtr>] the data tensor in ME
  110. * */
  111. static std::vector<MeTensorPtr> ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors,
  112. const std::vector<ShapeVector> &request_dims);
  113. /*
  114. * Parameters:
  115. * ge_tensors: [std::vector<GeTensorPtr>] the data tensor in GE
  116. * Return:
  117. * [std::vector<MeTensorPtr>] the data tensor in ME
  118. * */
  119. static std::vector<MeTensorPtr> ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors);
  120. /*
  121. * Parameters:
  122. * ge_tensor: [GeTensor] the data tensor in GE
  123. * me_dims: [ShapeVector] the shape of created Me tensor
  124. * me_type: [TypeId] the type of created Me tensor
  125. * Return:
  126. * [MeTensor] the data tensor in ME
  127. * */
  128. static MeTensorPtr GenerateMeTensor(const GeTensorPtr &ge_tensor, const ShapeVector &me_dims, const TypeId &me_type);
  129. /*
  130. * Parameters:
  131. * type: [GeDataType] the ge tensor data type
  132. * Return:
  133. * [MeDataType] the me tensor data type
  134. * */
  135. static MeDataType ConvertGeDataType(const GeDataType &type);
  136. /*
  137. * Parameters:
  138. * me_dims: [ShapeVector] the me shape
  139. * Return:
  140. * [GeShape] the ge shape
  141. * */
  142. static GeShape ConvertMeShape(const ShapeVector &me_dims);
  143. /*
  144. * Parameters:
  145. * ge_shape: [GeShape] the ge shape
  146. * Return:
  147. * [vector<int>] the me shape
  148. * */
  149. static ShapeVector ConvertGeShape(const GeShape &ge_shape);
  150. /* Function:
  151. * Convert GeShape to Me request shape, Support pattern:
  152. * {1, x, 1, 1} --> {x}
  153. * {x, 1, 1, 1} --> {x}
  154. * {x, x, 1, 1} --> {x, x}
  155. * {x, x, x, 1} --> {x, x, x}
  156. * {x, x, x, x} --> {x, x, x, x}
  157. * If unmatch upon patterns, return original ge dims
  158. * Parameters:
  159. * ge_shape: [GeShape] the ge shape
  160. * request_dims: [vector<int>] request dims
  161. * Return:
  162. * [vector<int>] the me shape
  163. * */
  164. static ShapeVector ConvertGeShape(const GeShape &ge_shape, const ShapeVector &request_dims);
  165. /*
  166. * Parameters:
  167. * vec: [ShapeVector] the vector to print
  168. * Return:
  169. * [string] value string
  170. * */
  171. template <typename T, typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
  172. static std::string PrintVector(const std::vector<T> &vec) {
  173. const int MAX_PRINT_NUM = 100;
  174. std::stringstream ss;
  175. ss << "{ ";
  176. int i = 0;
  177. for (auto it = vec.begin(); it != vec.end(); ++it) {
  178. ss << std::to_string(*it) << ", ";
  179. i++;
  180. if (i >= MAX_PRINT_NUM) {
  181. break;
  182. }
  183. }
  184. if (i >= MAX_PRINT_NUM) {
  185. ss << "... to be continue}";
  186. } else {
  187. ss << "}";
  188. }
  189. return ss.str();
  190. }
  191. /*
  192. * Parameters:
  193. * ge_tensor: [GeTensorPtr] the ge tensor
  194. * Return:
  195. * [stringstream] value string
  196. * */
  197. static std::string PrintGeTensor(const GeTensorPtr ge_tensor);
  198. /*
  199. * Parameters:
  200. * data: [uint8_t *] the ge tensor data pointer
  201. * size: [size_t] the ge tensor data bytes
  202. * Return:
  203. * [shared_ptr<std::vector<T>] vector pointer
  204. * */
  205. template <typename T, typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
  206. static std::vector<T> MakeVector(const uint8_t *const data, size_t size) {
  207. auto dest = std::vector<T>(size / sizeof(T));
  208. if (data == nullptr) {
  209. return dest;
  210. }
  211. errno_t ret = memcpy_s(dest.data(), dest.size() * sizeof(T), data, size);
  212. if (EOK != ret) {
  213. return std::vector<T>();
  214. }
  215. return dest;
  216. }
  217. };
  218. } // namespace transform
  219. } // namespace mindspore
  220. #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_UTIL_H_