You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans.h 8.9 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_COMMON_TRANS_H
  17. #define MINDSPORE_CCSRC_COMMON_TRANS_H
  18. #include <algorithm>
  19. #include <functional>
  20. #include <map>
  21. #include <memory>
  22. #include <string>
  23. #include <utility>
  24. #include <vector>
  25. #include "ir/dtype.h"
  26. #include "backend/kernel_compiler/kernel.h"
  27. #include "ir/dtype/type.h"
  28. #include "utils/shape_utils.h"
  29. #include "backend/session/anf_runtime_algorithm.h"
  30. namespace mindspore {
  31. namespace trans {
  32. constexpr int64_t kAlign16 = 16;
  33. enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims };
  34. enum Axis5D : int {
  35. N_ncdhw = 0,
  36. C_ncdhw,
  37. D_ncdhw,
  38. H_ncdhw,
  39. W_ncdhw,
  40. kNcdhw,
  41. N_ndc1hwc0 = 0,
  42. D_ndc1hwc0,
  43. C1_ndc1hwc0,
  44. H_ndc1hwc0,
  45. W_ndc1hwc0,
  46. C0_ndc1hwc0
  47. };
  48. struct TypeIdArgs {
  49. const void *data;
  50. size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d
  51. TypeId host_data_type;
  52. TypeId device_data_type;
  53. size_t data_size;
  54. };
  55. struct FormatArgs {
  56. const void *data;
  57. const size_t device_size;
  58. std::string host_format;
  59. std::string device_format;
  60. std::vector<size_t> host_shape;
  61. std::vector<size_t> device_shape;
  62. TypeId src_data_type;
  63. };
  64. int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index);
  65. std::vector<int64_t> GetAttrInputAndHiddenSize(const AnfNodePtr &node);
  66. void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec);
  67. void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec);
  68. ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
  69. bool IsNeedPadding(const std::string &format, const size_t shape_size);
  70. int64_t GetNodeGroups(const AnfNodePtr &node);
  71. std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
  72. const int64_t groups = 1,
  73. const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16});
  74. std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format,
  75. const int64_t groups = 1,
  76. const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16});
  77. template <typename T>
  78. std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, const AnfNodePtr &node,
  79. const size_t index, bool is_output = true) {
  80. int64_t groups = 1;
  81. if (format == kOpFormat_FRAC_Z) {
  82. groups = GetAttrGroups(node, index);
  83. }
  84. std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
  85. if (format == kOpFormat_FRACTAL_ZN_RNN || format == kOpFormat_ND_RNN_BIAS) {
  86. input_hidden_size = GetAttrInputAndHiddenSize(node);
  87. }
  88. if (node != nullptr) {
  89. MS_LOG(DEBUG) << "Start trans infer shape to device shape for node: " << node->DebugString()
  90. << ", format: " << format;
  91. }
  92. return TransShapeToDevice(shape, format, groups, input_hidden_size);
  93. }
  94. bool TransDataType(const TypeIdArgs &args, void *result);
  95. bool TransFormat(const FormatArgs &args, void *result, int64_t groups = 1);
  96. bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index);
  97. bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups = 1);
  98. bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index);
  99. // host to device
  100. bool NchwTo4D(const FormatArgs &args, void *result);
  101. bool NchwToFracZ(const FormatArgs &args, void *result);
  102. bool NchwToFracNz(const FormatArgs &args, void *result);
  103. bool NchwToNc1hwc0(const FormatArgs &args, void *result);
  104. bool NcdhwToFracZ3D(const FormatArgs &args, void *result);
  105. bool NchwToFracZc04(const FormatArgs &args, void *result);
  106. bool NchwToNc1hwc04(const FormatArgs &args, void *result);
  107. bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
  108. bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result);
  109. bool NchwToFracZWithGroups(const FormatArgs &args, void *result, int64_t groups);
  110. // device to host
  111. bool ToNchw(const FormatArgs &args, void *result);
  112. bool FracZToNchw(const FormatArgs &args, void *result);
  113. bool FracNzToNchw(const FormatArgs &args, void *result);
  114. bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
  115. bool Nc1hwc04ToNchw(const FormatArgs &args, void *result);
  116. bool FracZ3DToNcdhw(const FormatArgs &args, void *result);
  117. bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
  118. bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result);
  119. bool FracZToNchwWithGroups(const FormatArgs &args, void *result, int64_t groups);
  120. using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
  121. const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{
  122. {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},
  123. {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
  124. {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04},
  125. {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}, {kOpFormat_FRACTAL_Z_3D, NcdhwToFracZ3D}};
  126. template <typename T>
  127. std::vector<T> PaddingShapeTo5dDefault(const std::vector<T> &shape) {
  128. if (shape.size() >= kNcdhw) {
  129. return shape;
  130. }
  131. std::vector<T> shape_5d(kNcdhw, 1);
  132. switch (shape.size()) {
  133. case N_ncdhw:
  134. return shape_5d;
  135. case C_ncdhw:
  136. shape_5d[C_ncdhw] = shape[N_ncdhw];
  137. break;
  138. case D_ncdhw:
  139. shape_5d[C_ncdhw] = shape[N_ncdhw];
  140. shape_5d[D_ncdhw] = shape[C_ncdhw];
  141. break;
  142. case H_ncdhw:
  143. shape_5d[C_ncdhw] = shape[N_ncdhw];
  144. shape_5d[D_ncdhw] = shape[C_ncdhw];
  145. shape_5d[H_ncdhw] = shape[D_ncdhw];
  146. break;
  147. case W_ncdhw:
  148. shape_5d[C_ncdhw] = shape[N_ncdhw];
  149. shape_5d[D_ncdhw] = shape[C_ncdhw];
  150. shape_5d[H_ncdhw] = shape[D_ncdhw];
  151. shape_5d[W_ncdhw] = shape[H_ncdhw];
  152. break;
  153. default:
  154. MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
  155. }
  156. return shape_5d;
  157. }
  158. template <typename T>
  159. std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape) {
  160. std::vector<T> shape_4d(kNchwDims, 1);
  161. switch (shape.size()) {
  162. case kN:
  163. return shape_4d;
  164. case kC:
  165. shape_4d[kC] = shape[kN];
  166. break;
  167. case kH:
  168. shape_4d[kC] = shape[kN];
  169. shape_4d[kH] = shape[kC];
  170. break;
  171. case kW:
  172. shape_4d[kC] = shape[kN];
  173. shape_4d[kH] = shape[kC];
  174. shape_4d[kW] = shape[kH];
  175. break;
  176. case kNchwDims:
  177. std::copy(shape.begin(), shape.end(), shape_4d.begin());
  178. break;
  179. default:
  180. MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
  181. }
  182. return shape_4d;
  183. }
  184. template <typename T>
  185. std::vector<T> PaddingShapeTo5d(const std::vector<T> &shape, const std::string &padding_str = {""}) {
  186. std::vector<Axis5D> padding_axis;
  187. StringToAxisVector5D(padding_str, &padding_axis);
  188. if (padding_axis.empty() || shape.size() != padding_axis.size()) {
  189. return PaddingShapeTo5dDefault(shape);
  190. }
  191. std::vector<T> shape_5d(kNcdhw, 1);
  192. for (size_t index = 0; index < padding_axis.size(); index++) {
  193. shape_5d[padding_axis[index]] = shape[index];
  194. }
  195. return shape_5d;
  196. }
  197. template <typename T>
  198. std::vector<T> PaddingShapeTo4d(const std::vector<T> &shape, const std::string &padding_str = {""}) {
  199. std::vector<Axis> padding_axis;
  200. StringToAxisVector4D(padding_str, &padding_axis);
  201. if (padding_axis.empty() || shape.size() != padding_axis.size()) {
  202. return PaddingShapeTo4dDefault(shape);
  203. }
  204. std::vector<T> shape_4d(kNchwDims, 1);
  205. for (size_t index = 0; index < padding_axis.size(); index++) {
  206. shape_4d[padding_axis[index]] = shape[index];
  207. }
  208. return shape_4d;
  209. }
  210. template <typename T>
  211. std::vector<T> PaddingShape(const std::vector<T> &shape, const std::string &format,
  212. const std::string &pad_index = {""}) {
  213. std::vector<T> host_shape;
  214. if (k3DFormatSet.find(format) != k3DFormatSet.end()) {
  215. if (shape.size() >= kNcdhw) {
  216. return shape;
  217. }
  218. host_shape = trans::PaddingShapeTo5d(shape, pad_index);
  219. } else {
  220. host_shape = trans::PaddingShapeTo4d(shape, pad_index);
  221. }
  222. return host_shape;
  223. }
  224. } // namespace trans
  225. } // namespace mindspore
  226. #endif // MINDSPORE_CCSRC_COMMON_TRANS_H