You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_util.cc 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/common/node_util.h"
  17. #include <memory>
  18. #include <vector>
  19. #include "src/common/common.h"
  20. #include "utils/log_adapter.h"
  21. #include "tools/common/graph_util.h"
  22. #include "tools/common/tensor_util.h"
  23. namespace mindspore {
  24. namespace lite {
  25. static const std::vector<schema::PrimitiveType> nhwcOpList = {
  26. schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
  27. schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
  28. schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize,
  29. schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm,
  30. schema::PrimitiveType_CaffePReLU};
  31. static const std::vector<schema::PrimitiveType> fp32FullOpList = {
  32. schema::PrimitiveType_Concat, schema::PrimitiveType_Add,
  33. schema::PrimitiveType_Floor}; // fp32 ops support C4 and nhwc in fp32
  34. static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {};
  35. static const std::vector<schema::PrimitiveType> int8OpList = {
  36. schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
  37. schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
  38. schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
  39. schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,
  40. schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation,
  41. schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection,
  42. schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin,
  43. schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm,
  44. schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div,
  45. schema::PrimitiveType_Mul, schema::PrimitiveType_Slice,
  46. schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split,
  47. schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub,
  48. schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze};
  49. std::vector<schema::PrimitiveType> Getfp32FullOpList() { return fp32FullOpList; }
  50. std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; }
  51. std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; }
  52. std::vector<schema::PrimitiveType> GetUint8OpList() { return int8OpList; }
  53. STATUS NodeUtils::ConvertDims(mindspore::lite::Format src_format, const std::vector<int32_t> &src_dims,
  54. mindspore::lite::Format dst_format, std::vector<int32_t> *dst_dims) {
  55. if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) {
  56. MS_LOG(ERROR) << "Convert format , src size " << src_dims.size()
  57. << " <3 or src format is equal to dst format,not need convert";
  58. *dst_dims = src_dims;
  59. return RET_PARAM_INVALID;
  60. }
  61. std::vector<int32_t> nchw_dim;
  62. switch (src_format) {
  63. case Format_NCHW:
  64. nchw_dim = src_dims;
  65. break;
  66. case Format_NHWC:
  67. if (src_dims.size() == DIM_DEFAULT_SIZE) {
  68. nchw_dim.push_back(src_dims[NHWC_N]);
  69. nchw_dim.push_back(src_dims[NHWC_C]);
  70. nchw_dim.push_back(src_dims[NHWC_H]);
  71. nchw_dim.push_back(src_dims[NHWC_W]);
  72. } else {
  73. nchw_dim.push_back(src_dims[HWC_C]);
  74. nchw_dim.push_back(src_dims[HWC_H]);
  75. nchw_dim.push_back(src_dims[HWC_W]);
  76. }
  77. break;
  78. default:
  79. MS_LOG(ERROR) << "Not support src format: " << schema::EnumNameFormat(src_format);
  80. return RET_ERROR;
  81. }
  82. if (nchw_dim.size() == 0) {
  83. MS_LOG(ERROR) << "Param nchw_dim is empty!";
  84. return RET_ERROR;
  85. }
  86. switch (dst_format) {
  87. case Format_NCHW:
  88. *dst_dims = nchw_dim;
  89. break;
  90. case Format_NHWC:
  91. if (src_dims.size() == DIM_DEFAULT_SIZE) {
  92. dst_dims->push_back(nchw_dim[NCHW_N]);
  93. dst_dims->push_back(nchw_dim[NCHW_H]);
  94. dst_dims->push_back(nchw_dim[NCHW_W]);
  95. dst_dims->push_back(nchw_dim[NCHW_C]);
  96. }
  97. break;
  98. default:
  99. MS_LOG(ERROR) << "Not support dst format: " << dst_format;
  100. return RET_ERROR;
  101. }
  102. return RET_OK;
  103. }
  104. STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t* filterK, int32_t* filterC,
  105. int32_t* filterH, int32_t* filterW) {
  106. MS_ASSERT(oriDims.size() == 4);
  107. if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) {
  108. *filterK = oriDims.at(KCHW_K);
  109. *filterC = oriDims.at(KCHW_C);
  110. *filterH = oriDims.at(KCHW_H);
  111. *filterW = oriDims.at(KCHW_W);
  112. } else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) {
  113. *filterC = oriDims.at(CKHW_C);
  114. *filterK = oriDims.at(CKHW_K);
  115. *filterH = oriDims.at(CKHW_H);
  116. *filterW = oriDims.at(CKHW_W);
  117. } else if (type == kHWCK2KCHW || type == kHWCK2CKHW) {
  118. *filterH = oriDims.at(HWCK_H);
  119. *filterW = oriDims.at(HWCK_W);
  120. *filterC = oriDims.at(HWCK_C);
  121. *filterK = oriDims.at(HWCK_K);
  122. } else if (type == kHWKC2KCHW || type == kHWKC2CKHW) {
  123. *filterH = oriDims.at(HWKC_H);
  124. *filterW = oriDims.at(HWKC_W);
  125. *filterK = oriDims.at(HWKC_K);
  126. *filterC = oriDims.at(HWKC_C);
  127. } else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) {
  128. *filterK = oriDims.at(NHWC_N);
  129. *filterH = oriDims.at(NHWC_H);
  130. *filterW = oriDims.at(NHWC_W);
  131. *filterC = oriDims.at(NHWC_C);
  132. } else if (type == kCHWK2HWCK || type == kCHWK2KHWC) {
  133. *filterC = oriDims.at(CHWK_C);
  134. *filterH = oriDims.at(CHWK_H);
  135. *filterW = oriDims.at(CHWK_W);
  136. *filterK = oriDims.at(CHWK_K);
  137. } else if (type == kKHWC2HWCK || type == kKHWC2CHWK) {
  138. *filterK = oriDims.at(KHWC_K);
  139. *filterH = oriDims.at(KHWC_H);
  140. *filterW = oriDims.at(KHWC_W);
  141. *filterC = oriDims.at(KHWC_C);
  142. } else {
  143. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  144. return RET_ERROR;
  145. }
  146. return RET_OK;
  147. }
  148. STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
  149. int32_t filterH, int32_t filterW) {
  150. MS_ASSERT(tensor != nullptr);
  151. if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) {
  152. tensor->dims = {filterH, filterW, filterC, filterK};
  153. } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) {
  154. tensor->dims = {filterH, filterW, filterK, filterC};
  155. } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) {
  156. tensor->dims = {filterK, filterC, filterH, filterW};
  157. } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) {
  158. tensor->dims = {filterC, filterK, filterH, filterW};
  159. } else if (type == kKHWC2CHWK) {
  160. tensor->dims = {filterC, filterH, filterW, filterK};
  161. } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) {
  162. tensor->dims = {filterK, filterH, filterW, filterC};
  163. } else {
  164. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  165. return RET_ERROR;
  166. }
  167. return RET_OK;
  168. }
  169. STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
  170. if (tensor == nullptr) {
  171. return RET_NULL_PTR;
  172. }
  173. std::vector<int32_t> oriDims = tensor->dims;
  174. if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
  175. MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
  176. return RET_ERROR;
  177. }
  178. auto srcFormat = tensor->format;
  179. auto dataType = tensor->dataType;
  180. STATUS status;
  181. switch (dstFormat) {
  182. case schema::Format_KHWC: {
  183. switch (srcFormat) {
  184. case schema::Format_KCHW:
  185. if (dataType == kNumberTypeFloat32) {
  186. status = TransFilterFormat<float>(tensor, kKCHW2KHWC);
  187. } else if (dataType == kNumberTypeUInt8) {
  188. status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC);
  189. } else if (dataType == kNumberTypeInt8) {
  190. status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC);
  191. } else {
  192. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  193. return RET_ERROR;
  194. }
  195. break;
  196. case schema::Format_CKHW:
  197. if (dataType == kNumberTypeFloat32) {
  198. status = TransFilterFormat<float>(tensor, kCKHW2KHWC);
  199. } else if (dataType == kNumberTypeUInt8) {
  200. status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC);
  201. } else if (dataType == kNumberTypeInt8) {
  202. status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC);
  203. } else {
  204. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  205. return RET_ERROR;
  206. }
  207. break;
  208. case schema::Format_CHWK:
  209. if (dataType == kNumberTypeFloat32) {
  210. status = TransFilterFormat<float>(tensor, kCHWK2KHWC);
  211. } else if (dataType == kNumberTypeUInt8) {
  212. status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC);
  213. } else if (dataType == kNumberTypeInt8) {
  214. status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC);
  215. } else {
  216. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  217. return RET_ERROR;
  218. }
  219. break;
  220. case schema::Format_KHWC:
  221. return RET_OK;
  222. default:
  223. MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to "
  224. << schema::EnumNameFormat(dstFormat);
  225. return RET_ERROR;
  226. }
  227. } break;
  228. case schema::Format_HWCK: {
  229. switch (srcFormat) {
  230. case schema::Format_KCHW:
  231. if (dataType == kNumberTypeFloat32) {
  232. status = TransFilterFormat<float>(tensor, kKCHW2HWCK);
  233. } else if (dataType == kNumberTypeUInt8) {
  234. status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK);
  235. } else if (dataType == kNumberTypeInt8) {
  236. status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK);
  237. } else {
  238. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  239. return RET_ERROR;
  240. }
  241. break;
  242. case schema::Format_KHWC:
  243. if (dataType == kNumberTypeFloat32) {
  244. status = TransFilterFormat<float>(tensor, kKHWC2HWCK);
  245. } else if (dataType == kNumberTypeUInt8) {
  246. status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK);
  247. } else if (dataType == kNumberTypeInt8) {
  248. status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK);
  249. } else {
  250. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  251. return RET_ERROR;
  252. }
  253. break;
  254. case schema::Format_CKHW:
  255. if (dataType == kNumberTypeFloat32) {
  256. status = TransFilterFormat<float>(tensor, kCKHW2HWCK);
  257. } else if (dataType == kNumberTypeUInt8) {
  258. status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK);
  259. } else if (dataType == kNumberTypeInt8) {
  260. status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK);
  261. } else {
  262. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  263. return RET_ERROR;
  264. }
  265. break;
  266. case schema::Format_CHWK:
  267. if (dataType == kNumberTypeFloat32) {
  268. status = TransFilterFormat<float>(tensor, kCHWK2HWCK);
  269. } else if (dataType == kNumberTypeUInt8) {
  270. status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK);
  271. } else if (dataType == kNumberTypeInt8) {
  272. status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK);
  273. } else {
  274. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  275. return RET_ERROR;
  276. }
  277. break;
  278. case schema::Format_HWCK:
  279. return RET_OK;
  280. default:
  281. MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to "
  282. << schema::EnumNameFormat(dstFormat);
  283. return RET_ERROR;
  284. }
  285. } break;
  286. case schema::Format_KCHW: {
  287. switch (srcFormat) {
  288. case schema::Format_KCHW:
  289. return RET_OK;
  290. case schema::Format_HWCK:
  291. if (dataType == kNumberTypeFloat32) {
  292. status = TransFilterFormat<float>(tensor, kHWCK2KCHW);
  293. } else if (dataType == kNumberTypeUInt8) {
  294. status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW);
  295. } else if (dataType == kNumberTypeInt8) {
  296. status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW);
  297. } else {
  298. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  299. return RET_ERROR;
  300. }
  301. break;
  302. case schema::Format_HWKC:
  303. if (dataType == kNumberTypeFloat32) {
  304. status = TransFilterFormat<float>(tensor, kHWKC2KCHW);
  305. } else if (dataType == kNumberTypeUInt8) {
  306. status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW);
  307. } else if (dataType == kNumberTypeInt8) {
  308. status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW);
  309. } else {
  310. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  311. return RET_ERROR;
  312. }
  313. break;
  314. case schema::Format_KHWC:
  315. if (dataType == kNumberTypeFloat32) {
  316. status = TransFilterFormat<float>(tensor, kKHWC2KCHW);
  317. } else if (dataType == kNumberTypeUInt8) {
  318. status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW);
  319. } else if (dataType == kNumberTypeInt8) {
  320. status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW);
  321. } else {
  322. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  323. return RET_ERROR;
  324. }
  325. break;
  326. case schema::Format_CKHW:
  327. if (dataType == kNumberTypeFloat32) {
  328. status = TransFilterFormat<float>(tensor, kCKHW2KCHW);
  329. } else if (dataType == kNumberTypeUInt8) {
  330. status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW);
  331. } else if (dataType == kNumberTypeInt8) {
  332. status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW);
  333. } else {
  334. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  335. return RET_ERROR;
  336. }
  337. break;
  338. case schema::Format_CHWK:
  339. if (dataType == kNumberTypeFloat32) {
  340. status = TransFilterFormat<float>(tensor, kCHWK2KCHW);
  341. } else if (dataType == kNumberTypeUInt8) {
  342. status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW);
  343. } else if (dataType == kNumberTypeInt8) {
  344. status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW);
  345. } else {
  346. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  347. return RET_ERROR;
  348. }
  349. break;
  350. default:
  351. MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to "
  352. << schema::EnumNameFormat(dstFormat);
  353. return RET_ERROR;
  354. }
  355. } break;
  356. case schema::Format_CKHW: {
  357. switch (srcFormat) {
  358. case schema::Format_HWCK:
  359. if (dataType == kNumberTypeFloat32) {
  360. status = TransFilterFormat<float>(tensor, kHWCK2CKHW);
  361. } else if (dataType == kNumberTypeUInt8) {
  362. status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW);
  363. } else if (dataType == kNumberTypeInt8) {
  364. status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW);
  365. } else {
  366. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  367. return RET_ERROR;
  368. }
  369. break;
  370. case schema::Format_HWKC:
  371. if (dataType == kNumberTypeFloat32) {
  372. status = TransFilterFormat<float>(tensor, kHWKC2CKHW);
  373. } else if (dataType == kNumberTypeUInt8) {
  374. status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW);
  375. } else if (dataType == kNumberTypeInt8) {
  376. status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW);
  377. } else {
  378. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  379. return RET_ERROR;
  380. }
  381. break;
  382. case schema::Format_KCHW:
  383. if (dataType == kNumberTypeFloat32) {
  384. status = TransFilterFormat<float>(tensor, kKCHW2CKHW);
  385. } else if (dataType == kNumberTypeUInt8) {
  386. status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW);
  387. } else if (dataType == kNumberTypeInt8) {
  388. status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW);
  389. } else {
  390. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  391. return RET_ERROR;
  392. }
  393. break;
  394. case schema::Format_CKHW:
  395. return RET_OK;
  396. default:
  397. MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to "
  398. << schema::EnumNameFormat(dstFormat);
  399. return RET_ERROR;
  400. }
  401. } break;
  402. default:
  403. MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to "
  404. << schema::EnumNameFormat(dstFormat);
  405. return RET_ERROR;
  406. }
  407. if (status != RET_OK) {
  408. MS_LOG(ERROR) << "TransFilterData failed: " << status;
  409. return status;
  410. }
  411. return RET_OK;
  412. }
  413. } // namespace lite
  414. } // namespace mindspore