You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_util.h 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H
  17. #define MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H
  18. #include <memory>
  19. #include <vector>
  20. #include <string>
  21. #include <unordered_map>
  22. #include "schema/inner/model_generated.h"
  23. #include "src/common/common.h"
  24. #include "src/common/log_adapter.h"
  25. #include "include/errorcode.h"
  26. #include "securec/include/securec.h"
  27. #include "tools/optimizer/common/gllo_utils.h"
  28. namespace mindspore {
  29. namespace lite {
  30. template <typename T>
  31. int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema::PrimitiveType type) {
  32. auto attr = std::make_unique<T>();
  33. if (attr == nullptr) {
  34. MS_LOG(ERROR) << "new attr failed";
  35. return RET_NULL_PTR;
  36. }
  37. primitive->value.type = type;
  38. primitive->value.value = attr.release();
  39. return RET_OK;
  40. }
  41. using STATUS = int;
  42. STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node);
  43. STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector<Tensor *> &inputs, std::vector<Tensor *> *outputs);
  44. inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) { return cNodeT.primitive->value.type; }
  45. inline std::string GetCNodeTTypeName(const schema::CNodeT &cNodeT) {
  46. return schema::EnumNamePrimitiveType(GetCNodeTType(cNodeT));
  47. }
  48. inline schema::PrimitiveType GetOpType(const schema::CNode &opDef) { return opDef.primitive()->value_type(); }
  49. inline std::string GetOpTypeName(const schema::CNode &opDef) { return schema::EnumNamePrimitiveType(GetOpType(opDef)); }
  50. std::unordered_map<int, int> GetNc2NhAxisMap();
  51. std::vector<schema::PrimitiveType> GetInsertOpList();
  52. std::vector<schema::PrimitiveType> GetNhwcOpList();
  53. std::vector<schema::PrimitiveType> GetNchwOpList();
  54. std::vector<schema::PrimitiveType> GetNhwcAllInputOpList();
  55. std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes();
  56. std::vector<schema::PrimitiveType> Getfp32FullOpList();
  57. std::vector<schema::PrimitiveType> GetUint8NhwcOpList();
  58. std::vector<schema::PrimitiveType> GetInt8OpList();
  59. const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb);
  60. size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode);
  61. class NodeUtils {
  62. public:
  63. static STATUS ConvertDims(schema::Format src_format, const std::vector<int32_t> &src_dims, schema::Format dst_format,
  64. std::vector<int32_t> *dst_dims);
  65. };
  66. enum kTransFilterType {
  67. kKCHW2HWCK, // 0
  68. kKCHW2KHWC,
  69. kCKHW2KHWC,
  70. kCKHW2HWCK,
  71. kKCHW2HWKC,
  72. kCKHW2HWKC,
  73. kHWCK2KCHW,
  74. kHWCK2CKHW,
  75. kHWKC2KCHW,
  76. kHWKC2CKHW,
  77. kNHWC2KCHW, // 10
  78. kNHWC2CKHW,
  79. kNHWC2HWCK,
  80. kKHWC2HWCK,
  81. kCHWK2HWCK,
  82. kKHWC2CHWK,
  83. kCHWK2KHWC,
  84. kKHWC2KCHW,
  85. kCKHW2KCHW,
  86. kCHWK2KCHW,
  87. kKCHW2CKHW // 20
  88. };
  89. STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
  90. int32_t *filterH, int32_t *filterW);
  91. STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH,
  92. int32_t filterW);
  93. template <typename T>
  94. static void TransKHWC2CHWK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) {
  95. T *p1Buff = nullptr;
  96. T *p2Buff = nullptr;
  97. for (int k = 0; k < filterK; ++k) {
  98. for (int h = 0; h < filterH; ++h) {
  99. for (int w = 0; w < filterW; ++w) {
  100. for (int c = 0; c < filterC; ++c) {
  101. p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  102. p2Buff = dstData + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
  103. *p2Buff = *p1Buff;
  104. }
  105. }
  106. }
  107. }
  108. }
  109. template <typename T>
  110. static void TransKHWC2HWCK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) {
  111. T *p1Buff = nullptr;
  112. T *p2Buff = nullptr;
  113. for (int k = 0; k < filterK; ++k) {
  114. for (int h = 0; h < filterH; ++h) {
  115. for (int w = 0; w < filterW; ++w) {
  116. for (int c = 0; c < filterC; ++c) {
  117. p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  118. p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  119. *p2Buff = *p1Buff;
  120. }
  121. }
  122. }
  123. }
  124. }
  125. template <typename T>
  126. static void TransCKHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  127. T *srcData, T *dstData) {
  128. T *p1Buff = nullptr;
  129. T *p2Buff = nullptr;
  130. for (int c = 0; c < filterC; ++c) {
  131. for (int k = 0; k < filterK; ++k) {
  132. for (int h = 0; h < filterH; ++h) {
  133. for (int w = 0; w < filterW; ++w) {
  134. p1Buff = srcData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  135. if (type == kCKHW2HWCK) {
  136. p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  137. } else if (type == kCKHW2KHWC) {
  138. p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  139. } else {
  140. p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
  141. }
  142. *p2Buff = *p1Buff;
  143. }
  144. }
  145. }
  146. }
  147. }
  148. template <typename T>
  149. static void TransKCHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  150. T *srcData, T *dstData) {
  151. T *p1Buff = nullptr;
  152. T *p2Buff = nullptr;
  153. for (int k = 0; k < filterK; ++k) {
  154. for (int c = 0; c < filterC; ++c) {
  155. for (int h = 0; h < filterH; ++h) {
  156. for (int w = 0; w < filterW; ++w) {
  157. p1Buff = srcData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  158. if (type == kKCHW2HWCK) {
  159. p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  160. } else if (type == kKCHW2KHWC) {
  161. p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  162. } else if (type == kKCHW2CKHW) {
  163. p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  164. } else {
  165. p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
  166. }
  167. *p2Buff = *p1Buff;
  168. }
  169. }
  170. }
  171. }
  172. }
  173. template <typename T>
  174. static void TransCHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  175. T *srcData, T *dstData) {
  176. T *p1Buff = nullptr;
  177. T *p2Buff = nullptr;
  178. for (int c = 0; c < filterC; ++c) {
  179. for (int h = 0; h < filterH; ++h) {
  180. for (int w = 0; w < filterW; ++w) {
  181. for (int k = 0; k < filterK; ++k) {
  182. p1Buff = srcData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
  183. if (type == kCHWK2HWCK) {
  184. p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  185. } else {
  186. p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  187. }
  188. *p2Buff = *p1Buff;
  189. }
  190. }
  191. }
  192. }
  193. }
  194. template <typename T>
  195. static void TransHWCK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  196. T *srcData, T *dstData) {
  197. T *p1Buff = nullptr;
  198. T *p2Buff = nullptr;
  199. for (int h = 0; h < filterH; ++h) {
  200. for (int w = 0; w < filterW; ++w) {
  201. for (int c = 0; c < filterC; ++c) {
  202. for (int k = 0; k < filterK; ++k) {
  203. p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  204. if (type == kHWCK2KCHW) {
  205. p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  206. } else {
  207. p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  208. }
  209. *p2Buff = *p1Buff;
  210. }
  211. }
  212. }
  213. }
  214. }
  215. template <typename T>
  216. static void TransHWKC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  217. T *srcData, T *dstData) {
  218. T *p1Buff = nullptr;
  219. T *p2Buff = nullptr;
  220. for (int h = 0; h < filterH; ++h) {
  221. for (int w = 0; w < filterW; ++w) {
  222. for (int c = 0; c < filterC; ++c) {
  223. for (int k = 0; k < filterK; ++k) {
  224. p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
  225. if (type == kHWKC2KCHW) {
  226. p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  227. } else {
  228. p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  229. }
  230. *p2Buff = *p1Buff;
  231. }
  232. }
  233. }
  234. }
  235. }
  236. template <typename T>
  237. static void TransNHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  238. T *srcData, T *dstData) {
  239. T *p1Buff = nullptr;
  240. T *p2Buff = nullptr;
  241. for (int k = 0; k < filterK; ++k) {
  242. for (int h = 0; h < filterH; ++h) {
  243. for (int w = 0; w < filterW; ++w) {
  244. for (int c = 0; c < filterC; ++c) {
  245. p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
  246. if (type == kNHWC2HWCK) {
  247. p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  248. } else if (type == kNHWC2CKHW) {
  249. p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  250. } else {
  251. p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  252. }
  253. *p2Buff = *p1Buff;
  254. }
  255. }
  256. }
  257. }
  258. }
  259. template <typename T>
  260. static STATUS TransFilterData(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  261. T *srcData, T *dstData) {
  262. switch (type) {
  263. case kCHWK2HWCK:
  264. case kCHWK2KHWC: {
  265. TransCHWK(type, filterK, filterC, filterH, filterW, srcData, dstData);
  266. } break;
  267. case kKHWC2HWCK: {
  268. TransKHWC2HWCK(filterK, filterC, filterH, filterW, srcData, dstData);
  269. } break;
  270. case kKCHW2HWCK:
  271. case kKCHW2CKHW:
  272. case kKCHW2KHWC:
  273. case kKCHW2HWKC: {
  274. TransKCHW(type, filterK, filterC, filterH, filterW, srcData, dstData);
  275. } break;
  276. case kCKHW2HWCK:
  277. case kCKHW2KHWC:
  278. case kCKHW2HWKC: {
  279. TransCKHW(type, filterK, filterC, filterH, filterW, srcData, dstData);
  280. } break;
  281. case kHWCK2KCHW:
  282. case kHWCK2CKHW: {
  283. TransHWCK(type, filterK, filterC, filterH, filterW, srcData, dstData);
  284. } break;
  285. case kHWKC2KCHW:
  286. case kHWKC2CKHW: {
  287. TransHWKC(type, filterK, filterC, filterH, filterW, srcData, dstData);
  288. } break;
  289. case kNHWC2HWCK:
  290. case kNHWC2KCHW:
  291. case kNHWC2CKHW: {
  292. TransNHWC(type, filterK, filterC, filterH, filterW, srcData, dstData);
  293. } break;
  294. case kKHWC2CHWK: {
  295. TransKHWC2CHWK(filterK, filterC, filterH, filterW, srcData, dstData);
  296. } break;
  297. default: {
  298. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  299. return RET_ERROR;
  300. }
  301. }
  302. return RET_OK;
  303. }
  304. template <typename T>
  305. static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
  306. int32_t filterH, int32_t filterW) {
  307. MS_ASSERT(tensor != nullptr);
  308. int count = filterH * filterW * filterC * filterK;
  309. if (count <= 0) {
  310. MS_LOG(ERROR) << "Dim size invalid";
  311. return RET_ERROR;
  312. }
  313. std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
  314. if (buf == nullptr) {
  315. MS_LOG(ERROR) << "new buf failed";
  316. return RET_ERROR;
  317. }
  318. void *originWeightDate = tensor->data.data();
  319. T *weightData = static_cast<T *>(originWeightDate);
  320. if (weightData == nullptr) {
  321. MS_LOG(ERROR) << "weightData is nullptr";
  322. return RET_ERROR;
  323. }
  324. if (TransFilterData(type, filterK, filterC, filterH, filterW, weightData, buf.get()) != RET_OK) {
  325. MS_LOG(ERROR) << "TransFilterData failed";
  326. return RET_ERROR;
  327. }
  328. auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T));
  329. if (ret != EOK) {
  330. MS_LOG(ERROR) << "memcpy_s failed: " << ret;
  331. return RET_ERROR;
  332. }
  333. return RET_OK;
  334. }
  335. template <typename T>
  336. static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) {
  337. MS_ASSERT(tensor != nullptr);
  338. std::vector<int32_t> oriDims = tensor->dims;
  339. if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
  340. MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
  341. return RET_ERROR;
  342. }
  343. int32_t filterH;
  344. int32_t filterW;
  345. int32_t filterC;
  346. int32_t filterK;
  347. auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW);
  348. if (status != RET_OK) {
  349. MS_LOG(ERROR) << "GetFilterDim failed: " << status;
  350. return status;
  351. }
  352. status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW);
  353. if (status != RET_OK) {
  354. MS_LOG(ERROR) << "SetFilterDim failed: " << status;
  355. return status;
  356. }
  357. status = TransFilterData<T>(tensor, type, filterK, filterC, filterH, filterW);
  358. if (status != RET_OK) {
  359. MS_LOG(ERROR) << "TransFilterData failed: " << status;
  360. return status;
  361. }
  362. return RET_OK;
  363. }
  364. STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat);
  365. size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag = false);
  366. } // namespace lite
  367. } // namespace mindspore
  368. #endif // MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H