You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_util.h 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_PREDICT_NODE_UTIL_H
  17. #define MINDSPORE_PREDICT_NODE_UTIL_H
  18. #include <memory>
  19. #include <vector>
  20. #include <string>
  21. #include <unordered_map>
  22. #include "schema/inner/model_generated.h"
  23. #include "src/common/common.h"
  24. #include "src/common/log_adapter.h"
  25. #include "include/errorcode.h"
  26. #include "securec/include/securec.h"
  27. namespace mindspore {
  28. namespace lite {
  29. using STATUS = int;
  30. STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node);
  31. inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) { return cNodeT.primitive->value.type; }
  32. inline std::string GetCNodeTTypeName(const schema::CNodeT &cNodeT) {
  33. return schema::EnumNamePrimitiveType(GetCNodeTType(cNodeT));
  34. }
  35. inline schema::PrimitiveType GetOpType(const schema::CNode &opDef) { return opDef.primitive()->value_type(); }
  36. inline std::string GetOpTypeName(const schema::CNode &opDef) { return schema::EnumNamePrimitiveType(GetOpType(opDef)); }
  37. std::unordered_map<int, int> GetNc2NhAxisMap();
  38. std::vector<schema::PrimitiveType> GetInsertOpList();
  39. std::vector<schema::PrimitiveType> GetNhwcOpList();
  40. std::vector<schema::PrimitiveType> GetNhwcDualInputOpList();
  41. std::vector<schema::PrimitiveType> GetNhwcAllInputOpList();
  42. std::vector<schema::PrimitiveType> Getfp32FullOpList();
  43. std::vector<schema::PrimitiveType> GetUint8NhwcOpList();
  44. std::vector<schema::PrimitiveType> GetInt8OpList();
  45. class NodeUtils {
  46. public:
  47. static STATUS ConvertDims(schema::Format src_format, const std::vector<int32_t> &src_dims, schema::Format dst_format,
  48. std::vector<int32_t> *dst_dims);
  49. static void SliceData(std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, int64_t begin,
  50. int64_t out_dim, int64_t stride);
  51. static STATUS SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector<int32_t> &input_dims,
  52. std::vector<int32_t> &begin, std::vector<int32_t> &output_dims,
  53. schema::TensorT *output, std::vector<int32_t> &stride);
  54. };
  55. enum kTransFilterType {
  56. kKCHW2HWCK, // 0
  57. kKCHW2KHWC,
  58. kCKHW2KHWC,
  59. kCKHW2HWCK,
  60. kKCHW2HWKC,
  61. kCKHW2HWKC,
  62. kHWCK2KCHW,
  63. kHWCK2CKHW,
  64. kHWKC2KCHW,
  65. kHWKC2CKHW,
  66. kNHWC2KCHW, // 10
  67. kNHWC2CKHW,
  68. kNHWC2HWCK,
  69. kKHWC2HWCK,
  70. kCHWK2HWCK,
  71. kKHWC2CHWK,
  72. kCHWK2KHWC,
  73. kKHWC2KCHW,
  74. kCKHW2KCHW,
  75. kCHWK2KCHW,
  76. kKCHW2CKHW // 20
  77. };
  78. STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
  79. int32_t *filterH, int32_t *filterW);
  80. STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH,
  81. int32_t filterW);
  82. template <typename T>
  83. static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
  84. int32_t filterH, int32_t filterW) {
  85. MS_ASSERT(tensor != nullptr);
  86. int count = filterH * filterW * filterC * filterK;
  87. if (count <= 0) {
  88. MS_LOG(ERROR) << "Dim size invalid";
  89. return RET_ERROR;
  90. }
  91. std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
  92. if (buf == nullptr) {
  93. MS_LOG(ERROR) << "new buf failed";
  94. return RET_ERROR;
  95. }
  96. void *originWeightDate = tensor->data.data();
  97. T *weightData = static_cast<T *>(originWeightDate);
  98. if (weightData == nullptr) {
  99. MS_LOG(ERROR) << "weightData is nullptr";
  100. return RET_ERROR;
  101. }
  102. T *p1Buff = nullptr;
  103. T *p2Buff = nullptr;
  104. switch (type) {
  105. case kCHWK2HWCK:
  106. case kCHWK2KHWC: {
  107. for (int c = 0; c < filterC; ++c) {
  108. for (int h = 0; h < filterH; ++h) {
  109. for (int w = 0; w < filterW; ++w) {
  110. for (int k = 0; k < filterK; ++k) {
  111. p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
  112. if (type == kCHWK2HWCK) {
  113. p2Buff =
  114. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  115. } else if (type == kCHWK2KHWC) {
  116. p2Buff =
  117. buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  118. }
  119. *p2Buff = *p1Buff;
  120. }
  121. }
  122. }
  123. }
  124. } break;
  125. case kKHWC2HWCK: {
  126. for (int k = 0; k < filterK; ++k) {
  127. for (int h = 0; h < filterH; ++h) {
  128. for (int w = 0; w < filterW; ++w) {
  129. for (int c = 0; c < filterC; ++c) {
  130. p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  131. p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  132. *p2Buff = *p1Buff;
  133. }
  134. }
  135. }
  136. }
  137. } break;
  138. case kKCHW2HWCK:
  139. case kKCHW2CKHW:
  140. case kKCHW2KHWC:
  141. case kKCHW2HWKC: {
  142. for (int k = 0; k < filterK; ++k) {
  143. for (int c = 0; c < filterC; ++c) {
  144. for (int h = 0; h < filterH; ++h) {
  145. for (int w = 0; w < filterW; ++w) {
  146. p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  147. if (type == kKCHW2HWCK) {
  148. p2Buff =
  149. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  150. } else if (type == kKCHW2KHWC) {
  151. p2Buff =
  152. buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  153. } else if (type == kKCHW2CKHW) {
  154. p2Buff =
  155. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  156. } else {
  157. p2Buff =
  158. buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
  159. }
  160. *p2Buff = *p1Buff;
  161. }
  162. }
  163. }
  164. }
  165. } break;
  166. case kCKHW2HWCK:
  167. case kCKHW2KHWC:
  168. case kCKHW2HWKC: {
  169. for (int c = 0; c < filterC; ++c) {
  170. for (int k = 0; k < filterK; ++k) {
  171. for (int h = 0; h < filterH; ++h) {
  172. for (int w = 0; w < filterW; ++w) {
  173. p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  174. if (type == kCKHW2HWCK) {
  175. p2Buff =
  176. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  177. } else if (type == kCKHW2KHWC) {
  178. p2Buff =
  179. buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  180. } else {
  181. p2Buff =
  182. buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
  183. }
  184. *p2Buff = *p1Buff;
  185. }
  186. }
  187. }
  188. }
  189. } break;
  190. case kHWCK2KCHW:
  191. case kHWCK2CKHW: {
  192. for (int h = 0; h < filterH; ++h) {
  193. for (int w = 0; w < filterW; ++w) {
  194. for (int c = 0; c < filterC; ++c) {
  195. for (int k = 0; k < filterK; ++k) {
  196. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  197. if (type == kHWCK2KCHW) {
  198. p2Buff =
  199. buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  200. } else {
  201. p2Buff =
  202. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  203. }
  204. *p2Buff = *p1Buff;
  205. }
  206. }
  207. }
  208. }
  209. } break;
  210. case kHWKC2KCHW:
  211. case kHWKC2CKHW: {
  212. for (int h = 0; h < filterH; ++h) {
  213. for (int w = 0; w < filterW; ++w) {
  214. for (int c = 0; c < filterC; ++c) {
  215. for (int k = 0; k < filterK; ++k) {
  216. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
  217. if (type == kHWKC2KCHW) {
  218. p2Buff =
  219. buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  220. } else {
  221. p2Buff =
  222. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  223. }
  224. *p2Buff = *p1Buff;
  225. }
  226. }
  227. }
  228. }
  229. } break;
  230. case kNHWC2HWCK:
  231. case kNHWC2KCHW:
  232. case kNHWC2CKHW: {
  233. for (int k = 0; k < filterK; ++k) {
  234. for (int h = 0; h < filterH; ++h) {
  235. for (int w = 0; w < filterW; ++w) {
  236. for (int c = 0; c < filterC; ++c) {
  237. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
  238. if (type == kNHWC2HWCK) {
  239. p2Buff =
  240. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  241. } else if (type == kNHWC2CKHW) {
  242. p2Buff =
  243. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  244. } else {
  245. p2Buff =
  246. buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  247. }
  248. *p2Buff = *p1Buff;
  249. }
  250. }
  251. }
  252. }
  253. } break;
  254. case kKHWC2CHWK: {
  255. for (int k = 0; k < filterK; ++k) {
  256. for (int h = 0; h < filterH; ++h) {
  257. for (int w = 0; w < filterW; ++w) {
  258. for (int c = 0; c < filterC; ++c) {
  259. p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  260. p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
  261. *p2Buff = *p1Buff;
  262. }
  263. }
  264. }
  265. }
  266. } break;
  267. default: {
  268. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  269. return RET_ERROR;
  270. }
  271. }
  272. auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T));
  273. if (ret != EOK) {
  274. MS_LOG(ERROR) << "memcpy_s failed: " << ret;
  275. return RET_ERROR;
  276. }
  277. return RET_OK;
  278. }
  279. template <typename T>
  280. static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) {
  281. MS_ASSERT(tensor != nullptr);
  282. std::vector<int32_t> oriDims = tensor->dims;
  283. if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
  284. MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
  285. return RET_ERROR;
  286. }
  287. int32_t filterH;
  288. int32_t filterW;
  289. int32_t filterC;
  290. int32_t filterK;
  291. auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW);
  292. if (status != RET_OK) {
  293. MS_LOG(ERROR) << "GetFilterDim failed: " << status;
  294. return status;
  295. }
  296. status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW);
  297. if (status != RET_OK) {
  298. MS_LOG(ERROR) << "SetFilterDim failed: " << status;
  299. return status;
  300. }
  301. status = TransFilterData<T>(tensor, type, filterK, filterC, filterH, filterW);
  302. if (status != RET_OK) {
  303. MS_LOG(ERROR) << "TransFilterData failed: " << status;
  304. return status;
  305. }
  306. return RET_OK;
  307. }
  308. STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat);
  309. } // namespace lite
  310. } // namespace mindspore
  311. #endif // MINDSPORE_PREDICT_NODE_UTIL_H