You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matmul.c 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/fp32/matmul.h"
  17. void RowMajor2Row4Major(float *src_ptr, float *dst_ptr, int row, int col) {
  18. for (int r = 0; r < row; r++) {
  19. float *src = src_ptr + r * col;
  20. for (int c = 0; c < col; c++) {
  21. int cd8 = c / 4;
  22. int cm8 = c % 4;
  23. dst_ptr[cd8 * 4 * row + r * 4 + cm8] = src[c];
  24. }
  25. }
  26. return;
  27. }
  28. void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) {
  29. for (int r = 0; r < row; r++) {
  30. float *src = src_ptr + r * col;
  31. for (int c = 0; c < col; c++) {
  32. int cd8 = c / 8;
  33. int cm8 = c % 8;
  34. dst_ptr[cd8 * 8 * row + r * 8 + cm8] = src[c];
  35. }
  36. }
  37. return;
  38. }
  39. void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col) {
  40. for (int r = 0; r < row; r++) {
  41. float *src = src_ptr + r * col;
  42. for (int c = 0; c < col; c++) {
  43. int cd8 = c / C12NUM;
  44. int cm8 = c % C12NUM;
  45. dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c];
  46. }
  47. }
  48. return;
  49. }
  50. void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
  51. size_t row_up_12 = UP_ROUND(row, C12NUM);
  52. size_t row12 = row / C12NUM * C12NUM;
  53. size_t col4 = col / C4NUM * C4NUM;
  54. float *src_r = src_ptr;
  55. float *dst_r = dst_ptr;
  56. size_t ri = 0;
  57. for (; ri < row12; ri += C12NUM) {
  58. size_t ci = 0;
  59. for (; ci < col4; ci += C4NUM) {
  60. float *src_c = src_r + ci;
  61. float *dst_c = dst_r + ci * C12NUM;
  62. /* 12x4 row-major to col-major */
  63. #ifdef ENABLE_ARM64
  64. size_t stride = col * sizeof(float);
  65. asm volatile(
  66. "mov x10, %[src_c]\n"
  67. "mov x11, %[dst_c]\n"
  68. "ld1 {v0.4s}, [x10], %[stride]\n"
  69. "ld1 {v1.4s}, [x10], %[stride]\n"
  70. "ld1 {v2.4s}, [x10], %[stride]\n"
  71. "ld1 {v3.4s}, [x10], %[stride]\n"
  72. "ld1 {v4.4s}, [x10], %[stride]\n"
  73. "ld1 {v5.4s}, [x10], %[stride]\n"
  74. "ld1 {v6.4s}, [x10], %[stride]\n"
  75. "ld1 {v7.4s}, [x10], %[stride]\n"
  76. "zip1 v12.4s, v0.4s, v1.4s\n"
  77. "zip2 v13.4s, v0.4s, v1.4s\n"
  78. "zip1 v14.4s, v2.4s, v3.4s\n"
  79. "zip2 v15.4s, v2.4s, v3.4s\n"
  80. "ld1 {v8.4s}, [x10], %[stride]\n"
  81. "ld1 {v9.4s}, [x10], %[stride]\n"
  82. "ld1 {v10.4s}, [x10], %[stride]\n"
  83. "ld1 {v11.4s}, [x10], %[stride]\n"
  84. "zip1 v16.4s, v4.4s, v5.4s\n"
  85. "zip2 v17.4s, v4.4s, v5.4s\n"
  86. "zip1 v18.4s, v6.4s, v7.4s\n"
  87. "zip2 v19.4s, v6.4s, v7.4s\n"
  88. "trn1 v20.2d, v12.2d, v14.2d\n"
  89. "trn2 v23.2d, v12.2d, v14.2d\n"
  90. "trn1 v26.2d, v13.2d, v15.2d\n"
  91. "trn2 v29.2d, v13.2d, v15.2d\n"
  92. "trn1 v21.2d, v16.2d, v18.2d\n"
  93. "trn2 v24.2d, v16.2d, v18.2d\n"
  94. "trn1 v27.2d, v17.2d, v19.2d\n"
  95. "trn2 v30.2d, v17.2d, v19.2d\n"
  96. "zip1 v12.4s, v8.4s, v9.4s\n"
  97. "zip2 v13.4s, v8.4s, v9.4s\n"
  98. "zip1 v14.4s, v10.4s, v11.4s\n"
  99. "zip2 v15.4s, v10.4s, v11.4s\n"
  100. "trn1 v22.2d, v12.2d, v14.2d\n"
  101. "trn2 v25.2d, v12.2d, v14.2d\n"
  102. "trn1 v28.2d, v13.2d, v15.2d\n"
  103. "trn2 v31.2d, v13.2d, v15.2d\n"
  104. "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n"
  105. "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n"
  106. "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n"
  107. :
  108. : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
  109. : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
  110. "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
  111. "v30", "v31");
  112. #elif ENABLE_ARM32
  113. size_t stride = col * sizeof(float);
  114. asm volatile(
  115. "mov r10, %[src_c]\n"
  116. "mov r12, %[dst_c]\n"
  117. "vld1.32 {q0}, [r10], %[stride]\n"
  118. "vld1.32 {q3}, [r10], %[stride]\n"
  119. "vld1.32 {q10}, [r10], %[stride]\n"
  120. "vld1.32 {q13}, [r10], %[stride]\n"
  121. "vtrn.32 d0, d6\n"
  122. "vtrn.32 d1, d7\n"
  123. "vtrn.32 d20, d26\n"
  124. "vtrn.32 d21, d27\n"
  125. "vld1.32 {q1}, [r10], %[stride]\n"
  126. "vld1.32 {q8}, [r10], %[stride]\n"
  127. "vld1.32 {q11}, [r10], %[stride]\n"
  128. "vld1.32 {q14}, [r10], %[stride]\n"
  129. "vswp d1, d20\n"
  130. "vswp d7, d26\n"
  131. "vld1.32 {q2}, [r10], %[stride]\n"
  132. "vld1.32 {q9}, [r10], %[stride]\n"
  133. "vld1.32 {q12}, [r10], %[stride]\n"
  134. "vld1.32 {q15}, [r10], %[stride]\n"
  135. "vtrn.32 d2, d16\n"
  136. "vtrn.32 d3, d17\n"
  137. "vtrn.32 d22, d28\n"
  138. "vtrn.32 d23, d29\n"
  139. "vswp d3, d22\n"
  140. "vswp d17, d28\n"
  141. "vtrn.32 d4, d18\n"
  142. "vtrn.32 d5, d19\n"
  143. "vtrn.32 d24, d30\n"
  144. "vtrn.32 d25, d31\n"
  145. "vswp d5, d24\n"
  146. "vswp d19, d30\n"
  147. "vst1.32 {q0, q1}, [r12]!\n"
  148. "vst1.32 {q2, q3}, [r12]!\n"
  149. "vst1.32 {q8, q9}, [r12]!\n"
  150. "vst1.32 {q10, q11}, [r12]!\n"
  151. "vst1.32 {q12, q13}, [r12]!\n"
  152. "vst1.32 {q14, q15}, [r12]!\n"
  153. :
  154. : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
  155. : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
  156. #else
  157. for (int tr = 0; tr < C12NUM; tr++) {
  158. for (int tc = 0; tc < C4NUM; tc++) {
  159. dst_c[tc * C12NUM + tr] = src_c[tr * col + tc];
  160. }
  161. }
  162. #endif
  163. }
  164. for (; ci < col; ci++) {
  165. float *src_c = src_r + ci;
  166. float *dst_c = dst_r + ci * C12NUM;
  167. for (size_t i = 0; i < C12NUM; i++) {
  168. dst_c[i] = src_c[i * col];
  169. }
  170. }
  171. src_r += C12NUM * col;
  172. dst_r += C12NUM * col;
  173. }
  174. for (; ri < row; ri++) {
  175. for (size_t i = 0; i < col; i++) {
  176. dst_r[i * C12NUM] = src_r[i];
  177. }
  178. src_r += col;
  179. dst_r += 1;
  180. }
  181. for (; ri < row_up_12; ri++) {
  182. for (size_t i = 0; i < col; i++) {
  183. dst_r[i * C12NUM] = 0;
  184. }
  185. dst_r += 1;
  186. }
  187. return;
  188. }
  189. void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
  190. size_t row8 = row / C8NUM * C8NUM;
  191. size_t col4 = col / C4NUM * C4NUM;
  192. float *src_r = src_ptr;
  193. float *dst_r = dst_ptr;
  194. size_t ri = 0;
  195. for (; ri < row8; ri += C8NUM) {
  196. size_t ci = 0;
  197. for (; ci < col4; ci += C4NUM) {
  198. float *src_c = src_r + ci;
  199. float *dst_c = dst_r + ci * C8NUM;
  200. /* 8x4 row-major to col-major */
  201. #ifdef ENABLE_ARM64
  202. size_t stride = col * sizeof(float);
  203. asm volatile(
  204. "mov x10, %[src_c]\n"
  205. "mov x11, %[dst_c]\n"
  206. "ld1 {v0.4s}, [x10], %[stride]\n"
  207. "ld1 {v1.4s}, [x10], %[stride]\n"
  208. "ld1 {v2.4s}, [x10], %[stride]\n"
  209. "ld1 {v3.4s}, [x10], %[stride]\n"
  210. "zip1 v4.4s, v0.4s, v1.4s\n"
  211. "zip2 v5.4s, v0.4s, v1.4s\n"
  212. "zip1 v6.4s, v2.4s, v3.4s\n"
  213. "zip2 v7.4s, v2.4s, v3.4s\n"
  214. "ld1 {v8.4s}, [x10], %[stride]\n"
  215. "ld1 {v9.4s}, [x10], %[stride]\n"
  216. "ld1 {v10.4s}, [x10], %[stride]\n"
  217. "ld1 {v11.4s}, [x10], %[stride]\n"
  218. "trn1 v0.2d, v4.2d, v6.2d\n"
  219. "trn2 v1.2d, v4.2d, v6.2d\n"
  220. "trn1 v2.2d, v5.2d, v7.2d\n"
  221. "trn2 v3.2d, v5.2d, v7.2d\n"
  222. "zip1 v12.4s, v8.4s, v9.4s\n"
  223. "zip2 v13.4s, v8.4s, v9.4s\n"
  224. "zip1 v14.4s, v10.4s, v11.4s\n"
  225. "zip2 v15.4s, v10.4s, v11.4s\n"
  226. "trn1 v8.2d, v12.2d, v14.2d\n"
  227. "trn2 v9.2d, v12.2d, v14.2d\n"
  228. "trn1 v10.2d, v13.2d, v15.2d\n"
  229. "trn2 v11.2d, v13.2d, v15.2d\n"
  230. "st1 {v0.4s}, [x11], #16\n"
  231. "st1 {v8.4s}, [x11], #16\n"
  232. "st1 {v1.4s}, [x11], #16\n"
  233. "st1 {v9.4s}, [x11], #16\n"
  234. "st1 {v2.4s}, [x11],#16\n"
  235. "st1 {v10.4s}, [x11], #16\n"
  236. "st1 {v3.4s}, [x11],#16\n"
  237. "st1 {v11.4s}, [x11], #16\n"
  238. :
  239. : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
  240. : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
  241. "v15");
  242. #elif ENABLE_ARM32
  243. size_t stride = col * sizeof(float);
  244. asm volatile(
  245. "mov r10, %[src_c]\n"
  246. "mov r11, %[dst_c]\n"
  247. "vld1.32 {q0}, [r10], %[stride]\n"
  248. "vld1.32 {q2}, [r10], %[stride]\n"
  249. "vld1.32 {q4}, [r10], %[stride]\n"
  250. "vld1.32 {q6}, [r10], %[stride]\n"
  251. "vtrn.32 d0, d4\n"
  252. "vtrn.32 d1, d5\n"
  253. "vtrn.32 d8, d12\n"
  254. "vtrn.32 d9, d13\n"
  255. "vld1.32 {q1}, [r10], %[stride]\n"
  256. "vld1.32 {q3}, [r10], %[stride]\n"
  257. "vld1.32 {q5}, [r10], %[stride]\n"
  258. "vld1.32 {q7}, [r10], %[stride]\n"
  259. "vswp d1, d8\n"
  260. "vswp d5, d12\n"
  261. "vtrn.32 d2, d6\n"
  262. "vtrn.32 d3, d7\n"
  263. "vtrn.32 d10, d14\n"
  264. "vtrn.32 d11, d15\n"
  265. "vswp d3, d10\n"
  266. "vswp d7, d14\n"
  267. "vst1.32 {q0, q1}, [r11]!\n"
  268. "vst1.32 {q2, q3}, [r11]!\n"
  269. "vst1.32 {q4, q5}, [r11]!\n"
  270. "vst1.32 {q6, q7}, [r11]!\n"
  271. :
  272. : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
  273. : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
  274. #else
  275. for (int tr = 0; tr < 8; tr++) {
  276. for (int tc = 0; tc < 4; tc++) {
  277. dst_c[tc * 8 + tr] = src_c[tr * col + tc];
  278. }
  279. }
  280. #endif
  281. }
  282. for (; ci < col; ci++) {
  283. float *src_c = src_r + ci;
  284. float *dst_c = dst_r + ci * C8NUM;
  285. for (size_t i = 0; i < C8NUM; i++) {
  286. dst_c[i] = src_c[i * col];
  287. }
  288. }
  289. src_r += C8NUM * col;
  290. dst_r += C8NUM * col;
  291. }
  292. for (; ri < row; ri++) {
  293. for (size_t i = 0; i < col; i++) {
  294. dst_r[i * C8NUM] = src_r[i];
  295. }
  296. src_r += col;
  297. dst_r += 1;
  298. }
  299. return;
  300. }
  301. void RowMajor2Col4Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
  302. size_t row8 = row / C4NUM * C4NUM;
  303. size_t col4 = col / C4NUM * C4NUM;
  304. float *src_r = src_ptr;
  305. float *dst_r = dst_ptr;
  306. size_t ri = 0;
  307. for (; ri < row8; ri += C4NUM) {
  308. size_t ci = 0;
  309. for (; ci < col4; ci += C4NUM) {
  310. float *src_c = src_r + ci;
  311. float *dst_c = dst_r + ci * C4NUM;
  312. /* 4x4 row-major to col-major */
  313. #ifdef ENABLE_ARM32
  314. size_t stride = col * 4;
  315. asm volatile(
  316. "mov r10, %[src_c]\n"
  317. "mov r12, %[dst_c]\n"
  318. "vld1.32 {q0}, [r10], %[stride]\n"
  319. "vld1.32 {q1}, [r10], %[stride]\n"
  320. "vld1.32 {q2}, [r10], %[stride]\n"
  321. "vld1.32 {q3}, [r10], %[stride]\n"
  322. "vtrn.32 d0, d2\n"
  323. "vtrn.32 d1, d3\n"
  324. "vtrn.32 d4, d6\n"
  325. "vtrn.32 d5, d7\n"
  326. "vswp d1, d4\n"
  327. "vswp d3, d6\n"
  328. "vst1.32 {q0}, [r12]!\n"
  329. "vst1.32 {q1}, [r12]!\n"
  330. "vst1.32 {q2}, [r12]!\n"
  331. "vst1.32 {q3}, [r12]!\n"
  332. :
  333. : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
  334. : "r10", "r12", "q0", "q1", "q2", "q3");
  335. #else
  336. for (int tr = 0; tr < C4NUM; tr++) {
  337. for (int tc = 0; tc < C4NUM; tc++) {
  338. dst_c[tc * C4NUM + tr] = src_c[tr * col + tc];
  339. }
  340. }
  341. #endif
  342. }
  343. for (; ci < col; ci++) {
  344. float *src_c = src_r + ci;
  345. float *dst_c = dst_r + ci * C4NUM;
  346. for (size_t i = 0; i < C4NUM; i++) {
  347. dst_c[i] = src_c[i * col];
  348. }
  349. }
  350. src_r += C4NUM * col;
  351. dst_r += C4NUM * col;
  352. }
  353. for (; ri < row; ri++) {
  354. for (size_t i = 0; i < col; i++) {
  355. dst_r[i * C4NUM] = src_r[i];
  356. }
  357. src_r += col;
  358. dst_r += 1;
  359. }
  360. return;
  361. }
  362. void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
  363. int col, int stride, int out_type) {
  364. if (out_type == OutType_Nhwc) {
  365. for (int r = 0; r < row; r++) {
  366. for (int c = 0; c < col; c++) {
  367. int r12div = r / 12, r12mod = r % 12;
  368. int c8div = c / 8, c8mod = c % 8;
  369. size_t ci = r * stride + c;
  370. float value = 0;
  371. for (int d = 0; d < deep; d++) {
  372. size_t ai = r12div * deep * 12 + d * 12 + r12mod;
  373. size_t bi = c8div * deep * 8 + d * 8 + c8mod;
  374. value = value + a[ai] * b[bi];
  375. }
  376. if (bias != NULL) value += bias[c];
  377. if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
  378. if (act_type != ActType_No) value = MSMAX(0.0f, value);
  379. dst[ci] = value;
  380. }
  381. }
  382. } else if (out_type == OutType_C8) {
  383. int col_8 = UP_ROUND(col, C8NUM);
  384. int row_12 = UP_ROUND(row, C12NUM);
  385. for (int r = 0; r < row_12; r++) {
  386. for (int c = 0; c < col_8; c++) {
  387. int r12div = r / C12NUM, r12mod = r % C12NUM;
  388. int c8div = c / C8NUM, c8mod = c % C8NUM;
  389. size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod);
  390. float value = 0;
  391. for (int d = 0; d < deep; d++) {
  392. size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
  393. size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
  394. value = value + a[ai] * b[bi];
  395. }
  396. if (bias != NULL) value += bias[c];
  397. if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
  398. if (act_type != ActType_No) value = MSMAX(0.0f, value);
  399. dst[ci] = value;
  400. }
  401. }
  402. } else {
  403. for (int i = 0; i < row; ++i) {
  404. int src_r_offset = i;
  405. int dst_r_offset = i * col * stride;
  406. for (int j = 0; j < col; ++j) {
  407. int c8div = j / 8, c8mod = j % 8;
  408. size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
  409. float value = 0;
  410. for (int d = 0; d < deep; ++d) {
  411. size_t ai = src_r_offset + d * C12NUM;
  412. size_t bi = c8div * deep * 8 + d * 8 + c8mod;
  413. value = value + a[ai] * b[bi];
  414. }
  415. if (bias != NULL) value += bias[j];
  416. if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
  417. if (act_type != ActType_No) value = MSMAX(0.0f, value);
  418. dst[ci] = value;
  419. }
  420. }
  421. }
  422. return;
  423. }
  424. void MatMul4x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
  425. int col, int stride, int out_type) {
  426. if (out_type == OutType_C8) {
  427. int col_8 = UP_ROUND(col, C8NUM);
  428. int row_4 = UP_ROUND(row, C4NUM);
  429. for (int r = 0; r < row_4; r++) {
  430. for (int c = 0; c < col_8; c++) {
  431. int r4div = r / C4NUM, r4mod = r % C4NUM;
  432. int c8div = c / C8NUM, c8mod = c % C8NUM;
  433. size_t ci = (c8div * C8NUM * row_4 + r * C8NUM + c8mod);
  434. float value = 0;
  435. for (int d = 0; d < deep; d++) {
  436. size_t ai = r4div * deep * C4NUM + d * C4NUM + r4mod;
  437. size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
  438. value = value + a[ai] * b[bi];
  439. }
  440. if (bias != NULL) value += bias[c];
  441. if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
  442. if (act_type != ActType_No) value = MSMAX(0.0f, value);
  443. dst[ci] = value;
  444. }
  445. }
  446. }
  447. return;
  448. }
  449. void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
  450. int col, size_t stride, int out_type) {
  451. #ifdef ENABLE_ARM64
  452. if (out_type == OutType_C8) {
  453. MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
  454. } else {
  455. MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
  456. }
  457. #elif ENABLE_ARM32
  458. if (out_type == OutType_C8) {
  459. MatmulFloatNeon32(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
  460. } else {
  461. MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
  462. }
  463. #else
  464. MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
  465. #endif
  466. }
  467. #ifdef ENABLE_NNACL_INFER_SHAPE
  468. static void SwapDims(int *dims, int index1, int index2) {
  469. int tmp = dims[index1];
  470. dims[index1] = dims[index2];
  471. dims[index2] = tmp;
  472. }
  473. int MatMulInferShape(int **in_shape, int in_num, size_t *dim_size, int *out_shape, int *in_format, int *out_format,
  474. int *in_datatype, int *out_datatype, OpParameter *param) {
  475. *out_datatype = in_datatype[0];
  476. *out_format = in_format[0];
  477. if (dim_size[0] < 2 || dim_size[1] < 2) {
  478. return NNACL_PARAM_INVALID;
  479. }
  480. for (int i = 0; i < dim_size[0] - 2; ++i) {
  481. if (in_shape[0][i] != in_shape[1][i]) {
  482. return NNACL_PARAM_INVALID;
  483. }
  484. }
  485. MatMulParameter *matmul_param = (MatMulParameter *)param;
  486. if (matmul_param->a_transpose_) {
  487. SwapDims(in_shape[0], dim_size[0] - 1, dim_size[0] - 2);
  488. }
  489. if (matmul_param->b_transpose_) {
  490. SwapDims(in_shape[1], dim_size[1] - 1, dim_size[1] - 2);
  491. }
  492. for (int i = 0; i < dim_size[0] - 1; ++i) {
  493. out_shape[i] = in_shape[0][i];
  494. }
  495. out_shape[dim_size[0] - 1] = in_shape[1][dim_size[1] - 1];
  496. return NNACL_OK;
  497. }
  498. #endif