You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_util.cc 20 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/common/node_util.h"
  17. #include <memory>
  18. #include <vector>
  19. #include "src/common/common.h"
  20. #include "src/common/log_adapter.h"
  21. #include "tools/common/graph_util.h"
  22. #include "tools/common/tensor_util.h"
  23. namespace mindspore {
  24. namespace lite {
  25. static const std::vector<schema::PrimitiveType> nhwcOpList = {
  26. #ifdef SUPPORT_TRAIN
  27. schema::PrimitiveType_Conv2DGradFilter,
  28. schema::PrimitiveType_Conv2DGradInput,
  29. schema::PrimitiveType_GroupConv2DGradInput,
  30. schema::PrimitiveType_PoolingGrad,
  31. schema::PrimitiveType_BiasGrad,
  32. schema::PrimitiveType_BNGrad,
  33. schema::PrimitiveType_ActivationGrad,
  34. schema::PrimitiveType_ApplyMomentum,
  35. schema::PrimitiveType_Sgd,
  36. schema::PrimitiveType_Adam,
  37. #endif
  38. schema::PrimitiveType_Conv2D,
  39. schema::PrimitiveType_DeConv2D,
  40. schema::PrimitiveType_DepthwiseConv2D,
  41. schema::PrimitiveType_DeDepthwiseConv2D,
  42. schema::PrimitiveType_Pooling,
  43. schema::PrimitiveType_Resize,
  44. schema::PrimitiveType_BatchNorm,
  45. schema::PrimitiveType_FusedBatchNorm,
  46. schema::PrimitiveType_PReLU,
  47. schema::PrimitiveType_BiasAdd,
  48. schema::PrimitiveType_InstanceNorm};
  49. static const std::vector<schema::PrimitiveType> nhwcOpDualInputList = {
  50. #ifdef SUPPORT_TRAIN
  51. schema::PrimitiveType_Conv2DGradFilter, schema::PrimitiveType_BNGrad
  52. #endif
  53. };
  54. static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = {
  55. #ifdef SUPPORT_TRAIN
  56. schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_ActivationGrad
  57. #endif
  58. };
  59. static const std::vector<schema::PrimitiveType> fp32FullOpList = {
  60. schema::PrimitiveType_Concat, schema::PrimitiveType_Add,
  61. schema::PrimitiveType_Floor}; // fp32 ops support C4 and nhwc in fp32
  62. static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {};
  63. static const std::vector<schema::PrimitiveType> int8OpList = {
  64. schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
  65. schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
  66. schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
  67. schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,
  68. schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation,
  69. schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection,
  70. schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin,
  71. schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm,
  72. schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div,
  73. schema::PrimitiveType_Mul, schema::PrimitiveType_Slice,
  74. schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split,
  75. schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub,
  76. schema::PrimitiveType_StridedSlice, schema::PrimitiveType_TopK,
  77. schema::PrimitiveType_Unsqueeze, schema::PrimitiveType_MatMul,
  78. schema::PrimitiveType_Pad, schema::PrimitiveType_DeConv2D,
  79. schema::PrimitiveType_Scale};
  80. static const std::vector<schema::PrimitiveType> needInsertOpList = {
  81. schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
  82. schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add,
  83. schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop};
  84. static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};
  85. std::unordered_map<int, int> GetNc2NhAxisMap() { return nc2NhAxisMap; }
  86. std::vector<schema::PrimitiveType> GetInsertOpList() { return needInsertOpList; }
  87. std::vector<schema::PrimitiveType> Getfp32FullOpList() { return fp32FullOpList; }
  88. std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; }
  89. std::vector<schema::PrimitiveType> GetNhwcDualInputOpList() { return nhwcOpDualInputList; }
  90. std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInputList; }
  91. std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; }
  92. std::vector<schema::PrimitiveType> GetInt8OpList() { return int8OpList; }
  93. STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims,
  94. mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) {
  95. if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) {
  96. MS_LOG(ERROR) << "Convert format , src size " << src_dims.size()
  97. << " <3 or src format is equal to dst format,not need convert";
  98. *dst_dims = src_dims;
  99. return RET_PARAM_INVALID;
  100. }
  101. std::vector<int32_t> nchw_dim;
  102. switch (src_format) {
  103. case schema::Format::Format_NCHW:
  104. nchw_dim = src_dims;
  105. break;
  106. case schema::Format::Format_NHWC:
  107. if (src_dims.size() == DIM_DEFAULT_SIZE) {
  108. nchw_dim.push_back(src_dims[NHWC_N]);
  109. nchw_dim.push_back(src_dims[NHWC_C]);
  110. nchw_dim.push_back(src_dims[NHWC_H]);
  111. nchw_dim.push_back(src_dims[NHWC_W]);
  112. } else {
  113. nchw_dim.push_back(src_dims[HWC_C]);
  114. nchw_dim.push_back(src_dims[HWC_H]);
  115. nchw_dim.push_back(src_dims[HWC_W]);
  116. }
  117. break;
  118. default:
  119. MS_LOG(ERROR) << "Not support src format: " << EnumNameFormat(src_format);
  120. return RET_ERROR;
  121. }
  122. if (nchw_dim.size() == 0) {
  123. MS_LOG(ERROR) << "Param nchw_dim is empty!";
  124. return RET_ERROR;
  125. }
  126. switch (dst_format) {
  127. case schema::Format::Format_NCHW:
  128. *dst_dims = nchw_dim;
  129. break;
  130. case schema::Format::Format_NHWC:
  131. if (src_dims.size() == DIM_DEFAULT_SIZE) {
  132. dst_dims->push_back(nchw_dim[NCHW_N]);
  133. dst_dims->push_back(nchw_dim[NCHW_H]);
  134. dst_dims->push_back(nchw_dim[NCHW_W]);
  135. dst_dims->push_back(nchw_dim[NCHW_C]);
  136. }
  137. break;
  138. default:
  139. MS_LOG(ERROR) << "Not support dst format: " << dst_format;
  140. return RET_ERROR;
  141. }
  142. return RET_OK;
  143. }
  144. STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
  145. int32_t *filterH, int32_t *filterW) {
  146. MS_ASSERT(oriDims.size() == 4);
  147. if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) {
  148. *filterK = oriDims.at(KCHW_K);
  149. *filterC = oriDims.at(KCHW_C);
  150. *filterH = oriDims.at(KCHW_H);
  151. *filterW = oriDims.at(KCHW_W);
  152. } else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) {
  153. *filterC = oriDims.at(CKHW_C);
  154. *filterK = oriDims.at(CKHW_K);
  155. *filterH = oriDims.at(CKHW_H);
  156. *filterW = oriDims.at(CKHW_W);
  157. } else if (type == kHWCK2KCHW || type == kHWCK2CKHW) {
  158. *filterH = oriDims.at(HWCK_H);
  159. *filterW = oriDims.at(HWCK_W);
  160. *filterC = oriDims.at(HWCK_C);
  161. *filterK = oriDims.at(HWCK_K);
  162. } else if (type == kHWKC2KCHW || type == kHWKC2CKHW) {
  163. *filterH = oriDims.at(HWKC_H);
  164. *filterW = oriDims.at(HWKC_W);
  165. *filterK = oriDims.at(HWKC_K);
  166. *filterC = oriDims.at(HWKC_C);
  167. } else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) {
  168. *filterK = oriDims.at(NHWC_N);
  169. *filterH = oriDims.at(NHWC_H);
  170. *filterW = oriDims.at(NHWC_W);
  171. *filterC = oriDims.at(NHWC_C);
  172. } else if (type == kCHWK2HWCK || type == kCHWK2KHWC) {
  173. *filterC = oriDims.at(CHWK_C);
  174. *filterH = oriDims.at(CHWK_H);
  175. *filterW = oriDims.at(CHWK_W);
  176. *filterK = oriDims.at(CHWK_K);
  177. } else if (type == kKHWC2HWCK || type == kKHWC2CHWK) {
  178. *filterK = oriDims.at(KHWC_K);
  179. *filterH = oriDims.at(KHWC_H);
  180. *filterW = oriDims.at(KHWC_W);
  181. *filterC = oriDims.at(KHWC_C);
  182. } else {
  183. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  184. return RET_ERROR;
  185. }
  186. return RET_OK;
  187. }
  188. STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH,
  189. int32_t filterW) {
  190. MS_ASSERT(tensor != nullptr);
  191. if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) {
  192. tensor->dims = {filterH, filterW, filterC, filterK};
  193. } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) {
  194. tensor->dims = {filterH, filterW, filterK, filterC};
  195. } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) {
  196. tensor->dims = {filterK, filterC, filterH, filterW};
  197. } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) {
  198. tensor->dims = {filterC, filterK, filterH, filterW};
  199. } else if (type == kKHWC2CHWK) {
  200. tensor->dims = {filterC, filterH, filterW, filterK};
  201. } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) {
  202. tensor->dims = {filterK, filterH, filterW, filterC};
  203. } else {
  204. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  205. return RET_ERROR;
  206. }
  207. return RET_OK;
  208. }
  209. STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
  210. if (tensor == nullptr) {
  211. return RET_NULL_PTR;
  212. }
  213. std::vector<int32_t> oriDims = tensor->dims;
  214. if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
  215. MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
  216. return RET_ERROR;
  217. }
  218. auto srcFormat = tensor->format;
  219. auto dataType = tensor->dataType;
  220. STATUS status;
  221. switch (dstFormat) {
  222. case schema::Format::Format_KHWC: {
  223. switch (srcFormat) {
  224. case schema::Format::Format_KCHW:
  225. if (dataType == kNumberTypeFloat32) {
  226. status = TransFilterFormat<float>(tensor, kKCHW2KHWC);
  227. } else if (dataType == kNumberTypeUInt8) {
  228. status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC);
  229. } else if (dataType == kNumberTypeInt8) {
  230. status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC);
  231. } else {
  232. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  233. return RET_ERROR;
  234. }
  235. break;
  236. case schema::Format::Format_CKHW:
  237. if (dataType == kNumberTypeFloat32) {
  238. status = TransFilterFormat<float>(tensor, kCKHW2KHWC);
  239. } else if (dataType == kNumberTypeUInt8) {
  240. status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC);
  241. } else if (dataType == kNumberTypeInt8) {
  242. status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC);
  243. } else {
  244. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  245. return RET_ERROR;
  246. }
  247. break;
  248. case schema::Format::Format_CHWK:
  249. if (dataType == kNumberTypeFloat32) {
  250. status = TransFilterFormat<float>(tensor, kCHWK2KHWC);
  251. } else if (dataType == kNumberTypeUInt8) {
  252. status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC);
  253. } else if (dataType == kNumberTypeInt8) {
  254. status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC);
  255. } else {
  256. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  257. return RET_ERROR;
  258. }
  259. break;
  260. case schema::Format::Format_KHWC:
  261. return RET_OK;
  262. default:
  263. MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
  264. << EnumNameFormat(dstFormat);
  265. return RET_ERROR;
  266. }
  267. } break;
  268. case schema::Format::Format_HWCK: {
  269. switch (srcFormat) {
  270. case schema::Format::Format_KCHW:
  271. if (dataType == kNumberTypeFloat32) {
  272. status = TransFilterFormat<float>(tensor, kKCHW2HWCK);
  273. } else if (dataType == kNumberTypeUInt8) {
  274. status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK);
  275. } else if (dataType == kNumberTypeInt8) {
  276. status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK);
  277. } else {
  278. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  279. return RET_ERROR;
  280. }
  281. break;
  282. case schema::Format::Format_KHWC:
  283. if (dataType == kNumberTypeFloat32) {
  284. status = TransFilterFormat<float>(tensor, kKHWC2HWCK);
  285. } else if (dataType == kNumberTypeUInt8) {
  286. status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK);
  287. } else if (dataType == kNumberTypeInt8) {
  288. status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK);
  289. } else {
  290. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  291. return RET_ERROR;
  292. }
  293. break;
  294. case schema::Format::Format_CKHW:
  295. if (dataType == kNumberTypeFloat32) {
  296. status = TransFilterFormat<float>(tensor, kCKHW2HWCK);
  297. } else if (dataType == kNumberTypeUInt8) {
  298. status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK);
  299. } else if (dataType == kNumberTypeInt8) {
  300. status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK);
  301. } else {
  302. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  303. return RET_ERROR;
  304. }
  305. break;
  306. case schema::Format::Format_CHWK:
  307. if (dataType == kNumberTypeFloat32) {
  308. status = TransFilterFormat<float>(tensor, kCHWK2HWCK);
  309. } else if (dataType == kNumberTypeUInt8) {
  310. status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK);
  311. } else if (dataType == kNumberTypeInt8) {
  312. status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK);
  313. } else {
  314. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  315. return RET_ERROR;
  316. }
  317. break;
  318. case schema::Format::Format_HWCK:
  319. return RET_OK;
  320. default:
  321. MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
  322. << EnumNameFormat(dstFormat);
  323. return RET_ERROR;
  324. }
  325. } break;
  326. case schema::Format::Format_KCHW: {
  327. switch (srcFormat) {
  328. case schema::Format::Format_KCHW:
  329. return RET_OK;
  330. case schema::Format::Format_HWCK:
  331. if (dataType == kNumberTypeFloat32) {
  332. status = TransFilterFormat<float>(tensor, kHWCK2KCHW);
  333. } else if (dataType == kNumberTypeUInt8) {
  334. status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW);
  335. } else if (dataType == kNumberTypeInt8) {
  336. status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW);
  337. } else {
  338. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  339. return RET_ERROR;
  340. }
  341. break;
  342. case schema::Format::Format_HWKC:
  343. if (dataType == kNumberTypeFloat32) {
  344. status = TransFilterFormat<float>(tensor, kHWKC2KCHW);
  345. } else if (dataType == kNumberTypeUInt8) {
  346. status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW);
  347. } else if (dataType == kNumberTypeInt8) {
  348. status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW);
  349. } else {
  350. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  351. return RET_ERROR;
  352. }
  353. break;
  354. case schema::Format::Format_KHWC:
  355. if (dataType == kNumberTypeFloat32) {
  356. status = TransFilterFormat<float>(tensor, kKHWC2KCHW);
  357. } else if (dataType == kNumberTypeUInt8) {
  358. status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW);
  359. } else if (dataType == kNumberTypeInt8) {
  360. status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW);
  361. } else {
  362. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  363. return RET_ERROR;
  364. }
  365. break;
  366. case schema::Format::Format_CKHW:
  367. if (dataType == kNumberTypeFloat32) {
  368. status = TransFilterFormat<float>(tensor, kCKHW2KCHW);
  369. } else if (dataType == kNumberTypeUInt8) {
  370. status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW);
  371. } else if (dataType == kNumberTypeInt8) {
  372. status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW);
  373. } else {
  374. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  375. return RET_ERROR;
  376. }
  377. break;
  378. case schema::Format::Format_CHWK:
  379. if (dataType == kNumberTypeFloat32) {
  380. status = TransFilterFormat<float>(tensor, kCHWK2KCHW);
  381. } else if (dataType == kNumberTypeUInt8) {
  382. status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW);
  383. } else if (dataType == kNumberTypeInt8) {
  384. status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW);
  385. } else {
  386. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  387. return RET_ERROR;
  388. }
  389. break;
  390. default:
  391. MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
  392. << EnumNameFormat(dstFormat);
  393. return RET_ERROR;
  394. }
  395. } break;
  396. case schema::Format::Format_CKHW: {
  397. switch (srcFormat) {
  398. case schema::Format::Format_HWCK:
  399. if (dataType == kNumberTypeFloat32) {
  400. status = TransFilterFormat<float>(tensor, kHWCK2CKHW);
  401. } else if (dataType == kNumberTypeUInt8) {
  402. status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW);
  403. } else if (dataType == kNumberTypeInt8) {
  404. status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW);
  405. } else {
  406. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  407. return RET_ERROR;
  408. }
  409. break;
  410. case schema::Format::Format_HWKC:
  411. if (dataType == kNumberTypeFloat32) {
  412. status = TransFilterFormat<float>(tensor, kHWKC2CKHW);
  413. } else if (dataType == kNumberTypeUInt8) {
  414. status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW);
  415. } else if (dataType == kNumberTypeInt8) {
  416. status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW);
  417. } else {
  418. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  419. return RET_ERROR;
  420. }
  421. break;
  422. case schema::Format::Format_KCHW:
  423. if (dataType == kNumberTypeFloat32) {
  424. status = TransFilterFormat<float>(tensor, kKCHW2CKHW);
  425. } else if (dataType == kNumberTypeUInt8) {
  426. status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW);
  427. } else if (dataType == kNumberTypeInt8) {
  428. status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW);
  429. } else {
  430. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  431. return RET_ERROR;
  432. }
  433. break;
  434. case schema::Format::Format_CKHW:
  435. return RET_OK;
  436. default:
  437. MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
  438. << EnumNameFormat(dstFormat);
  439. return RET_ERROR;
  440. }
  441. } break;
  442. default:
  443. MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to "
  444. << EnumNameFormat(dstFormat);
  445. return RET_ERROR;
  446. }
  447. if (status != RET_OK) {
  448. MS_LOG(ERROR) << "TransFilterData failed: " << status;
  449. return status;
  450. }
  451. return RET_OK;
  452. }
  453. } // namespace lite
  454. } // namespace mindspore