You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_util.h 12 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_PREDICT_NODE_UTIL_H
  17. #define MINDSPORE_PREDICT_NODE_UTIL_H
  18. #include <memory>
  19. #include <vector>
  20. #include <unordered_map>
  21. #include "schema/inner/model_generated.h"
  22. #include "src/common/common.h"
  23. #include "src/common/log_adapter.h"
  24. #include "include/errorcode.h"
  25. #include "securec/include/securec.h"
  26. namespace mindspore {
  27. namespace lite {
  28. using STATUS = int;
  29. STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node);
  30. std::unordered_map<int, int> GetNc2NhAxisMap();
  31. std::vector<schema::PrimitiveType> GetInsertOpList();
  32. std::vector<schema::PrimitiveType> GetNhwcOpList();
  33. std::vector<schema::PrimitiveType> GetNhwcDualInputOpList();
  34. std::vector<schema::PrimitiveType> GetNhwcAllInputOpList();
  35. std::vector<schema::PrimitiveType> Getfp32FullOpList();
  36. std::vector<schema::PrimitiveType> GetUint8NhwcOpList();
  37. std::vector<schema::PrimitiveType> GetInt8OpList();
  38. class NodeUtils {
  39. public:
  40. static STATUS ConvertDims(schema::Format src_format, const std::vector<int32_t> &src_dims, schema::Format dst_format,
  41. std::vector<int32_t> *dst_dims);
  42. static void SliceData(std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, int64_t begin,
  43. int64_t out_dim, int64_t stride);
  44. static STATUS SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector<int32_t> &input_dims,
  45. std::vector<int32_t> &begin, std::vector<int32_t> &output_dims,
  46. schema::TensorT *output, std::vector<int32_t> &stride);
  47. };
  48. enum kTransFilterType {
  49. kKCHW2HWCK, // 0
  50. kKCHW2KHWC,
  51. kCKHW2KHWC,
  52. kCKHW2HWCK,
  53. kKCHW2HWKC,
  54. kCKHW2HWKC,
  55. kHWCK2KCHW,
  56. kHWCK2CKHW,
  57. kHWKC2KCHW,
  58. kHWKC2CKHW,
  59. kNHWC2KCHW, // 10
  60. kNHWC2CKHW,
  61. kNHWC2HWCK,
  62. kKHWC2HWCK,
  63. kCHWK2HWCK,
  64. kKHWC2CHWK,
  65. kCHWK2KHWC,
  66. kKHWC2KCHW,
  67. kCKHW2KCHW,
  68. kCHWK2KCHW,
  69. kKCHW2CKHW // 20
  70. };
  71. STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
  72. int32_t *filterH, int32_t *filterW);
  73. STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH,
  74. int32_t filterW);
  75. template <typename T>
  76. static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
  77. int32_t filterH, int32_t filterW) {
  78. MS_ASSERT(tensor != nullptr);
  79. int count = filterH * filterW * filterC * filterK;
  80. if (count <= 0) {
  81. MS_LOG(ERROR) << "Dim size invalid";
  82. return RET_ERROR;
  83. }
  84. std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
  85. if (buf == nullptr) {
  86. MS_LOG(ERROR) << "new buf failed";
  87. return RET_ERROR;
  88. }
  89. void *originWeightDate = tensor->data.data();
  90. T *weightData = static_cast<T *>(originWeightDate);
  91. if (weightData == nullptr) {
  92. MS_LOG(ERROR) << "weightData is nullptr";
  93. return RET_ERROR;
  94. }
  95. T *p1Buff = nullptr;
  96. T *p2Buff = nullptr;
  97. switch (type) {
  98. case kCHWK2HWCK:
  99. case kCHWK2KHWC: {
  100. for (int c = 0; c < filterC; ++c) {
  101. for (int h = 0; h < filterH; ++h) {
  102. for (int w = 0; w < filterW; ++w) {
  103. for (int k = 0; k < filterK; ++k) {
  104. p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
  105. if (type == kCHWK2HWCK) {
  106. p2Buff =
  107. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  108. } else if (type == kCHWK2KHWC) {
  109. p2Buff =
  110. buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  111. }
  112. *p2Buff = *p1Buff;
  113. }
  114. }
  115. }
  116. }
  117. } break;
  118. case kKHWC2HWCK: {
  119. for (int k = 0; k < filterK; ++k) {
  120. for (int h = 0; h < filterH; ++h) {
  121. for (int w = 0; w < filterW; ++w) {
  122. for (int c = 0; c < filterC; ++c) {
  123. p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  124. p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  125. *p2Buff = *p1Buff;
  126. }
  127. }
  128. }
  129. }
  130. } break;
  131. case kKCHW2HWCK:
  132. case kKCHW2CKHW:
  133. case kKCHW2KHWC:
  134. case kKCHW2HWKC: {
  135. for (int k = 0; k < filterK; ++k) {
  136. for (int c = 0; c < filterC; ++c) {
  137. for (int h = 0; h < filterH; ++h) {
  138. for (int w = 0; w < filterW; ++w) {
  139. p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  140. if (type == kKCHW2HWCK) {
  141. p2Buff =
  142. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  143. } else if (type == kKCHW2KHWC) {
  144. p2Buff =
  145. buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  146. } else if (type == kKCHW2CKHW) {
  147. p2Buff =
  148. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  149. } else {
  150. p2Buff =
  151. buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
  152. }
  153. *p2Buff = *p1Buff;
  154. }
  155. }
  156. }
  157. }
  158. } break;
  159. case kCKHW2HWCK:
  160. case kCKHW2KHWC:
  161. case kCKHW2HWKC: {
  162. for (int c = 0; c < filterC; ++c) {
  163. for (int k = 0; k < filterK; ++k) {
  164. for (int h = 0; h < filterH; ++h) {
  165. for (int w = 0; w < filterW; ++w) {
  166. p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  167. if (type == kCKHW2HWCK) {
  168. p2Buff =
  169. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  170. } else if (type == kCKHW2KHWC) {
  171. p2Buff =
  172. buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  173. } else {
  174. p2Buff =
  175. buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
  176. }
  177. *p2Buff = *p1Buff;
  178. }
  179. }
  180. }
  181. }
  182. } break;
  183. case kHWCK2KCHW:
  184. case kHWCK2CKHW: {
  185. for (int h = 0; h < filterH; ++h) {
  186. for (int w = 0; w < filterW; ++w) {
  187. for (int c = 0; c < filterC; ++c) {
  188. for (int k = 0; k < filterK; ++k) {
  189. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  190. if (type == kHWCK2KCHW) {
  191. p2Buff =
  192. buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  193. } else {
  194. p2Buff =
  195. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  196. }
  197. *p2Buff = *p1Buff;
  198. }
  199. }
  200. }
  201. }
  202. } break;
  203. case kHWKC2KCHW:
  204. case kHWKC2CKHW: {
  205. for (int h = 0; h < filterH; ++h) {
  206. for (int w = 0; w < filterW; ++w) {
  207. for (int c = 0; c < filterC; ++c) {
  208. for (int k = 0; k < filterK; ++k) {
  209. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
  210. if (type == kHWKC2KCHW) {
  211. p2Buff =
  212. buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  213. } else {
  214. p2Buff =
  215. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  216. }
  217. *p2Buff = *p1Buff;
  218. }
  219. }
  220. }
  221. }
  222. } break;
  223. case kNHWC2HWCK:
  224. case kNHWC2KCHW:
  225. case kNHWC2CKHW: {
  226. for (int k = 0; k < filterK; ++k) {
  227. for (int h = 0; h < filterH; ++h) {
  228. for (int w = 0; w < filterW; ++w) {
  229. for (int c = 0; c < filterC; ++c) {
  230. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
  231. if (type == kNHWC2HWCK) {
  232. p2Buff =
  233. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  234. } else if (type == kNHWC2CKHW) {
  235. p2Buff =
  236. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  237. } else {
  238. p2Buff =
  239. buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  240. }
  241. *p2Buff = *p1Buff;
  242. }
  243. }
  244. }
  245. }
  246. } break;
  247. case kKHWC2CHWK: {
  248. for (int k = 0; k < filterK; ++k) {
  249. for (int h = 0; h < filterH; ++h) {
  250. for (int w = 0; w < filterW; ++w) {
  251. for (int c = 0; c < filterC; ++c) {
  252. p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  253. p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
  254. *p2Buff = *p1Buff;
  255. }
  256. }
  257. }
  258. }
  259. } break;
  260. default: {
  261. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  262. return RET_ERROR;
  263. }
  264. }
  265. auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T));
  266. if (ret != EOK) {
  267. MS_LOG(ERROR) << "memcpy_s failed: " << ret;
  268. return RET_ERROR;
  269. }
  270. return RET_OK;
  271. }
  272. template <typename T>
  273. static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) {
  274. MS_ASSERT(tensor != nullptr);
  275. std::vector<int32_t> oriDims = tensor->dims;
  276. if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
  277. MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
  278. return RET_ERROR;
  279. }
  280. int32_t filterH;
  281. int32_t filterW;
  282. int32_t filterC;
  283. int32_t filterK;
  284. auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW);
  285. if (status != RET_OK) {
  286. MS_LOG(ERROR) << "GetFilterDim failed: " << status;
  287. return status;
  288. }
  289. status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW);
  290. if (status != RET_OK) {
  291. MS_LOG(ERROR) << "SetFilterDim failed: " << status;
  292. return status;
  293. }
  294. status = TransFilterData<T>(tensor, type, filterK, filterC, filterH, filterW);
  295. if (status != RET_OK) {
  296. MS_LOG(ERROR) << "TransFilterData failed: " << status;
  297. return status;
  298. }
  299. return RET_OK;
  300. }
  301. STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat);
  302. } // namespace lite
  303. } // namespace mindspore
  304. #endif // MINDSPORE_PREDICT_NODE_UTIL_H