You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_util.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_PREDICT_NODE_UTIL_H
  17. #define MINDSPORE_PREDICT_NODE_UTIL_H
  18. #include <memory>
  19. #include <vector>
  20. #include "schema/inner/model_generated.h"
  21. #include "src/common/common.h"
  22. #include "utils/log_adapter.h"
  23. #include "include/errorcode.h"
  24. #include "securec/include/securec.h"
  25. namespace mindspore {
  26. namespace lite {
  27. using STATUS = int;
  28. STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node);
  29. std::vector<schema::PrimitiveType> GetNhwcOpList();
  30. std::vector<schema::PrimitiveType> Getfp32FullOpList();
  31. std::vector<schema::PrimitiveType> GetUint8NhwcOpList();
  32. std::vector<schema::PrimitiveType> GetUint8OpList();
  33. class NodeUtils {
  34. public:
  35. static STATUS ConvertDims(schema::Format src_format, const std::vector<int32_t> &src_dims, schema::Format dst_format,
  36. std::vector<int32_t> *dst_dims);
  37. static void SliceData(std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, int64_t begin,
  38. int64_t out_dim, int64_t stride);
  39. static STATUS SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector<int32_t> &input_dims,
  40. std::vector<int32_t> &begin, std::vector<int32_t> &output_dims,
  41. schema::TensorT *output, std::vector<int32_t> &stride);
  42. };
  43. enum kTransFilterType {
  44. kKCHW2HWCK, // 0
  45. kKCHW2KHWC,
  46. kCKHW2KHWC,
  47. kCKHW2HWCK,
  48. kKCHW2HWKC,
  49. kCKHW2HWKC,
  50. kHWCK2KCHW,
  51. kHWCK2CKHW,
  52. kHWKC2KCHW,
  53. kHWKC2CKHW,
  54. kNHWC2KCHW, // 10
  55. kNHWC2CKHW,
  56. kNHWC2HWCK,
  57. kKHWC2HWCK,
  58. kCHWK2HWCK,
  59. kKHWC2CHWK,
  60. kCHWK2KHWC,
  61. kKHWC2KCHW,
  62. kCKHW2KCHW,
  63. kCHWK2KCHW,
  64. kKCHW2CKHW // 20
  65. };
  66. STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t* filterK, int32_t* filterC,
  67. int32_t* filterH, int32_t* filterW);
  68. STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
  69. int32_t filterH, int32_t filterW);
  70. template <typename T>
  71. static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
  72. int32_t filterH, int32_t filterW) {
  73. MS_ASSERT(tensor != nullptr);
  74. int count = filterH * filterW * filterC * filterK;
  75. if (count <= 0) {
  76. MS_LOG(ERROR) << "Dim size invalid";
  77. return RET_ERROR;
  78. }
  79. std::unique_ptr<T> buf(new (std::nothrow) T[count]);
  80. if (buf == nullptr) {
  81. MS_LOG(ERROR) << "new buf failed";
  82. return RET_ERROR;
  83. }
  84. void *originWeightDate = tensor->data.data();
  85. T *weightData = static_cast<T *>(originWeightDate);
  86. if (weightData == nullptr) {
  87. MS_LOG(ERROR) << "weightData is nullptr";
  88. return RET_ERROR;
  89. }
  90. T *p1Buff = nullptr;
  91. T *p2Buff = nullptr;
  92. switch (type) {
  93. case kCHWK2HWCK:
  94. case kCHWK2KHWC: {
  95. for (int c = 0; c < filterC; ++c) {
  96. for (int h = 0; h < filterH; ++h) {
  97. for (int w = 0; w < filterW; ++w) {
  98. for (int k = 0; k < filterK; ++k) {
  99. p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
  100. if (type == kCHWK2HWCK) {
  101. p2Buff =
  102. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  103. } else if (type == kCHWK2KHWC) {
  104. p2Buff =
  105. buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  106. }
  107. *p2Buff = *p1Buff;
  108. }
  109. }
  110. }
  111. }
  112. } break;
  113. case kKHWC2HWCK: {
  114. for (int k = 0; k < filterK; ++k) {
  115. for (int h = 0; h < filterH; ++h) {
  116. for (int w = 0; w < filterW; ++w) {
  117. for (int c = 0; c < filterC; ++c) {
  118. p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  119. p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  120. *p2Buff = *p1Buff;
  121. }
  122. }
  123. }
  124. }
  125. } break;
  126. case kKCHW2HWCK:
  127. case kKCHW2CKHW:
  128. case kKCHW2KHWC:
  129. case kKCHW2HWKC: {
  130. for (int k = 0; k < filterK; ++k) {
  131. for (int c = 0; c < filterC; ++c) {
  132. for (int h = 0; h < filterH; ++h) {
  133. for (int w = 0; w < filterW; ++w) {
  134. p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  135. if (type == kKCHW2HWCK) {
  136. p2Buff =
  137. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  138. } else if (type == kKCHW2KHWC) {
  139. p2Buff =
  140. buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  141. } else if (type == kKCHW2CKHW) {
  142. p2Buff =
  143. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  144. } else {
  145. p2Buff =
  146. buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
  147. }
  148. *p2Buff = *p1Buff;
  149. }
  150. }
  151. }
  152. }
  153. } break;
  154. case kCKHW2HWCK:
  155. case kCKHW2KHWC:
  156. case kCKHW2HWKC: {
  157. for (int c = 0; c < filterC; ++c) {
  158. for (int k = 0; k < filterK; ++k) {
  159. for (int h = 0; h < filterH; ++h) {
  160. for (int w = 0; w < filterW; ++w) {
  161. p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  162. if (type == kCKHW2HWCK) {
  163. p2Buff =
  164. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  165. } else if (type == kCKHW2KHWC) {
  166. p2Buff =
  167. buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  168. } else {
  169. p2Buff =
  170. buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
  171. }
  172. *p2Buff = *p1Buff;
  173. }
  174. }
  175. }
  176. }
  177. } break;
  178. case kHWCK2KCHW:
  179. case kHWCK2CKHW: {
  180. for (int h = 0; h < filterH; ++h) {
  181. for (int w = 0; w < filterW; ++w) {
  182. for (int c = 0; c < filterC; ++c) {
  183. for (int k = 0; k < filterK; ++k) {
  184. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  185. if (type == kHWCK2KCHW) {
  186. p2Buff =
  187. buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  188. } else {
  189. p2Buff =
  190. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  191. }
  192. *p2Buff = *p1Buff;
  193. }
  194. }
  195. }
  196. }
  197. } break;
  198. case kHWKC2KCHW:
  199. case kHWKC2CKHW: {
  200. for (int h = 0; h < filterH; ++h) {
  201. for (int w = 0; w < filterW; ++w) {
  202. for (int c = 0; c < filterC; ++c) {
  203. for (int k = 0; k < filterK; ++k) {
  204. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
  205. if (type == kHWKC2KCHW) {
  206. p2Buff =
  207. buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  208. } else {
  209. p2Buff =
  210. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  211. }
  212. *p2Buff = *p1Buff;
  213. }
  214. }
  215. }
  216. }
  217. } break;
  218. case kNHWC2HWCK:
  219. case kNHWC2KCHW:
  220. case kNHWC2CKHW: {
  221. for (int k = 0; k < filterK; ++k) {
  222. for (int h = 0; h < filterH; ++h) {
  223. for (int w = 0; w < filterW; ++w) {
  224. for (int c = 0; c < filterC; ++c) {
  225. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
  226. if (type == kNHWC2HWCK) {
  227. p2Buff =
  228. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  229. } else if (type == kNHWC2CKHW) {
  230. p2Buff =
  231. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  232. } else {
  233. p2Buff =
  234. buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  235. }
  236. *p2Buff = *p1Buff;
  237. }
  238. }
  239. }
  240. }
  241. } break;
  242. case kKHWC2CHWK: {
  243. for (int k = 0; k < filterK; ++k) {
  244. for (int h = 0; h < filterH; ++h) {
  245. for (int w = 0; w < filterW; ++w) {
  246. for (int c = 0; c < filterC; ++c) {
  247. p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  248. p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
  249. *p2Buff = *p1Buff;
  250. }
  251. }
  252. }
  253. }
  254. } break;
  255. default: {
  256. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  257. return RET_ERROR;
  258. }
  259. }
  260. auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T));
  261. if (ret != EOK) {
  262. MS_LOG(ERROR) << "memcpy_s failed: " << ret;
  263. return RET_ERROR;
  264. }
  265. return RET_OK;
  266. }
  267. template <typename T>
  268. static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) {
  269. MS_ASSERT(tensor != nullptr);
  270. std::vector<int32_t> oriDims = tensor->dims;
  271. if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
  272. MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
  273. return RET_ERROR;
  274. }
  275. int32_t filterH;
  276. int32_t filterW;
  277. int32_t filterC;
  278. int32_t filterK;
  279. auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW);
  280. if (status != RET_OK) {
  281. MS_LOG(ERROR) << "GetFilterDim failed: " << status;
  282. return status;
  283. }
  284. status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW);
  285. if (status != RET_OK) {
  286. MS_LOG(ERROR) << "SetFilterDim failed: " << status;
  287. return status;
  288. }
  289. status = TransFilterData<T>(tensor, type, filterK, filterC, filterH, filterW);
  290. if (status != RET_OK) {
  291. MS_LOG(ERROR) << "TransFilterData failed: " << status;
  292. return status;
  293. }
  294. return RET_OK;
  295. }
  296. STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat);
  297. } // namespace lite
  298. } // namespace mindspore
  299. #endif // MINDSPORE_PREDICT_NODE_UTIL_H