You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_mk4_8x8x8.h 57 kB


  1. /**
  2. * \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include <inttypes.h>
  13. #include "src/aarch64/matrix_mul/asm/common.h"
  14. #include "src/arm_common/simd_macro/marm_neon.h"
  15. namespace megdnn {
  16. namespace aarch64 {
  17. namespace matmul_mk4_8x8x8 {
  18. /**
  19. * Overview of register layout:
  20. *
  21. * A 8x8 cell of Lhs is stored in 8bit in v16-v17
  22. * B 8x8 cell of Rhs is stored in 8bit in v0-v15, v20-v23
  23. * C 8x8 block of accumulators is stored in 16bit in v24-v31
  24. *
  25. * +---------------------------------+
  26. * | v0 ------------------------ v7 |
  27. * | v8 ------------------------ v15|
  28. * Rhs +---------------------------------+
  29. * Lhs | |
  30. * +--------+ - - - - +---------------------------------+
  31. * | v16 | | v24 |
  32. * | v17 | | v25 |
  33. * | v16 | | v26 |
  34. * | v17 | | v27 |
  35. * | v16 | | v28 |
  36. * | v17 | | v29 |
  37. * | v16 | | v30 |
  38. * | v17 | | v31 |
  39. * +--------+ - - - - +---------------------------------+
  40. *
  41. * Accumulator
  42. */
  43. static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
  44. int16_t* output, int LDC, bool is_first_k, int m_remain,
  45. int n_remain) {
  46. K /= 8;
  47. LDC = LDC * sizeof(int16_t);
  48. const int8_t* a_ptr = packB;//packA;
  49. const int8_t* b_ptr = packA;//packB;
  50. // clang-format off
  51. #define LOAD_C_8 \
  52. "ld1 {v0.8h}, [x0], #16\n" \
  53. "ld1 {v1.8h}, [x0], #16\n" \
  54. "ld1 {v2.8h}, [x0], #16\n" \
  55. "ld1 {v3.8h}, [x0], #16\n" \
  56. "ld1 {v4.8h}, [x1], #16\n" \
  57. "ld1 {v5.8h}, [x1], #16\n" \
  58. "ld1 {v6.8h}, [x1], #16\n" \
  59. "ld1 {v7.8h}, [x1], #16\n" \
  60. #define STORE_C_8 \
  61. "st1 {v0.8h}, [x0], #16\n" \
  62. "st1 {v1.8h}, [x0], #16\n" \
  63. "st1 {v2.8h}, [x0], #16\n" \
  64. "st1 {v3.8h}, [x0], #16\n" \
  65. "st1 {v4.8h}, [x1], #16\n" \
  66. "st1 {v5.8h}, [x1], #16\n" \
  67. "st1 {v6.8h}, [x1], #16\n" \
  68. "st1 {v7.8h}, [x1], #16\n" \
  69. register int16_t* outptr asm("x0") = output;
  70. asm volatile(
  71. "add x1, x0, %x[LDC]\n"
  72. "eor v24.16b, v24.16b, v24.16b\n"
  73. "PRFM PLDL1KEEP, [%[a_ptr], #512]\n"
  74. "eor v25.16b, v25.16b, v25.16b\n"
  75. "PRFM PLDL1KEEP, [%[b_ptr], #512]\n"
  76. "eor v26.16b, v26.16b, v26.16b\n"
  77. "ld1 {v20.16b}, [%[a_ptr]],#16\n"
  78. "eor v27.16b, v27.16b, v27.16b\n"
  79. "ld1 {v21.16b}, [%[a_ptr]],#16\n"
  80. "eor v28.16b, v28.16b, v28.16b\n"
  81. "eor v29.16b, v29.16b, v29.16b\n"
  82. "eor v30.16b, v30.16b, v30.16b\n"
  83. "eor v31.16b, v31.16b, v31.16b\n"
  84. // General loop.
  85. "1:\n"
  86. "dup v0.8b,v20.b[0]\n"
  87. "ld1 {v22.16b}, [%[a_ptr]],#16\n"
  88. "dup v1.8b,v20.b[1]\n"
  89. "ld1 {v23.16b}, [%[a_ptr]],#16\n"
  90. "dup v2.8b,v20.b[2]\n"
  91. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  92. "dup v3.8b,v20.b[3]\n"
  93. "dup v4.8b,v20.b[4]\n"
  94. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  95. "dup v5.8b,v20.b[5]\n"
  96. "dup v6.8b,v20.b[6]\n"
  97. "dup v7.8b,v20.b[7]\n"
  98. "dup v8.8b,v20.b[8]\n"
  99. "smlal v24.8h, v0.8b, v16.8b\n"
  100. "dup v9.8b,v20.b[9]\n"
  101. "smlal v25.8h, v1.8b, v16.8b\n"
  102. "dup v10.8b,v20.b[10]\n"
  103. "smlal v26.8h, v2.8b, v16.8b\n"
  104. "dup v11.8b,v20.b[11]\n"
  105. "smlal v27.8h, v3.8b, v16.8b\n"
  106. "dup v12.8b,v20.b[12]\n"
  107. "smlal v28.8h, v4.8b, v16.8b\n"
  108. "dup v13.8b,v20.b[13]\n"
  109. "smlal v29.8h, v5.8b, v16.8b\n"
  110. "dup v14.8b,v20.b[14]\n"
  111. "smlal v30.8h, v6.8b, v16.8b\n"
  112. "dup v15.8b,v20.b[15]\n"
  113. "smlal v31.8h, v7.8b, v16.8b\n"
  114. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  115. "dup v0.8b,v21.b[0]\n"
  116. "smlal v24.8h, v8.8b, v17.8b\n"
  117. "dup v1.8b,v21.b[1]\n"
  118. "smlal v25.8h, v9.8b, v17.8b\n"
  119. "dup v2.8b,v21.b[2]\n"
  120. "smlal v26.8h, v10.8b, v17.8b\n"
  121. "dup v3.8b,v21.b[3]\n"
  122. "smlal v27.8h, v11.8b, v17.8b\n"
  123. "dup v4.8b,v21.b[4]\n"
  124. "smlal v28.8h, v12.8b, v17.8b\n"
  125. "dup v5.8b,v21.b[5]\n"
  126. "smlal v29.8h, v13.8b, v17.8b\n"
  127. "dup v6.8b,v21.b[6]\n"
  128. "smlal v30.8h, v14.8b, v17.8b\n"
  129. "dup v7.8b,v21.b[7]\n"
  130. "smlal v31.8h, v15.8b, v17.8b\n"
  131. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  132. "dup v8.8b,v21.b[8]\n"
  133. "smlal v24.8h, v0.8b, v16.8b\n"
  134. "dup v9.8b,v21.b[9]\n"
  135. "smlal v25.8h, v1.8b, v16.8b\n"
  136. "dup v10.8b,v21.b[10]\n"
  137. "smlal v26.8h, v2.8b, v16.8b\n"
  138. "dup v11.8b,v21.b[11]\n"
  139. "smlal v27.8h, v3.8b, v16.8b\n"
  140. "dup v12.8b,v21.b[12]\n"
  141. "smlal v28.8h, v4.8b, v16.8b\n"
  142. "dup v13.8b,v21.b[13]\n"
  143. "smlal v29.8h, v5.8b, v16.8b\n"
  144. "dup v14.8b,v21.b[14]\n"
  145. "smlal v30.8h, v6.8b, v16.8b\n"
  146. "dup v15.8b,v21.b[15]\n"
  147. "smlal v31.8h, v7.8b, v16.8b\n"
  148. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  149. "dup v0.8b,v22.b[0]\n"
  150. "smlal v24.8h, v8.8b, v17.8b\n"
  151. "dup v1.8b,v22.b[1]\n"
  152. "smlal v25.8h, v9.8b, v17.8b\n"
  153. "dup v2.8b,v22.b[2]\n"
  154. "smlal v26.8h, v10.8b, v17.8b\n"
  155. "dup v3.8b,v22.b[3]\n"
  156. "smlal v27.8h, v11.8b, v17.8b\n"
  157. "dup v4.8b,v22.b[4]\n"
  158. "smlal v28.8h, v12.8b, v17.8b\n"
  159. "dup v5.8b,v22.b[5]\n"
  160. "smlal v29.8h, v13.8b, v17.8b\n"
  161. "dup v6.8b,v22.b[6]\n"
  162. "smlal v30.8h, v14.8b, v17.8b\n"
  163. "dup v7.8b,v22.b[7]\n"
  164. "smlal v31.8h, v15.8b, v17.8b\n"
  165. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  166. "dup v8.8b,v22.b[8]\n"
  167. "smlal v24.8h, v0.8b, v16.8b\n"
  168. "dup v9.8b,v22.b[9]\n"
  169. "smlal v25.8h, v1.8b, v16.8b\n"
  170. "dup v10.8b,v22.b[10]\n"
  171. "smlal v26.8h, v2.8b, v16.8b\n"
  172. "dup v11.8b,v22.b[11]\n"
  173. "smlal v27.8h, v3.8b, v16.8b\n"
  174. "dup v12.8b,v22.b[12]\n"
  175. "smlal v28.8h, v4.8b, v16.8b\n"
  176. "dup v13.8b,v22.b[13]\n"
  177. "smlal v29.8h, v5.8b, v16.8b\n"
  178. "dup v14.8b,v22.b[14]\n"
  179. "smlal v30.8h, v6.8b, v16.8b\n"
  180. "dup v15.8b,v22.b[15]\n"
  181. "smlal v31.8h, v7.8b, v16.8b\n"
  182. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  183. "dup v0.8b,v23.b[0]\n"
  184. "smlal v24.8h, v8.8b, v17.8b\n"
  185. "dup v1.8b,v23.b[1]\n"
  186. "smlal v25.8h, v9.8b, v17.8b\n"
  187. "dup v2.8b,v23.b[2]\n"
  188. "smlal v26.8h, v10.8b, v17.8b\n"
  189. "dup v3.8b,v23.b[3]\n"
  190. "smlal v27.8h, v11.8b, v17.8b\n"
  191. "dup v4.8b,v23.b[4]\n"
  192. "smlal v28.8h, v12.8b, v17.8b\n"
  193. "dup v5.8b,v23.b[5]\n"
  194. "smlal v29.8h, v13.8b, v17.8b\n"
  195. "dup v6.8b,v23.b[6]\n"
  196. "smlal v30.8h, v14.8b, v17.8b\n"
  197. "dup v7.8b,v23.b[7]\n"
  198. "smlal v31.8h, v15.8b, v17.8b\n"
  199. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  200. "dup v8.8b,v23.b[8]\n"
  201. "smlal v24.8h, v0.8b, v16.8b\n"
  202. "dup v9.8b,v23.b[9]\n"
  203. "smlal v25.8h, v1.8b, v16.8b\n"
  204. "dup v10.8b,v23.b[10]\n"
  205. "smlal v26.8h, v2.8b, v16.8b\n"
  206. "dup v11.8b,v23.b[11]\n"
  207. "smlal v27.8h, v3.8b, v16.8b\n"
  208. "dup v12.8b,v23.b[12]\n"
  209. "smlal v28.8h, v4.8b, v16.8b\n"
  210. "dup v13.8b,v23.b[13]\n"
  211. "smlal v29.8h, v5.8b, v16.8b\n"
  212. "dup v14.8b,v23.b[14]\n"
  213. "smlal v30.8h, v6.8b, v16.8b\n"
  214. "dup v15.8b,v23.b[15]\n"
  215. "smlal v31.8h, v7.8b, v16.8b\n"
  216. "ld1 {v20.16b}, [%[a_ptr]],#16\n"
  217. "smlal v24.8h, v8.8b, v17.8b\n"
  218. "smlal v25.8h, v9.8b, v17.8b\n"
  219. "smlal v26.8h, v10.8b, v17.8b\n"
  220. "smlal v27.8h, v11.8b, v17.8b\n"
  221. "ld1 {v21.16b}, [%[a_ptr]],#16\n"
  222. "smlal v28.8h, v12.8b, v17.8b\n"
  223. "smlal v29.8h, v13.8b, v17.8b\n"
  224. "smlal v30.8h, v14.8b, v17.8b\n"
  225. "smlal v31.8h, v15.8b, v17.8b\n"
  226. "subs %w[K], %w[K], #1\n"
  227. "cbnz %w[K], 1b\n"
  228. "cmp %w[is_first_k], #1\n"
  229. "beq 2f\n" LOAD_C_8
  230. "b 3f \n"
  231. "2: \n"
  232. "eor v0.16b, v0.16b, v0.16b\n"
  233. "eor v1.16b, v1.16b, v1.16b\n"
  234. "eor v2.16b, v2.16b, v2.16b\n"
  235. "eor v3.16b, v3.16b, v3.16b\n"
  236. "eor v4.16b, v4.16b, v4.16b\n"
  237. "eor v5.16b, v5.16b, v5.16b\n"
  238. "eor v6.16b, v6.16b, v6.16b\n"
  239. "eor v7.16b, v7.16b, v7.16b\n"
  240. "3:\n"
  241. "zip1 v8.2d, v24.2d, v25.2d\n"
  242. "zip2 v9.2d, v24.2d, v25.2d\n"
  243. "zip1 v10.2d, v26.2d, v27.2d\n"
  244. "zip2 v11.2d, v26.2d, v27.2d\n"
  245. "zip1 v12.2d, v28.2d, v29.2d\n"
  246. "zip2 v13.2d, v28.2d, v29.2d\n"
  247. "zip1 v14.2d, v30.2d, v31.2d\n"
  248. "zip2 v15.2d, v30.2d, v31.2d\n"
  249. "add v0.8h, v0.8h, v8.8h\n"
  250. "add v1.8h, v1.8h, v10.8h\n"
  251. "add v2.8h, v2.8h, v12.8h\n"
  252. "add v3.8h, v3.8h, v14.8h\n"
  253. "add v4.8h, v4.8h, v9.8h\n"
  254. "add v5.8h, v5.8h, v11.8h\n"
  255. "add v6.8h, v6.8h, v13.8h\n"
  256. "add v7.8h, v7.8h, v15.8h\n"
  257. // Store back into memory
  258. STORE_C_8
  259. :
  260. [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
  261. [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
  262. [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
  263. [ n_remain ] "+r"(n_remain)
  264. :
  265. : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
  266. "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
  267. "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
  268. "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
  269. "v29", "v30", "v31");
  270. // clang-format on
  271. }
  272. static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
  273. int16_t* output, int LDC, bool is_first_k, int m_remain,
  274. int n_remain) {
  275. K /= 8;
  276. LDC = LDC * sizeof(int16_t);
  277. const int8_t* a_ptr = packB;
  278. const int8_t* b_ptr = packA;
  279. // clang-format off
  280. register int16_t* outptr asm("x0") = output;
  281. asm volatile(
  282. "add x1, x0, %x[LDC]\n"
  283. "eor v24.16b, v24.16b, v24.16b\n"
  284. "eor v25.16b, v25.16b, v25.16b\n"
  285. "eor v26.16b, v26.16b, v26.16b\n"
  286. "eor v27.16b, v27.16b, v27.16b\n"
  287. "eor v28.16b, v28.16b, v28.16b\n"
  288. "eor v29.16b, v29.16b, v29.16b\n"
  289. "eor v30.16b, v30.16b, v30.16b\n"
  290. "eor v31.16b, v31.16b, v31.16b\n"
  291. // General loop.
  292. "ld1 {v20.16b}, [%[a_ptr]],#16\n"
  293. "ld1 {v21.16b}, [%[a_ptr]],#16\n"
  294. "PRFM PLDL1KEEP, [%[a_ptr], #512]\n"
  295. "PRFM PLDL1KEEP, [%[b_ptr], #512]\n"
  296. "1:\n"
  297. "dup v0.8b,v20.b[0]\n"
  298. "ld1 {v22.16b}, [%[a_ptr]],#16\n"
  299. "dup v1.8b,v20.b[1]\n"
  300. "ld1 {v23.16b}, [%[a_ptr]],#16\n"
  301. "dup v2.8b,v20.b[2]\n"
  302. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  303. "dup v3.8b,v20.b[3]\n"
  304. "dup v4.8b,v20.b[4]\n"
  305. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  306. "dup v5.8b,v20.b[5]\n"
  307. "dup v6.8b,v20.b[6]\n"
  308. "dup v7.8b,v20.b[7]\n"
  309. "dup v8.8b,v20.b[8]\n"
  310. "smlal v24.8h, v0.8b, v16.8b\n"
  311. "dup v9.8b,v20.b[9]\n"
  312. "smlal v25.8h, v1.8b, v16.8b\n"
  313. "dup v10.8b,v20.b[10]\n"
  314. "smlal v26.8h, v2.8b, v16.8b\n"
  315. "dup v11.8b,v20.b[11]\n"
  316. "smlal v27.8h, v3.8b, v16.8b\n"
  317. "dup v12.8b,v20.b[12]\n"
  318. "smlal v28.8h, v4.8b, v16.8b\n"
  319. "dup v13.8b,v20.b[13]\n"
  320. "smlal v29.8h, v5.8b, v16.8b\n"
  321. "dup v14.8b,v20.b[14]\n"
  322. "smlal v30.8h, v6.8b, v16.8b\n"
  323. "dup v15.8b,v20.b[15]\n"
  324. "smlal v31.8h, v7.8b, v16.8b\n"
  325. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  326. "dup v0.8b,v21.b[0]\n"
  327. "smlal v24.8h, v8.8b, v17.8b\n"
  328. "dup v1.8b,v21.b[1]\n"
  329. "smlal v25.8h, v9.8b, v17.8b\n"
  330. "dup v2.8b,v21.b[2]\n"
  331. "smlal v26.8h, v10.8b, v17.8b\n"
  332. "dup v3.8b,v21.b[3]\n"
  333. "smlal v27.8h, v11.8b, v17.8b\n"
  334. "dup v4.8b,v21.b[4]\n"
  335. "smlal v28.8h, v12.8b, v17.8b\n"
  336. "dup v5.8b,v21.b[5]\n"
  337. "smlal v29.8h, v13.8b, v17.8b\n"
  338. "dup v6.8b,v21.b[6]\n"
  339. "smlal v30.8h, v14.8b, v17.8b\n"
  340. "dup v7.8b,v21.b[7]\n"
  341. "smlal v31.8h, v15.8b, v17.8b\n"
  342. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  343. "dup v8.8b,v21.b[8]\n"
  344. "smlal v24.8h, v0.8b, v16.8b\n"
  345. "dup v9.8b,v21.b[9]\n"
  346. "smlal v25.8h, v1.8b, v16.8b\n"
  347. "dup v10.8b,v21.b[10]\n"
  348. "smlal v26.8h, v2.8b, v16.8b\n"
  349. "dup v11.8b,v21.b[11]\n"
  350. "smlal v27.8h, v3.8b, v16.8b\n"
  351. "dup v12.8b,v21.b[12]\n"
  352. "smlal v28.8h, v4.8b, v16.8b\n"
  353. "dup v13.8b,v21.b[13]\n"
  354. "smlal v29.8h, v5.8b, v16.8b\n"
  355. "dup v14.8b,v21.b[14]\n"
  356. "smlal v30.8h, v6.8b, v16.8b\n"
  357. "dup v15.8b,v21.b[15]\n"
  358. "smlal v31.8h, v7.8b, v16.8b\n"
  359. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  360. "dup v0.8b,v22.b[0]\n"
  361. "smlal v24.8h, v8.8b, v17.8b\n"
  362. "dup v1.8b,v22.b[1]\n"
  363. "smlal v25.8h, v9.8b, v17.8b\n"
  364. "dup v2.8b,v22.b[2]\n"
  365. "smlal v26.8h, v10.8b, v17.8b\n"
  366. "dup v3.8b,v22.b[3]\n"
  367. "smlal v27.8h, v11.8b, v17.8b\n"
  368. "dup v4.8b,v22.b[4]\n"
  369. "smlal v28.8h, v12.8b, v17.8b\n"
  370. "dup v5.8b,v22.b[5]\n"
  371. "smlal v29.8h, v13.8b, v17.8b\n"
  372. "dup v6.8b,v22.b[6]\n"
  373. "smlal v30.8h, v14.8b, v17.8b\n"
  374. "dup v7.8b,v22.b[7]\n"
  375. "smlal v31.8h, v15.8b, v17.8b\n"
  376. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  377. "dup v8.8b,v22.b[8]\n"
  378. "smlal v24.8h, v0.8b, v16.8b\n"
  379. "dup v9.8b,v22.b[9]\n"
  380. "smlal v25.8h, v1.8b, v16.8b\n"
  381. "dup v10.8b,v22.b[10]\n"
  382. "smlal v26.8h, v2.8b, v16.8b\n"
  383. "dup v11.8b,v22.b[11]\n"
  384. "smlal v27.8h, v3.8b, v16.8b\n"
  385. "dup v12.8b,v22.b[12]\n"
  386. "smlal v28.8h, v4.8b, v16.8b\n"
  387. "dup v13.8b,v22.b[13]\n"
  388. "smlal v29.8h, v5.8b, v16.8b\n"
  389. "dup v14.8b,v22.b[14]\n"
  390. "smlal v30.8h, v6.8b, v16.8b\n"
  391. "dup v15.8b,v22.b[15]\n"
  392. "smlal v31.8h, v7.8b, v16.8b\n"
  393. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  394. "dup v0.8b,v23.b[0]\n"
  395. "smlal v24.8h, v8.8b, v17.8b\n"
  396. "dup v1.8b,v23.b[1]\n"
  397. "smlal v25.8h, v9.8b, v17.8b\n"
  398. "dup v2.8b,v23.b[2]\n"
  399. "smlal v26.8h, v10.8b, v17.8b\n"
  400. "dup v3.8b,v23.b[3]\n"
  401. "smlal v27.8h, v11.8b, v17.8b\n"
  402. "dup v4.8b,v23.b[4]\n"
  403. "smlal v28.8h, v12.8b, v17.8b\n"
  404. "dup v5.8b,v23.b[5]\n"
  405. "smlal v29.8h, v13.8b, v17.8b\n"
  406. "dup v6.8b,v23.b[6]\n"
  407. "smlal v30.8h, v14.8b, v17.8b\n"
  408. "dup v7.8b,v23.b[7]\n"
  409. "smlal v31.8h, v15.8b, v17.8b\n"
  410. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  411. "dup v8.8b,v23.b[8]\n"
  412. "smlal v24.8h, v0.8b, v16.8b\n"
  413. "dup v9.8b,v23.b[9]\n"
  414. "smlal v25.8h, v1.8b, v16.8b\n"
  415. "dup v10.8b,v23.b[10]\n"
  416. "smlal v26.8h, v2.8b, v16.8b\n"
  417. "dup v11.8b,v23.b[11]\n"
  418. "smlal v27.8h, v3.8b, v16.8b\n"
  419. "dup v12.8b,v23.b[12]\n"
  420. "smlal v28.8h, v4.8b, v16.8b\n"
  421. "dup v13.8b,v23.b[13]\n"
  422. "smlal v29.8h, v5.8b, v16.8b\n"
  423. "dup v14.8b,v23.b[14]\n"
  424. "smlal v30.8h, v6.8b, v16.8b\n"
  425. "dup v15.8b,v23.b[15]\n"
  426. "smlal v31.8h, v7.8b, v16.8b\n"
  427. "ld1 {v20.16b}, [%[a_ptr]],#16\n"
  428. "smlal v24.8h, v8.8b, v17.8b\n"
  429. "smlal v25.8h, v9.8b, v17.8b\n"
  430. "smlal v26.8h, v10.8b, v17.8b\n"
  431. "smlal v27.8h, v11.8b, v17.8b\n"
  432. "ld1 {v21.16b}, [%[a_ptr]],#16\n"
  433. "smlal v28.8h, v12.8b, v17.8b\n"
  434. "smlal v29.8h, v13.8b, v17.8b\n"
  435. "smlal v30.8h, v14.8b, v17.8b\n"
  436. "smlal v31.8h, v15.8b, v17.8b\n"
  437. "subs %w[K], %w[K], #1\n"
  438. "cbnz %w[K], 1b\n"
  439. "cmp %w[is_first_k], #1\n"
  440. "beq 2f\n"
  441. "cmp %x[m_remain], #8 \n"
  442. "beq 8f \n"
  443. "cmp %x[m_remain], #4 \n"
  444. "beq 9f \n"
  445. "8: \n"
  446. "cmp %x[n_remain], #8\n"
  447. "beq 200f \n"
  448. "cmp %x[n_remain], #7\n"
  449. "beq 201f \n"
  450. "cmp %x[n_remain], #6\n"
  451. "beq 202f \n"
  452. "cmp %x[n_remain], #5\n"
  453. "beq 203f \n"
  454. "cmp %x[n_remain], #4\n"
  455. "beq 204f \n"
  456. "cmp %x[n_remain], #3\n"
  457. "beq 205f \n"
  458. "cmp %x[n_remain], #2\n"
  459. "beq 206f \n"
  460. "cmp %x[n_remain], #1\n"
  461. "beq 207f \n"
  462. "200: \n"
  463. "ld1 {v0.8h}, [x0], #16\n"
  464. "ld1 {v1.8h}, [x0], #16\n"
  465. "ld1 {v2.8h}, [x0], #16\n"
  466. "ld1 {v3.8h}, [x0], #16\n"
  467. "ld1 {v4.8h}, [x1], #16\n"
  468. "ld1 {v5.8h}, [x1], #16\n"
  469. "ld1 {v6.8h}, [x1], #16\n"
  470. "ld1 {v7.8h}, [x1], #16\n"
  471. "b 3f \n"
  472. "201: \n"
  473. "ld1 {v0.8h}, [x0], #16\n"
  474. "ld1 {v1.8h}, [x0], #16\n"
  475. "ld1 {v2.8h}, [x0], #16\n"
  476. "ld1 {v3.d}[0], [x0], #8\n"
  477. "ld1 {v4.8h}, [x1], #16\n"
  478. "ld1 {v5.8h}, [x1], #16\n"
  479. "ld1 {v6.8h}, [x1], #16\n"
  480. "ld1 {v7.d}[0], [x1], #8\n"
  481. "b 3f \n"
  482. "202: \n"
  483. "ld1 {v0.8h}, [x0], #16\n"
  484. "ld1 {v1.8h}, [x0], #16\n"
  485. "ld1 {v2.8h}, [x0], #16\n"
  486. "ld1 {v4.8h}, [x1], #16\n"
  487. "ld1 {v5.8h}, [x1], #16\n"
  488. "ld1 {v6.8h}, [x1], #16\n"
  489. "b 3f \n"
  490. "203: \n"
  491. "ld1 {v0.8h}, [x0], #16\n"
  492. "ld1 {v1.8h}, [x0], #16\n"
  493. "ld1 {v2.d}[0], [x0], #8\n"
  494. "ld1 {v4.8h}, [x1], #16\n"
  495. "ld1 {v5.8h}, [x1], #16\n"
  496. "ld1 {v6.d}[0], [x1], #8\n"
  497. "b 3f \n"
  498. "204: \n"
  499. "ld1 {v0.8h}, [x0], #16\n"
  500. "ld1 {v1.8h}, [x0], #16\n"
  501. "ld1 {v4.8h}, [x1], #16\n"
  502. "ld1 {v5.8h}, [x1], #16\n"
  503. "b 3f \n"
  504. "205: \n"
  505. "ld1 {v0.8h}, [x0], #16\n"
  506. "ld1 {v1.d}[0], [x0], #8\n"
  507. "ld1 {v4.8h}, [x1], #16\n"
  508. "ld1 {v5.d}[0], [x1], #8\n"
  509. "b 3f \n"
  510. "206: \n"
  511. "ld1 {v0.8h}, [x0], #16\n"
  512. "ld1 {v4.8h}, [x1], #16\n"
  513. "b 3f \n"
  514. "207: \n"
  515. "ld1 {v0.d}[0], [x0], #8\n"
  516. "ld1 {v4.d}[0], [x1], #8\n"
  517. "b 3f \n"
  518. "9: \n"
  519. "cmp %x[n_remain], #8\n"
  520. "beq 300f \n"
  521. "cmp %x[n_remain], #7\n"
  522. "beq 301f \n"
  523. "cmp %x[n_remain], #6\n"
  524. "beq 302f \n"
  525. "cmp %x[n_remain], #5\n"
  526. "beq 303f \n"
  527. "cmp %x[n_remain], #4\n"
  528. "beq 304f \n"
  529. "cmp %x[n_remain], #3\n"
  530. "beq 305f \n"
  531. "cmp %x[n_remain], #2\n"
  532. "beq 306f \n"
  533. "cmp %x[n_remain], #1\n"
  534. "beq 307f \n"
  535. "300: \n"
  536. "ld1 {v0.8h}, [x0], #16\n"
  537. "ld1 {v1.8h}, [x0], #16\n"
  538. "ld1 {v2.8h}, [x0], #16\n"
  539. "ld1 {v3.8h}, [x0], #16\n"
  540. "b 3f \n"
  541. "301: \n"
  542. "ld1 {v0.8h}, [x0], #16\n"
  543. "ld1 {v1.8h}, [x0], #16\n"
  544. "ld1 {v2.8h}, [x0], #16\n"
  545. "ld1 {v3.d}[0], [x0], #8\n"
  546. "b 3f \n"
  547. "302: \n"
  548. "ld1 {v0.8h}, [x0], #16\n"
  549. "ld1 {v1.8h}, [x0], #16\n"
  550. "ld1 {v2.8h}, [x0], #16\n"
  551. "b 3f \n"
  552. "303: \n"
  553. "ld1 {v0.8h}, [x0], #16\n"
  554. "ld1 {v1.8h}, [x0], #16\n"
  555. "ld1 {v2.d}[0], [x0], #8\n"
  556. "b 3f \n"
  557. "304: \n"
  558. "ld1 {v0.8h}, [x0], #16\n"
  559. "ld1 {v1.8h}, [x0], #16\n"
  560. "b 3f \n"
  561. "305: \n"
  562. "ld1 {v0.8h}, [x0], #16\n"
  563. "ld1 {v1.d}[0], [x0], #8\n"
  564. "b 3f \n"
  565. "306: \n"
  566. "ld1 {v0.8h}, [x0], #16\n"
  567. "b 3f \n"
  568. "307: \n"
  569. "ld1 {v0.d}[0], [x0], #8\n"
  570. "b 3f \n"
  571. "2: \n"
  572. "eor v0.16b, v0.16b, v0.16b\n"
  573. "eor v1.16b, v1.16b, v1.16b\n"
  574. "eor v2.16b, v2.16b, v2.16b\n"
  575. "eor v3.16b, v3.16b, v3.16b\n"
  576. "eor v4.16b, v4.16b, v4.16b\n"
  577. "eor v5.16b, v5.16b, v5.16b\n"
  578. "eor v6.16b, v6.16b, v6.16b\n"
  579. "eor v7.16b, v7.16b, v7.16b\n"
  580. "3:\n"
  581. "zip1 v8.2d, v24.2d, v25.2d\n"
  582. "zip1 v10.2d, v26.2d, v27.2d\n"
  583. "add v0.8h, v0.8h, v8.8h \n"
  584. "zip1 v12.2d, v28.2d, v29.2d\n"
  585. "add v1.8h, v1.8h, v10.8h \n"
  586. "zip1 v14.2d, v30.2d, v31.2d\n"
  587. "add v2.8h, v2.8h, v12.8h \n"
  588. "add v3.8h, v3.8h, v14.8h \n"
  589. "zip2 v9.2d, v24.2d, v25.2d\n"
  590. "zip2 v11.2d, v26.2d, v27.2d \n"
  591. "add v4.8h, v4.8h, v9.8h \n"
  592. "zip2 v13.2d, v28.2d, v29.2d \n"
  593. "add v5.8h, v5.8h, v11.8h \n"
  594. "zip2 v15.2d, v30.2d, v31.2d \n"
  595. "add v6.8h, v6.8h, v13.8h \n"
  596. "add v7.8h, v7.8h, v15.8h \n"
  597. //save to memory
  598. "cmp %x[m_remain], #8 \n"
  599. "beq 4f \n"
  600. "cmp %x[m_remain], #4 \n"
  601. "beq 5f \n"
  602. "4: \n"
  603. "cmp %x[n_remain], #8\n"
  604. "beq 100f \n"
  605. "cmp %x[n_remain], #7\n"
  606. "beq 101f \n"
  607. "cmp %x[n_remain], #6\n"
  608. "beq 102f \n"
  609. "cmp %x[n_remain], #5\n"
  610. "beq 103f \n"
  611. "cmp %x[n_remain], #4\n"
  612. "beq 104f \n"
  613. "cmp %x[n_remain], #3\n"
  614. "beq 105f \n"
  615. "cmp %x[n_remain], #2\n"
  616. "beq 106f \n"
  617. "cmp %x[n_remain], #1\n"
  618. "beq 107f \n"
  619. "100: \n"
  620. "st1 {v0.8h}, [x0], #16\n"
  621. "st1 {v1.8h}, [x0], #16\n"
  622. "st1 {v2.8h}, [x0], #16\n"
  623. "st1 {v3.8h}, [x0], #16\n"
  624. "st1 {v4.8h}, [x1], #16\n"
  625. "st1 {v5.8h}, [x1], #16\n"
  626. "st1 {v6.8h}, [x1], #16\n"
  627. "st1 {v7.8h}, [x1], #16\n"
  628. "b 1000f \n"
  629. "101: \n"
  630. "st1 {v0.8h}, [x0], #16\n"
  631. "st1 {v1.8h}, [x0], #16\n"
  632. "st1 {v2.8h}, [x0], #16\n"
  633. "st1 {v3.d}[0], [x0], #8\n"
  634. "st1 {v4.8h}, [x1], #16\n"
  635. "st1 {v5.8h}, [x1], #16\n"
  636. "st1 {v6.8h}, [x1], #16\n"
  637. "st1 {v7.d}[0], [x1], #8\n"
  638. "b 1000f \n"
  639. "102: \n"
  640. "st1 {v0.8h}, [x0], #16\n"
  641. "st1 {v1.8h}, [x0], #16\n"
  642. "st1 {v2.8h}, [x0], #16\n"
  643. "st1 {v4.8h}, [x1], #16\n"
  644. "st1 {v5.8h}, [x1], #16\n"
  645. "st1 {v6.8h}, [x1], #16\n"
  646. "b 1000f \n"
  647. "103: \n"
  648. "st1 {v0.8h}, [x0], #16\n"
  649. "st1 {v1.8h}, [x0], #16\n"
  650. "st1 {v2.d}[0], [x0], #8\n"
  651. "st1 {v4.8h}, [x1], #16\n"
  652. "st1 {v5.8h}, [x1], #16\n"
  653. "st1 {v6.d}[0], [x1], #8\n"
  654. "b 1000f \n"
  655. "104: \n"
  656. "st1 {v0.8h}, [x0], #16\n"
  657. "st1 {v1.8h}, [x0], #16\n"
  658. "st1 {v4.8h}, [x1], #16\n"
  659. "st1 {v5.8h}, [x1], #16\n"
  660. "b 1000f \n"
  661. "105: \n"
  662. "st1 {v0.8h}, [x0], #16\n"
  663. "st1 {v1.d}[0], [x0], #8\n"
  664. "st1 {v4.8h}, [x1], #16\n"
  665. "st1 {v5.d}[0], [x1], #8\n"
  666. "b 1000f \n"
  667. "106: \n"
  668. "st1 {v0.8h}, [x0], #16\n"
  669. "st1 {v4.8h}, [x1], #16\n"
  670. "b 1000f \n"
  671. "107: \n"
  672. "st1 {v0.d}[0], [x0], #8\n"
  673. "st1 {v4.d}[0], [x1], #8\n"
  674. "b 1000f \n"
  675. "5: \n"
  676. "cmp %x[n_remain], #8\n"
  677. "beq 200f \n"
  678. "cmp %x[n_remain], #7\n"
  679. "beq 201f \n"
  680. "cmp %x[n_remain], #6\n"
  681. "beq 202f \n"
  682. "cmp %x[n_remain], #5\n"
  683. "beq 203f \n"
  684. "cmp %x[n_remain], #4\n"
  685. "beq 204f \n"
  686. "cmp %x[n_remain], #3\n"
  687. "beq 205f \n"
  688. "cmp %x[n_remain], #2\n"
  689. "beq 206f \n"
  690. "cmp %x[n_remain], #1\n"
  691. "beq 207f \n"
  692. "200: \n"
  693. "st1 {v0.8h}, [x0], #16\n"
  694. "st1 {v1.8h}, [x0], #16\n"
  695. "st1 {v2.8h}, [x0], #16\n"
  696. "st1 {v3.8h}, [x0], #16\n"
  697. "b 1000f \n"
  698. "201: \n"
  699. "st1 {v0.8h}, [x0], #16\n"
  700. "st1 {v1.8h}, [x0], #16\n"
  701. "st1 {v2.8h}, [x0], #16\n"
  702. "st1 {v3.d}[0], [x0], #8\n"
  703. "b 1000f \n"
  704. "202: \n"
  705. "st1 {v0.8h}, [x0], #16\n"
  706. "st1 {v1.8h}, [x0], #16\n"
  707. "st1 {v2.8h}, [x0], #16\n"
  708. "b 1000f \n"
  709. "203: \n"
  710. "st1 {v0.8h}, [x0], #16\n"
  711. "st1 {v1.8h}, [x0], #16\n"
  712. "st1 {v2.d}[0], [x0], #8\n"
  713. "b 1000f \n"
  714. "204: \n"
  715. "st1 {v0.8h}, [x0], #16\n"
  716. "st1 {v1.8h}, [x0], #16\n"
  717. "b 1000f \n"
  718. "205: \n"
  719. "st1 {v0.8h}, [x0], #16\n"
  720. "st1 {v1.d}[0], [x0], #8\n"
  721. "b 1000f \n"
  722. "206: \n"
  723. "st1 {v0.8h}, [x0], #16\n"
  724. "b 1000f \n"
  725. "207: \n"
  726. "st1 {v0.d}[0], [x0], #8\n"
  727. "b 1000f \n"
  728. "1000: \n"
  729. :
  730. [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
  731. [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
  732. [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
  733. [ n_remain ] "+r"(n_remain)
  734. :
  735. : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
  736. "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
  737. "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
  738. "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
  739. "v29", "v30", "v31");
  740. // clang-format on
  741. #undef LOAD_C_8
  742. #undef STORE_C_8
  743. }
  744. static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
  745. int16_t* output, int LDC, bool is_first_k, int m_remain,
  746. int n_remain) {
  747. K /= 8;
  748. LDC = LDC * sizeof(int16_t);
  749. const int8_t* a_ptr = packB;//packA;
  750. const int8_t* b_ptr = packA;//packB;
  751. // clang-format off
  752. #define LOAD_C_4 \
  753. "ld1 {v0.8h}, [x0], #16\n" \
  754. "ld1 {v1.8h}, [x0], #16\n" \
  755. "ld1 {v2.8h}, [x0], #16\n" \
  756. "ld1 {v3.8h}, [x0], #16\n" \
  757. #define STORE_C_4 \
  758. "st1 {v0.8h}, [x0], #16\n" \
  759. "st1 {v1.8h}, [x0], #16\n" \
  760. "st1 {v2.8h}, [x0], #16\n" \
  761. "st1 {v3.8h}, [x0], #16\n" \
  762. register int16_t* outptr asm("x0") = output;
  763. asm volatile(
  764. "eor v24.16b, v24.16b, v24.16b\n"
  765. "eor v25.16b, v25.16b, v25.16b\n"
  766. "eor v26.16b, v26.16b, v26.16b\n"
  767. "eor v27.16b, v27.16b, v27.16b\n"
  768. "eor v28.16b, v28.16b, v28.16b\n"
  769. "eor v29.16b, v29.16b, v29.16b\n"
  770. "eor v30.16b, v30.16b, v30.16b\n"
  771. "eor v31.16b, v31.16b, v31.16b\n"
  772. // General loop.
  773. "ld1 {v20.16b}, [%[a_ptr]],#16\n"
  774. "ld1 {v21.16b}, [%[a_ptr]],#16\n"
  775. "PRFM PLDL1KEEP, [%[a_ptr], #512]\n"
  776. "PRFM PLDL1KEEP, [%[b_ptr], #512]\n"
  777. "1:\n"
  778. "dup v0.8b,v20.b[0]\n"
  779. "ld1 {v22.16b}, [%[a_ptr]],#16\n"
  780. "dup v1.8b,v20.b[1]\n"
  781. "ld1 {v23.16b}, [%[a_ptr]],#16\n"
  782. "dup v2.8b,v20.b[2]\n"
  783. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  784. "dup v3.8b,v20.b[3]\n"
  785. "dup v4.8b,v20.b[4]\n"
  786. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  787. "dup v5.8b,v20.b[5]\n"
  788. "dup v6.8b,v20.b[6]\n"
  789. "dup v7.8b,v20.b[7]\n"
  790. "dup v8.8b,v20.b[8]\n"
  791. "smlal v24.8h, v0.8b, v16.8b\n"
  792. "dup v9.8b,v20.b[9]\n"
  793. "smlal v25.8h, v1.8b, v16.8b\n"
  794. "dup v10.8b,v20.b[10]\n"
  795. "smlal v26.8h, v2.8b, v16.8b\n"
  796. "dup v11.8b,v20.b[11]\n"
  797. "smlal v27.8h, v3.8b, v16.8b\n"
  798. "dup v12.8b,v20.b[12]\n"
  799. "smlal v28.8h, v4.8b, v16.8b\n"
  800. "dup v13.8b,v20.b[13]\n"
  801. "smlal v29.8h, v5.8b, v16.8b\n"
  802. "dup v14.8b,v20.b[14]\n"
  803. "smlal v30.8h, v6.8b, v16.8b\n"
  804. "dup v15.8b,v20.b[15]\n"
  805. "smlal v31.8h, v7.8b, v16.8b\n"
  806. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  807. "dup v0.8b,v21.b[0]\n"
  808. "smlal v24.8h, v8.8b, v17.8b\n"
  809. "dup v1.8b,v21.b[1]\n"
  810. "smlal v25.8h, v9.8b, v17.8b\n"
  811. "dup v2.8b,v21.b[2]\n"
  812. "smlal v26.8h, v10.8b, v17.8b\n"
  813. "dup v3.8b,v21.b[3]\n"
  814. "smlal v27.8h, v11.8b, v17.8b\n"
  815. "dup v4.8b,v21.b[4]\n"
  816. "smlal v28.8h, v12.8b, v17.8b\n"
  817. "dup v5.8b,v21.b[5]\n"
  818. "smlal v29.8h, v13.8b, v17.8b\n"
  819. "dup v6.8b,v21.b[6]\n"
  820. "smlal v30.8h, v14.8b, v17.8b\n"
  821. "dup v7.8b,v21.b[7]\n"
  822. "smlal v31.8h, v15.8b, v17.8b\n"
  823. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  824. "dup v8.8b,v21.b[8]\n"
  825. "smlal v24.8h, v0.8b, v16.8b\n"
  826. "dup v9.8b,v21.b[9]\n"
  827. "smlal v25.8h, v1.8b, v16.8b\n"
  828. "dup v10.8b,v21.b[10]\n"
  829. "smlal v26.8h, v2.8b, v16.8b\n"
  830. "dup v11.8b,v21.b[11]\n"
  831. "smlal v27.8h, v3.8b, v16.8b\n"
  832. "dup v12.8b,v21.b[12]\n"
  833. "smlal v28.8h, v4.8b, v16.8b\n"
  834. "dup v13.8b,v21.b[13]\n"
  835. "smlal v29.8h, v5.8b, v16.8b\n"
  836. "dup v14.8b,v21.b[14]\n"
  837. "smlal v30.8h, v6.8b, v16.8b\n"
  838. "dup v15.8b,v21.b[15]\n"
  839. "smlal v31.8h, v7.8b, v16.8b\n"
  840. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  841. "dup v0.8b,v22.b[0]\n"
  842. "smlal v24.8h, v8.8b, v17.8b\n"
  843. "dup v1.8b,v22.b[1]\n"
  844. "smlal v25.8h, v9.8b, v17.8b\n"
  845. "dup v2.8b,v22.b[2]\n"
  846. "smlal v26.8h, v10.8b, v17.8b\n"
  847. "dup v3.8b,v22.b[3]\n"
  848. "smlal v27.8h, v11.8b, v17.8b\n"
  849. "dup v4.8b,v22.b[4]\n"
  850. "smlal v28.8h, v12.8b, v17.8b\n"
  851. "dup v5.8b,v22.b[5]\n"
  852. "smlal v29.8h, v13.8b, v17.8b\n"
  853. "dup v6.8b,v22.b[6]\n"
  854. "smlal v30.8h, v14.8b, v17.8b\n"
  855. "dup v7.8b,v22.b[7]\n"
  856. "smlal v31.8h, v15.8b, v17.8b\n"
  857. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  858. "dup v8.8b,v22.b[8]\n"
  859. "smlal v24.8h, v0.8b, v16.8b\n"
  860. "dup v9.8b,v22.b[9]\n"
  861. "smlal v25.8h, v1.8b, v16.8b\n"
  862. "dup v10.8b,v22.b[10]\n"
  863. "smlal v26.8h, v2.8b, v16.8b\n"
  864. "dup v11.8b,v22.b[11]\n"
  865. "smlal v27.8h, v3.8b, v16.8b\n"
  866. "dup v12.8b,v22.b[12]\n"
  867. "smlal v28.8h, v4.8b, v16.8b\n"
  868. "dup v13.8b,v22.b[13]\n"
  869. "smlal v29.8h, v5.8b, v16.8b\n"
  870. "dup v14.8b,v22.b[14]\n"
  871. "smlal v30.8h, v6.8b, v16.8b\n"
  872. "dup v15.8b,v22.b[15]\n"
  873. "smlal v31.8h, v7.8b, v16.8b\n"
  874. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  875. "dup v0.8b,v23.b[0]\n"
  876. "smlal v24.8h, v8.8b, v17.8b\n"
  877. "dup v1.8b,v23.b[1]\n"
  878. "smlal v25.8h, v9.8b, v17.8b\n"
  879. "dup v2.8b,v23.b[2]\n"
  880. "smlal v26.8h, v10.8b, v17.8b\n"
  881. "dup v3.8b,v23.b[3]\n"
  882. "smlal v27.8h, v11.8b, v17.8b\n"
  883. "dup v4.8b,v23.b[4]\n"
  884. "smlal v28.8h, v12.8b, v17.8b\n"
  885. "dup v5.8b,v23.b[5]\n"
  886. "smlal v29.8h, v13.8b, v17.8b\n"
  887. "dup v6.8b,v23.b[6]\n"
  888. "smlal v30.8h, v14.8b, v17.8b\n"
  889. "dup v7.8b,v23.b[7]\n"
  890. "smlal v31.8h, v15.8b, v17.8b\n"
  891. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  892. "dup v8.8b,v23.b[8]\n"
  893. "smlal v24.8h, v0.8b, v16.8b\n"
  894. "dup v9.8b,v23.b[9]\n"
  895. "smlal v25.8h, v1.8b, v16.8b\n"
  896. "dup v10.8b,v23.b[10]\n"
  897. "smlal v26.8h, v2.8b, v16.8b\n"
  898. "dup v11.8b,v23.b[11]\n"
  899. "smlal v27.8h, v3.8b, v16.8b\n"
  900. "dup v12.8b,v23.b[12]\n"
  901. "smlal v28.8h, v4.8b, v16.8b\n"
  902. "dup v13.8b,v23.b[13]\n"
  903. "smlal v29.8h, v5.8b, v16.8b\n"
  904. "dup v14.8b,v23.b[14]\n"
  905. "smlal v30.8h, v6.8b, v16.8b\n"
  906. "dup v15.8b,v23.b[15]\n"
  907. "smlal v31.8h, v7.8b, v16.8b\n"
  908. "ld1 {v20.16b}, [%[a_ptr]],#16\n"
  909. "smlal v24.8h, v8.8b, v17.8b\n"
  910. "smlal v25.8h, v9.8b, v17.8b\n"
  911. "smlal v26.8h, v10.8b, v17.8b\n"
  912. "smlal v27.8h, v11.8b, v17.8b\n"
  913. "ld1 {v21.16b}, [%[a_ptr]],#16\n"
  914. "smlal v28.8h, v12.8b, v17.8b\n"
  915. "smlal v29.8h, v13.8b, v17.8b\n"
  916. "smlal v30.8h, v14.8b, v17.8b\n"
  917. "smlal v31.8h, v15.8b, v17.8b\n"
  918. "subs %w[K], %w[K], #1\n"
  919. "cbnz %w[K], 1b\n"
  920. "cmp %w[is_first_k], #1\n"
  921. "beq 2f\n" LOAD_C_4
  922. "b 3f \n"
  923. "2: \n"
  924. "eor v0.16b, v0.16b, v0.16b\n"
  925. "eor v1.16b, v1.16b, v1.16b\n"
  926. "eor v2.16b, v2.16b, v2.16b\n"
  927. "eor v3.16b, v3.16b, v3.16b\n"
  928. "eor v4.16b, v4.16b, v4.16b\n"
  929. "eor v5.16b, v5.16b, v5.16b\n"
  930. "eor v6.16b, v6.16b, v6.16b\n"
  931. "eor v7.16b, v7.16b, v7.16b\n"
  932. "3:\n"
  933. "zip1 v8.2d, v24.2d, v25.2d\n"
  934. "zip1 v10.2d, v26.2d, v27.2d\n"
  935. "add v0.8h, v0.8h, v8.8h\n"
  936. "zip1 v12.2d, v28.2d, v29.2d\n"
  937. "add v1.8h, v1.8h, v10.8h\n"
  938. "zip1 v14.2d, v30.2d, v31.2d\n"
  939. "add v2.8h, v2.8h, v12.8h\n"
  940. "add v3.8h, v3.8h, v14.8h\n"
  941. // Store back into memory
  942. STORE_C_4
  943. :
  944. [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
  945. [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
  946. [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
  947. [ n_remain ] "+r"(n_remain)
  948. :
  949. : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
  950. "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
  951. "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
  952. "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
  953. "v29", "v30", "v31");
  954. // clang-format on
  955. #undef LOAD_C_4
  956. #undef STORE_C_4
  957. }
  958. static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K,
  959. int16_t* output, int LDC, bool is_first_k, int m_remain,
  960. int n_remain) {
  961. K /= 8;
  962. LDC = LDC * sizeof(int16_t);
  963. const int8_t* a_ptr = packB;//packA;
  964. const int8_t* b_ptr = packA;//packB;
  965. // clang-format off
  966. register int16_t* outptr asm("x0") = output;
  967. asm volatile(
  968. "eor v24.16b, v24.16b, v24.16b\n"
  969. "eor v25.16b, v25.16b, v25.16b\n"
  970. "eor v26.16b, v26.16b, v26.16b\n"
  971. "eor v27.16b, v27.16b, v27.16b\n"
  972. "eor v28.16b, v28.16b, v28.16b\n"
  973. "eor v29.16b, v29.16b, v29.16b\n"
  974. "eor v30.16b, v30.16b, v30.16b\n"
  975. "eor v31.16b, v31.16b, v31.16b\n"
  976. // General loop.
  977. "ld1 {v20.16b}, [%[a_ptr]],#16\n"
  978. "ld1 {v21.16b}, [%[a_ptr]],#16\n"
  979. "PRFM PLDL1KEEP, [%[a_ptr], #512]\n"
  980. "PRFM PLDL1KEEP, [%[b_ptr], #512]\n"
  981. "1:\n"
  982. "dup v0.8b,v20.b[0]\n"
  983. "ld1 {v22.16b}, [%[a_ptr]],#16\n"
  984. "dup v1.8b,v20.b[1]\n"
  985. "ld1 {v23.16b}, [%[a_ptr]],#16\n"
  986. "dup v2.8b,v20.b[2]\n"
  987. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  988. "dup v3.8b,v20.b[3]\n"
  989. "dup v4.8b,v20.b[4]\n"
  990. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  991. "dup v5.8b,v20.b[5]\n"
  992. "dup v6.8b,v20.b[6]\n"
  993. "dup v7.8b,v20.b[7]\n"
  994. "dup v8.8b,v20.b[8]\n"
  995. "smlal v24.8h, v0.8b, v16.8b\n"
  996. "dup v9.8b,v20.b[9]\n"
  997. "smlal v25.8h, v1.8b, v16.8b\n"
  998. "dup v10.8b,v20.b[10]\n"
  999. "smlal v26.8h, v2.8b, v16.8b\n"
  1000. "dup v11.8b,v20.b[11]\n"
  1001. "smlal v27.8h, v3.8b, v16.8b\n"
  1002. "dup v12.8b,v20.b[12]\n"
  1003. "smlal v28.8h, v4.8b, v16.8b\n"
  1004. "dup v13.8b,v20.b[13]\n"
  1005. "smlal v29.8h, v5.8b, v16.8b\n"
  1006. "dup v14.8b,v20.b[14]\n"
  1007. "smlal v30.8h, v6.8b, v16.8b\n"
  1008. "dup v15.8b,v20.b[15]\n"
  1009. "smlal v31.8h, v7.8b, v16.8b\n"
  1010. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  1011. "dup v0.8b,v21.b[0]\n"
  1012. "smlal v24.8h, v8.8b, v17.8b\n"
  1013. "dup v1.8b,v21.b[1]\n"
  1014. "smlal v25.8h, v9.8b, v17.8b\n"
  1015. "dup v2.8b,v21.b[2]\n"
  1016. "smlal v26.8h, v10.8b, v17.8b\n"
  1017. "dup v3.8b,v21.b[3]\n"
  1018. "smlal v27.8h, v11.8b, v17.8b\n"
  1019. "dup v4.8b,v21.b[4]\n"
  1020. "smlal v28.8h, v12.8b, v17.8b\n"
  1021. "dup v5.8b,v21.b[5]\n"
  1022. "smlal v29.8h, v13.8b, v17.8b\n"
  1023. "dup v6.8b,v21.b[6]\n"
  1024. "smlal v30.8h, v14.8b, v17.8b\n"
  1025. "dup v7.8b,v21.b[7]\n"
  1026. "smlal v31.8h, v15.8b, v17.8b\n"
  1027. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  1028. "dup v8.8b,v21.b[8]\n"
  1029. "smlal v24.8h, v0.8b, v16.8b\n"
  1030. "dup v9.8b,v21.b[9]\n"
  1031. "smlal v25.8h, v1.8b, v16.8b\n"
  1032. "dup v10.8b,v21.b[10]\n"
  1033. "smlal v26.8h, v2.8b, v16.8b\n"
  1034. "dup v11.8b,v21.b[11]\n"
  1035. "smlal v27.8h, v3.8b, v16.8b\n"
  1036. "dup v12.8b,v21.b[12]\n"
  1037. "smlal v28.8h, v4.8b, v16.8b\n"
  1038. "dup v13.8b,v21.b[13]\n"
  1039. "smlal v29.8h, v5.8b, v16.8b\n"
  1040. "dup v14.8b,v21.b[14]\n"
  1041. "smlal v30.8h, v6.8b, v16.8b\n"
  1042. "dup v15.8b,v21.b[15]\n"
  1043. "smlal v31.8h, v7.8b, v16.8b\n"
  1044. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  1045. "dup v0.8b,v22.b[0]\n"
  1046. "smlal v24.8h, v8.8b, v17.8b\n"
  1047. "dup v1.8b,v22.b[1]\n"
  1048. "smlal v25.8h, v9.8b, v17.8b\n"
  1049. "dup v2.8b,v22.b[2]\n"
  1050. "smlal v26.8h, v10.8b, v17.8b\n"
  1051. "dup v3.8b,v22.b[3]\n"
  1052. "smlal v27.8h, v11.8b, v17.8b\n"
  1053. "dup v4.8b,v22.b[4]\n"
  1054. "smlal v28.8h, v12.8b, v17.8b\n"
  1055. "dup v5.8b,v22.b[5]\n"
  1056. "smlal v29.8h, v13.8b, v17.8b\n"
  1057. "dup v6.8b,v22.b[6]\n"
  1058. "smlal v30.8h, v14.8b, v17.8b\n"
  1059. "dup v7.8b,v22.b[7]\n"
  1060. "smlal v31.8h, v15.8b, v17.8b\n"
  1061. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  1062. "dup v8.8b,v22.b[8]\n"
  1063. "smlal v24.8h, v0.8b, v16.8b\n"
  1064. "dup v9.8b,v22.b[9]\n"
  1065. "smlal v25.8h, v1.8b, v16.8b\n"
  1066. "dup v10.8b,v22.b[10]\n"
  1067. "smlal v26.8h, v2.8b, v16.8b\n"
  1068. "dup v11.8b,v22.b[11]\n"
  1069. "smlal v27.8h, v3.8b, v16.8b\n"
  1070. "dup v12.8b,v22.b[12]\n"
  1071. "smlal v28.8h, v4.8b, v16.8b\n"
  1072. "dup v13.8b,v22.b[13]\n"
  1073. "smlal v29.8h, v5.8b, v16.8b\n"
  1074. "dup v14.8b,v22.b[14]\n"
  1075. "smlal v30.8h, v6.8b, v16.8b\n"
  1076. "dup v15.8b,v22.b[15]\n"
  1077. "smlal v31.8h, v7.8b, v16.8b\n"
  1078. "ld1 {v16.8b}, [%[b_ptr]], 8\n"
  1079. "dup v0.8b,v23.b[0]\n"
  1080. "smlal v24.8h, v8.8b, v17.8b\n"
  1081. "dup v1.8b,v23.b[1]\n"
  1082. "smlal v25.8h, v9.8b, v17.8b\n"
  1083. "dup v2.8b,v23.b[2]\n"
  1084. "smlal v26.8h, v10.8b, v17.8b\n"
  1085. "dup v3.8b,v23.b[3]\n"
  1086. "smlal v27.8h, v11.8b, v17.8b\n"
  1087. "dup v4.8b,v23.b[4]\n"
  1088. "smlal v28.8h, v12.8b, v17.8b\n"
  1089. "dup v5.8b,v23.b[5]\n"
  1090. "smlal v29.8h, v13.8b, v17.8b\n"
  1091. "dup v6.8b,v23.b[6]\n"
  1092. "smlal v30.8h, v14.8b, v17.8b\n"
  1093. "dup v7.8b,v23.b[7]\n"
  1094. "smlal v31.8h, v15.8b, v17.8b\n"
  1095. "ld1 {v17.8b}, [%[b_ptr]], 8\n"
  1096. "dup v8.8b,v23.b[8]\n"
  1097. "smlal v24.8h, v0.8b, v16.8b\n"
  1098. "dup v9.8b,v23.b[9]\n"
  1099. "smlal v25.8h, v1.8b, v16.8b\n"
  1100. "dup v10.8b,v23.b[10]\n"
  1101. "smlal v26.8h, v2.8b, v16.8b\n"
  1102. "dup v11.8b,v23.b[11]\n"
  1103. "smlal v27.8h, v3.8b, v16.8b\n"
  1104. "dup v12.8b,v23.b[12]\n"
  1105. "smlal v28.8h, v4.8b, v16.8b\n"
  1106. "dup v13.8b,v23.b[13]\n"
  1107. "smlal v29.8h, v5.8b, v16.8b\n"
  1108. "dup v14.8b,v23.b[14]\n"
  1109. "smlal v30.8h, v6.8b, v16.8b\n"
  1110. "dup v15.8b,v23.b[15]\n"
  1111. "smlal v31.8h, v7.8b, v16.8b\n"
  1112. "ld1 {v20.16b}, [%[a_ptr]],#16\n"
  1113. "smlal v24.8h, v8.8b, v17.8b\n"
  1114. "smlal v25.8h, v9.8b, v17.8b\n"
  1115. "smlal v26.8h, v10.8b, v17.8b\n"
  1116. "smlal v27.8h, v11.8b, v17.8b\n"
  1117. "ld1 {v21.16b}, [%[a_ptr]],#16\n"
  1118. "smlal v28.8h, v12.8b, v17.8b\n"
  1119. "smlal v29.8h, v13.8b, v17.8b\n"
  1120. "smlal v30.8h, v14.8b, v17.8b\n"
  1121. "smlal v31.8h, v15.8b, v17.8b\n"
  1122. "subs %w[K], %w[K], #1 \n"
  1123. "cbnz %w[K], 1b \n"
  1124. "cmp %w[is_first_k], #1 \n"
  1125. "beq 2f \n"
  1126. "cmp %w[n_remain],#7 \n"
  1127. "beq 200f \n"
  1128. "cmp %w[n_remain],#6 \n"
  1129. "beq 201f \n"
  1130. "cmp %w[n_remain],#5 \n"
  1131. "beq 202f \n"
  1132. "cmp %w[n_remain],#4 \n"
  1133. "beq 203f \n"
  1134. "cmp %w[n_remain],#3 \n"
  1135. "beq 204f \n"
  1136. "cmp %w[n_remain],#2 \n"
  1137. "beq 205f \n"
  1138. "cmp %w[n_remain],#1 \n"
  1139. "beq 206f \n"
  1140. "200: \n"
  1141. "ld1 {v0.8h}, [x0],#16 \n"
  1142. "ld1 {v1.8h}, [x0],#16 \n"
  1143. "ld1 {v2.8h}, [x0],#16 \n"
  1144. "ld1 {v3.d}[0], [x0],#8 \n"
  1145. "b 3f \n"
  1146. "201: \n"
  1147. "ld1 {v0.8h}, [x0],#16 \n"
  1148. "ld1 {v1.8h}, [x0],#16 \n"
  1149. "ld1 {v2.8h}, [x0],#16 \n"
  1150. "b 3f \n"
  1151. "202: \n"
  1152. "ld1 {v0.8h}, [x0],#16 \n"
  1153. "ld1 {v1.8h}, [x0],#16 \n"
  1154. "ld1 {v2.d}[0], [x0],#8 \n"
  1155. "b 3f \n"
  1156. "203: \n"
  1157. "ld1 {v0.8h}, [x0],#16 \n"
  1158. "ld1 {v1.8h}, [x0],#16 \n"
  1159. "b 3f \n"
  1160. "204: \n"
  1161. "ld1 {v0.8h}, [x0],#16 \n"
  1162. "ld1 {v1.d}[0], [x0],#8 \n"
  1163. "b 3f \n"
  1164. "205: \n"
  1165. "ld1 {v0.8h}, [x0],#16 \n"
  1166. "b 3f \n"
  1167. "206: \n"
  1168. "ld1 {v0.d}[0], [x0],#8 \n"
  1169. "b 3f \n"
  1170. "2: \n"
  1171. "eor v0.16b, v0.16b, v0.16b\n"
  1172. "eor v1.16b, v1.16b, v1.16b\n"
  1173. "eor v2.16b, v2.16b, v2.16b\n"
  1174. "eor v3.16b, v3.16b, v3.16b\n"
  1175. "eor v4.16b, v4.16b, v4.16b\n"
  1176. "eor v5.16b, v5.16b, v5.16b\n"
  1177. "eor v6.16b, v6.16b, v6.16b\n"
  1178. "eor v7.16b, v7.16b, v7.16b\n"
  1179. "3: \n"
  1180. "zip1 v8.2d, v24.2d, v25.2d\n"
  1181. "zip1 v10.2d, v26.2d, v27.2d\n"
  1182. "add v0.8h, v0.8h, v8.8h \n"
  1183. "zip1 v12.2d, v28.2d, v29.2d\n"
  1184. "add v1.8h, v1.8h, v10.8h\n"
  1185. "zip1 v14.2d, v30.2d, v31.2d\n"
  1186. "add v2.8h, v2.8h, v12.8h\n"
  1187. "add v3.8h, v3.8h, v14.8h\n"
  1188. // Store back into memory
  1189. "cmp %w[n_remain],#7 \n"
  1190. "beq 100f \n"
  1191. "cmp %w[n_remain],#6 \n"
  1192. "beq 101f \n"
  1193. "cmp %w[n_remain],#5 \n"
  1194. "beq 102f \n"
  1195. "cmp %w[n_remain],#4 \n"
  1196. "beq 103f \n"
  1197. "cmp %w[n_remain],#3 \n"
  1198. "beq 104f \n"
  1199. "cmp %w[n_remain],#2 \n"
  1200. "beq 105f \n"
  1201. "cmp %w[n_remain],#1 \n"
  1202. "beq 106f \n"
  1203. "100: \n"
  1204. "st1 {v0.8h}, [x0],#16 \n"
  1205. "st1 {v1.8h}, [x0],#16 \n"
  1206. "st1 {v2.8h}, [x0],#16 \n"
  1207. "st1 {v3.d}[0], [x0],#8 \n"
  1208. "b 1000f \n"
  1209. "101: \n"
  1210. "st1 {v0.8h}, [x0],#16 \n"
  1211. "st1 {v1.8h}, [x0],#16 \n"
  1212. "st1 {v2.8h}, [x0],#16 \n"
  1213. "b 1000f \n"
  1214. "102: \n"
  1215. "st1 {v0.8h}, [x0],#16 \n"
  1216. "st1 {v1.8h}, [x0],#16 \n"
  1217. "st1 {v2.d}[0], [x0],#8 \n"
  1218. "b 1000f \n"
  1219. "103: \n"
  1220. "st1 {v0.8h}, [x0],#16 \n"
  1221. "st1 {v1.8h}, [x0],#16 \n"
  1222. "b 1000f \n"
  1223. "104: \n"
  1224. "st1 {v0.8h}, [x0],#16 \n"
  1225. "st1 {v1.d}[0], [x0],#8 \n"
  1226. "b 1000f \n"
  1227. "105: \n"
  1228. "st1 {v0.8h}, [x0],#16 \n"
  1229. "b 1000f \n"
  1230. "106: \n"
  1231. "st1 {v0.d}[0], [x0],#8 \n"
  1232. "b 1000f \n"
  1233. "1000: \n"
  1234. :
  1235. [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
  1236. [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
  1237. [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
  1238. [ n_remain ] "+r"(n_remain)
  1239. :
  1240. : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
  1241. "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
  1242. "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
  1243. "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
  1244. "v29", "v30", "v31");
  1245. // clang-format on
  1246. #undef LOAD_C_4
  1247. #undef STORE_C_4
  1248. }
  1249. //! pack to icxoc
  1250. //! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7))
  1251. //! if M K is not times of 8,pack 0 instead
  1252. static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr,
  1253. const dt_int8* inptr, int ldin,
  1254. int m0, int mmax, int k0, int kmax) {
  1255. megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4");
  1256. megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
  1257. constexpr int pack_m = 8;
  1258. constexpr int pack_k = 8;
  1259. constexpr int pack_size = 4;
  1260. int8_t tmpbuff0[pack_m * pack_size] = {0};
  1261. int8_t tmpbuff1[pack_m * pack_size] = {0};
  1262. int8_t zerobuff[pack_m * pack_size] = {0};
  1263. const int m_size = mmax - m0;
  1264. const int m_end = m_size / pack_m * pack_m + m0;
  1265. int remain_m = mmax - m_end;
  1266. for (int m_idx = m0; m_idx < m_end; m_idx += pack_m) {
  1267. const int8_t* inptr0 = inptr + m_idx / pack_size * ldin + k0;
  1268. const int8_t* inptr1 = inptr0 + ldin;
  1269. prefetch_2x(inptr0);
  1270. prefetch_2x(inptr1);
  1271. int k_idx = k0;
  1272. for ( ; k_idx + 7 < kmax; k_idx += pack_k) {
  1273. interleave_8x8_mk4_b(inptr0,inptr1,outptr);
  1274. }
  1275. if (k_idx < kmax) {
  1276. memcpy(tmpbuff0, inptr0, sizeof(int8_t) * (kmax - k_idx) * pack_size);
  1277. memcpy(tmpbuff1, inptr1, sizeof(int8_t) * (kmax - k_idx) * pack_size);
  1278. inptr0 = tmpbuff0;
  1279. inptr1 = tmpbuff1;
  1280. interleave_8x8_mk4_b(inptr0, inptr1, outptr);
  1281. }
  1282. }
  1283. int m_idx = m_end;
  1284. if (remain_m == 4) {
  1285. const int8_t* inptr0 = inptr + m_idx / pack_size * ldin + k0;
  1286. const int8_t* inptr1 = inptr0 + ldin;
  1287. prefetch_2x(inptr0);
  1288. prefetch_2x(inptr1);
  1289. int k_idx = k0;
  1290. for ( ; k_idx + 7 < kmax; k_idx += pack_k) {
  1291. inptr1 = zerobuff;
  1292. interleave_8x8_mk4_b(inptr0,inptr1,outptr);
  1293. }
  1294. if (k_idx < kmax) {
  1295. memcpy(tmpbuff0, inptr0, sizeof(int8_t) * (kmax - k_idx) * pack_size);
  1296. inptr0 = tmpbuff0;
  1297. inptr1 = zerobuff;
  1298. interleave_8x8_mk4_b(inptr0, inptr1, outptr);
  1299. }
  1300. }
  1301. }
  1302. //! pack to nxic
  1303. //! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead.
  1304. static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in,
  1305. int ldin, int n0, int nmax, int k0,
  1306. int kmax) {
  1307. megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
  1308. constexpr int pack_n = 8;
  1309. constexpr int pack_k = 8;
  1310. constexpr int pack_size = 4;
  1311. int8_t tmpbuff0[pack_n * pack_size] = {0};
  1312. int8_t tmpbuff1[pack_n * pack_size] = {0};
  1313. int8_t zerobuff[pack_n * pack_size] = {0};
  1314. const int ksize = round_up<int>((kmax - k0),8);
  1315. const int nsize = nmax - n0;
  1316. const int n_end = nsize / pack_n * pack_n + n0;
  1317. const int remain_n = nsize % pack_n;
  1318. int output_stride = ksize * pack_n;
  1319. int8_t* outptr_base = out;
  1320. int k_idx = k0;
  1321. for ( ; k_idx + 7 < kmax; k_idx += pack_k) {
  1322. const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size;
  1323. const int8_t* inptr1 = inptr0 + ldin;
  1324. prefetch_3x(inptr0);
  1325. prefetch_3x(inptr1);
  1326. auto outptr = outptr_base;
  1327. for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) {
  1328. transpose_8x8_mk4_b(inptr0, inptr1, outptr);
  1329. outptr += output_stride;
  1330. }
  1331. if (remain_n > 0) {
  1332. memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size);
  1333. memcpy(tmpbuff1, inptr1, sizeof(int8_t) * remain_n * pack_size);
  1334. inptr0 = tmpbuff0;
  1335. inptr1 = tmpbuff1;
  1336. transpose_8x8_mk4_b(inptr0, inptr1, outptr);
  1337. outptr += output_stride;
  1338. }
  1339. outptr_base += pack_n * pack_k;
  1340. }
  1341. if(k_idx < kmax){
  1342. const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size;
  1343. const int8_t* inptr1 = nullptr;
  1344. prefetch_3x(inptr0);
  1345. auto outptr = outptr_base;
  1346. for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) {
  1347. inptr1 = zerobuff;
  1348. transpose_8x8_mk4_b(inptr0, inptr1, outptr);
  1349. outptr += output_stride;
  1350. }
  1351. if (remain_n > 0) {
  1352. memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size);
  1353. inptr1 = zerobuff;
  1354. inptr0 = tmpbuff0;
  1355. transpose_8x8_mk4_b(inptr0, inptr1, outptr);
  1356. outptr += output_stride;
  1357. }
  1358. outptr_base += pack_n * pack_size;
  1359. }
  1360. }
  1361. } // namespace matmul_mk4_16x12x4_a53
  1362. } // namespace aarch64
  1363. } // namespace megdnn
  1364. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台