You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_mk4_4x4x16.h 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892
  1. /**
  2. * \file dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include <cstring>
  13. #if !(__ARM_FEATURE_DOTPROD)
  14. #include "src/aarch64/matrix_mul/asm/common.h"
  15. #include "src/arm_common/simd_macro/marm_neon.h"
  16. namespace megdnn {
  17. namespace aarch64 {
  18. namespace matmul_mk4_4x4x16 {
  19. /**
  20. * Overview of register layout:
  21. *
  22. * A 16x4 cell of Rhs is stored in 8bit in v0-q3.
  23. * B 16x4 cell of Lhs is stored in 8bit in q4-q7
  24. * C 8x16 block of accumulators is stored in 8bit in q8--q31.
  25. *
  26. * \warning Fast kernel operating on int8 operands.
  27. * It is assumed that one of the two int8 operands only takes values
  28. * in [-127, 127], while the other may freely range in [-128, 127].
  29. * The issue with both operands taking the value -128 is that:
  30. * -128*-128 + -128*-128 == -32768 overflows int16.
  31. * Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16
  32. * range. That is the basic idea of this kernel.
  33. *
  34. *
  35. * +--------+--------+---------+---------+
  36. * |v4[0-16]|v5[0-16]| v6[0-16]| v7[0-16]|
  37. * Rhs +--------+--------+---------+---------+
  38. * | | | | |
  39. *
  40. * Lhs | | | | |
  41. *
  42. * +--------+ - - - - +-------------------------------------+
  43. * |v0[0-16]| |v16[0-4]|v17[0-4]| v18[0-4]| v19[0-4]|
  44. * |v1[0-16]| |v20[0-4]|v21[0-4]| v22[0-4]| v23[0-4]|
  45. * |v2[0-16]| |v24[0-4]|v25[0-4]| v26[0-4]| v27[0-4]|
  46. * |v3[0-16]| |v28[0-4]|v29[0-4]| v30[0-4]| v31[0-4]|
  47. * +--------+ - - - - +-------------------------------------+
  48. *
  49. * Accumulator
  50. */
  51. static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
  52. int32_t* output, bool is_first_k) {
  53. K = div_ceil(K, 16);
  54. const int8_t* a_ptr = packA;
  55. const int8_t* b_ptr = packB;
  56. asm volatile(
  57. // load accumulator C
  58. "ld1 {v0.16b}, [%[a_ptr]], #16\n"
  59. "eor v16.16b, v16.16b, v16.16b\n"
  60. "eor v17.16b, v17.16b, v17.16b\n"
  61. "eor v18.16b, v18.16b, v18.16b\n"
  62. "ld1 {v1.16b}, [%[a_ptr]], #16\n"
  63. "eor v19.16b, v19.16b, v19.16b\n"
  64. "eor v20.16b, v19.16b, v19.16b\n"
  65. "eor v21.16b, v19.16b, v19.16b\n"
  66. "ld1 {v4.16b, v5.16b}, [%[b_ptr]], #32\n"
  67. "eor v22.16b, v19.16b, v19.16b\n"
  68. "PRFM PLDL1KEEP, [%[a_ptr], #32]\n"
  69. "eor v23.16b, v19.16b, v19.16b\n"
  70. "eor v24.16b, v19.16b, v19.16b\n"
  71. "PRFM PLDL1KEEP, [%[b_ptr], #32]\n"
  72. "eor v25.16b, v19.16b, v19.16b\n"
  73. "eor v26.16b, v19.16b, v19.16b\n"
  74. "PRFM PLDL1KEEP, [%[b_ptr], #64]\n"
  75. "eor v27.16b, v19.16b, v19.16b\n"
  76. "eor v28.16b, v19.16b, v19.16b\n"
  77. "PRFM PLDL1KEEP, [%[a_ptr], #64]\n"
  78. "eor v29.16b, v19.16b, v19.16b\n"
  79. "eor v30.16b, v19.16b, v19.16b\n"
  80. "PRFM PLDL1KEEP, [%[b_ptr], #128]\n"
  81. "eor v31.16b, v19.16b, v19.16b\n"
  82. //! if K==1 jump to compute last K
  83. "cmp %w[k], #2\n"
  84. "beq 2f\n"
  85. "blt 3f\n"
  86. //! K>2
  87. "1:\n"
  88. //! First k
  89. "smull v8.8h, v0.8b, v4.8b\n"
  90. "smull v9.8h, v0.8b, v5.8b\n"
  91. "ld1 {v6.16b}, [%[b_ptr]], #16\n"
  92. "smull v12.8h, v1.8b, v4.8b\n"
  93. "smull v13.8h, v1.8b, v5.8b\n"
  94. "ld1 {v7.16b}, [%[b_ptr]], #16\n"
  95. "smlal2 v8.8h, v0.16b, v4.16b\n"
  96. "smlal2 v9.8h, v0.16b, v5.16b\n"
  97. "smlal2 v12.8h, v1.16b, v4.16b\n"
  98. "smlal2 v13.8h, v1.16b, v5.16b\n"
  99. "smull v10.8h, v0.8b, v6.8b\n"
  100. "ld1 {v2.16b}, [%[a_ptr]], #16\n"
  101. "smull v11.8h, v0.8b, v7.8b\n"
  102. "smull v14.8h, v1.8b, v6.8b\n"
  103. "ld1 {v3.16b}, [%[a_ptr]], #16\n"
  104. "smull v15.8h, v1.8b, v7.8b\n"
  105. "sadalp v16.4s, v8.8h\n"
  106. "smlal2 v10.8h, v0.16b, v6.16b\n"
  107. "sadalp v17.4s, v9.8h\n"
  108. "smlal2 v11.8h, v0.16b, v7.16b\n"
  109. "sadalp v20.4s, v12.8h\n"
  110. "smlal2 v14.8h, v1.16b, v6.16b\n"
  111. "sadalp v21.4s, v13.8h\n"
  112. "smlal2 v15.8h, v1.16b, v7.16b\n"
  113. "smull v8.8h, v2.8b, v4.8b\n"
  114. "smull v9.8h, v2.8b, v5.8b\n"
  115. "ld1 {v0.16b}, [%[a_ptr]], #16\n"
  116. "smull v12.8h, v3.8b, v4.8b\n"
  117. "ld1 {v1.16b}, [%[a_ptr]], #16\n"
  118. "smull v13.8h, v3.8b, v5.8b\n"
  119. "sadalp v18.4s, v10.8h\n"
  120. "smlal2 v8.8h, v2.16b, v4.16b\n"
  121. "sadalp v19.4s, v11.8h\n"
  122. "smlal2 v9.8h, v2.16b, v5.16b\n"
  123. "sadalp v22.4s, v14.8h\n"
  124. "smlal2 v12.8h, v3.16b, v4.16b\n"
  125. "sadalp v23.4s, v15.8h\n"
  126. "smlal2 v13.8h, v3.16b, v5.16b\n"
  127. "smull v10.8h, v2.8b, v6.8b\n"
  128. "smull v11.8h, v2.8b, v7.8b\n"
  129. "ld1 {v4.16b}, [%[b_ptr]], #16\n"
  130. "smull v14.8h, v3.8b, v6.8b\n"
  131. "ld1 {v5.16b}, [%[b_ptr]], #16\n"
  132. "smull v15.8h, v3.8b, v7.8b\n"
  133. "sadalp v24.4s, v8.8h\n"
  134. "smlal2 v10.8h, v2.16b, v6.16b\n"
  135. "sadalp v25.4s, v9.8h\n"
  136. "smlal2 v11.8h, v2.16b, v7.16b\n"
  137. "sadalp v28.4s, v12.8h\n"
  138. "smlal2 v14.8h, v3.16b, v6.16b\n"
  139. "sadalp v29.4s, v13.8h\n"
  140. "smlal2 v15.8h, v3.16b, v7.16b\n"
  141. //! Second k
  142. "smull v8.8h, v0.8b, v4.8b\n"
  143. "smull v9.8h, v0.8b, v5.8b\n"
  144. "ld1 {v6.16b}, [%[b_ptr]], #16\n"
  145. "smull v12.8h, v1.8b, v4.8b\n"
  146. "smull v13.8h, v1.8b, v5.8b\n"
  147. "ld1 {v7.16b}, [%[b_ptr]], #16\n"
  148. "smlal2 v8.8h, v0.16b, v4.16b\n"
  149. "sadalp v26.4s, v10.8h\n"
  150. "smlal2 v9.8h, v0.16b, v5.16b\n"
  151. "sadalp v27.4s, v11.8h\n"
  152. "smlal2 v12.8h, v1.16b, v4.16b\n"
  153. "sadalp v30.4s, v14.8h\n"
  154. "smlal2 v13.8h, v1.16b, v5.16b\n"
  155. "sadalp v31.4s, v15.8h\n"
  156. "smull v10.8h, v0.8b, v6.8b\n"
  157. "ld1 {v2.16b}, [%[a_ptr]], #16\n"
  158. "smull v11.8h, v0.8b, v7.8b\n"
  159. "smull v14.8h, v1.8b, v6.8b\n"
  160. "ld1 {v3.16b}, [%[a_ptr]], #16\n"
  161. "smull v15.8h, v1.8b, v7.8b\n"
  162. "sadalp v16.4s, v8.8h\n"
  163. "smlal2 v10.8h, v0.16b, v6.16b\n"
  164. "sadalp v17.4s, v9.8h\n"
  165. "smlal2 v11.8h, v0.16b, v7.16b\n"
  166. "sadalp v20.4s, v12.8h\n"
  167. "smlal2 v14.8h, v1.16b, v6.16b\n"
  168. "sadalp v21.4s, v13.8h\n"
  169. "smlal2 v15.8h, v1.16b, v7.16b\n"
  170. "smull v8.8h, v2.8b, v4.8b\n"
  171. "smull v9.8h, v2.8b, v5.8b\n"
  172. "ld1 {v0.16b}, [%[a_ptr]], #16\n"
  173. "smull v12.8h, v3.8b, v4.8b\n"
  174. "ld1 {v1.16b}, [%[a_ptr]], #16\n"
  175. "smull v13.8h, v3.8b, v5.8b\n"
  176. "sadalp v18.4s, v10.8h\n"
  177. "smlal2 v8.8h, v2.16b, v4.16b\n"
  178. "sadalp v19.4s, v11.8h\n"
  179. "smlal2 v9.8h, v2.16b, v5.16b\n"
  180. "sadalp v22.4s, v14.8h\n"
  181. "smlal2 v12.8h, v3.16b, v4.16b\n"
  182. "sadalp v23.4s, v15.8h\n"
  183. "smlal2 v13.8h, v3.16b, v5.16b\n"
  184. "sub %w[k], %w[k], #2\n"
  185. "cmp %w[k], #2\n"
  186. "smull v10.8h, v2.8b, v6.8b\n"
  187. "ld1 {v4.16b}, [%[b_ptr]], #16\n"
  188. "smull v11.8h, v2.8b, v7.8b\n"
  189. "ld1 {v5.16b}, [%[b_ptr]], #16\n"
  190. "smull v14.8h, v3.8b, v6.8b\n"
  191. "sadalp v24.4s, v8.8h\n"
  192. "smull v15.8h, v3.8b, v7.8b\n"
  193. "sadalp v25.4s, v9.8h\n"
  194. "smlal2 v10.8h, v2.16b, v6.16b\n"
  195. "sadalp v28.4s, v12.8h\n"
  196. "smlal2 v11.8h, v2.16b, v7.16b\n"
  197. "sadalp v29.4s, v13.8h\n"
  198. "smlal2 v14.8h, v3.16b, v6.16b\n"
  199. "sadalp v26.4s, v10.8h\n"
  200. "smlal2 v15.8h, v3.16b, v7.16b\n"
  201. "sadalp v27.4s, v11.8h\n"
  202. "sadalp v30.4s, v14.8h\n"
  203. "sadalp v31.4s, v15.8h\n"
  204. "bgt 1b\n"
  205. "blt 3f\n"
  206. //! K==2
  207. "2:\n"
  208. "smull v8.8h, v0.8b, v4.8b\n"
  209. "smull v9.8h, v0.8b, v5.8b\n"
  210. "ld1 {v6.16b}, [%[b_ptr]], #16\n"
  211. "smull v12.8h, v1.8b, v4.8b\n"
  212. "smull v13.8h, v1.8b, v5.8b\n"
  213. "ld1 {v7.16b}, [%[b_ptr]], #16\n"
  214. "smlal2 v8.8h, v0.16b, v4.16b\n"
  215. "smlal2 v9.8h, v0.16b, v5.16b\n"
  216. "smlal2 v12.8h, v1.16b, v4.16b\n"
  217. "smlal2 v13.8h, v1.16b, v5.16b\n"
  218. "smull v10.8h, v0.8b, v6.8b\n"
  219. "ld1 {v2.16b}, [%[a_ptr]], #16\n"
  220. "smull v11.8h, v0.8b, v7.8b\n"
  221. "smull v14.8h, v1.8b, v6.8b\n"
  222. "ld1 {v3.16b}, [%[a_ptr]], #16\n"
  223. "smull v15.8h, v1.8b, v7.8b\n"
  224. "sadalp v16.4s, v8.8h\n"
  225. "smlal2 v10.8h, v0.16b, v6.16b\n"
  226. "sadalp v17.4s, v9.8h\n"
  227. "smlal2 v11.8h, v0.16b, v7.16b\n"
  228. "sadalp v20.4s, v12.8h\n"
  229. "smlal2 v14.8h, v1.16b, v6.16b\n"
  230. "sadalp v21.4s, v13.8h\n"
  231. "smlal2 v15.8h, v1.16b, v7.16b\n"
  232. "smull v8.8h, v2.8b, v4.8b\n"
  233. "smull v9.8h, v2.8b, v5.8b\n"
  234. "ld1 {v0.16b}, [%[a_ptr]], #16\n"
  235. "smull v12.8h, v3.8b, v4.8b\n"
  236. "ld1 {v1.16b}, [%[a_ptr]], #16\n"
  237. "smull v13.8h, v3.8b, v5.8b\n"
  238. "sadalp v18.4s, v10.8h\n"
  239. "smlal2 v8.8h, v2.16b, v4.16b\n"
  240. "sadalp v19.4s, v11.8h\n"
  241. "smlal2 v9.8h, v2.16b, v5.16b\n"
  242. "sadalp v22.4s, v14.8h\n"
  243. "smlal2 v12.8h, v3.16b, v4.16b\n"
  244. "sadalp v23.4s, v15.8h\n"
  245. "smlal2 v13.8h, v3.16b, v5.16b\n"
  246. "smull v10.8h, v2.8b, v6.8b\n"
  247. "smull v11.8h, v2.8b, v7.8b\n"
  248. "ld1 {v4.16b}, [%[b_ptr]], #16\n"
  249. "smull v14.8h, v3.8b, v6.8b\n"
  250. "ld1 {v5.16b}, [%[b_ptr]], #16\n"
  251. "smull v15.8h, v3.8b, v7.8b\n"
  252. "sadalp v24.4s, v8.8h\n"
  253. "smlal2 v10.8h, v2.16b, v6.16b\n"
  254. "sadalp v25.4s, v9.8h\n"
  255. "smlal2 v11.8h, v2.16b, v7.16b\n"
  256. "sadalp v28.4s, v12.8h\n"
  257. "smlal2 v14.8h, v3.16b, v6.16b\n"
  258. "sadalp v29.4s, v13.8h\n"
  259. "smlal2 v15.8h, v3.16b, v7.16b\n"
  260. "sadalp v26.4s, v10.8h\n"
  261. "sadalp v27.4s, v11.8h\n"
  262. "sadalp v30.4s, v14.8h\n"
  263. "sadalp v31.4s, v15.8h\n"
  264. //! K==1
  265. "3:\n"
  266. "smull v8.8h, v0.8b, v4.8b\n"
  267. "smull v9.8h, v0.8b, v5.8b\n"
  268. "ld1 {v6.16b}, [%[b_ptr]], #16\n"
  269. "smull v12.8h, v1.8b, v4.8b\n"
  270. "smull v13.8h, v1.8b, v5.8b\n"
  271. "ld1 {v7.16b}, [%[b_ptr]], #16\n"
  272. "smlal2 v8.8h, v0.16b, v4.16b\n"
  273. "smlal2 v9.8h, v0.16b, v5.16b\n"
  274. "smlal2 v12.8h, v1.16b, v4.16b\n"
  275. "smlal2 v13.8h, v1.16b, v5.16b\n"
  276. "smull v10.8h, v0.8b, v6.8b\n"
  277. "ld1 {v2.16b}, [%[a_ptr]], #16\n"
  278. "smull v11.8h, v0.8b, v7.8b\n"
  279. "smull v14.8h, v1.8b, v6.8b\n"
  280. "ld1 {v3.16b}, [%[a_ptr]], #16\n"
  281. "smull v15.8h, v1.8b, v7.8b\n"
  282. "sadalp v16.4s, v8.8h\n"
  283. "smlal2 v10.8h, v0.16b, v6.16b\n"
  284. "sadalp v17.4s, v9.8h\n"
  285. "smlal2 v11.8h, v0.16b, v7.16b\n"
  286. "sadalp v20.4s, v12.8h\n"
  287. "smlal2 v14.8h, v1.16b, v6.16b\n"
  288. "sadalp v21.4s, v13.8h\n"
  289. "smlal2 v15.8h, v1.16b, v7.16b\n"
  290. "smull v8.8h, v2.8b, v4.8b\n"
  291. "smull v9.8h, v2.8b, v5.8b\n"
  292. "smull v12.8h, v3.8b, v4.8b\n"
  293. "smull v13.8h, v3.8b, v5.8b\n"
  294. "sadalp v18.4s, v10.8h\n"
  295. "smlal2 v8.8h, v2.16b, v4.16b\n"
  296. "sadalp v19.4s, v11.8h\n"
  297. "smlal2 v9.8h, v2.16b, v5.16b\n"
  298. "sadalp v22.4s, v14.8h\n"
  299. "smlal2 v12.8h, v3.16b, v4.16b\n"
  300. "sadalp v23.4s, v15.8h\n"
  301. "smlal2 v13.8h, v3.16b, v5.16b\n"
  302. "smull v10.8h, v2.8b, v6.8b\n"
  303. "sadalp v24.4s, v8.8h\n"
  304. "smull v11.8h, v2.8b, v7.8b\n"
  305. "sadalp v25.4s, v9.8h\n"
  306. "smull v14.8h, v3.8b, v6.8b\n"
  307. "sadalp v28.4s, v12.8h\n"
  308. "smull v15.8h, v3.8b, v7.8b\n"
  309. "sadalp v29.4s, v13.8h\n"
  310. "smlal2 v10.8h, v2.16b, v6.16b\n"
  311. "smlal2 v11.8h, v2.16b, v7.16b\n"
  312. "sadalp v26.4s, v10.8h\n"
  313. "smlal2 v14.8h, v3.16b, v6.16b\n"
  314. "sadalp v27.4s, v11.8h\n"
  315. "smlal2 v15.8h, v3.16b, v7.16b\n"
  316. "sadalp v30.4s, v14.8h\n"
  317. "sadalp v31.4s, v15.8h\n"
  318. "addp v4.4s, v16.4s, v20.4s\n"
  319. "addp v5.4s, v24.4s, v28.4s\n"
  320. "addp v6.4s, v17.4s, v21.4s\n"
  321. "addp v7.4s, v25.4s, v29.4s\n"
  322. "addp v8.4s, v18.4s, v22.4s\n"
  323. "addp v9.4s, v26.4s, v30.4s\n"
  324. "addp v10.4s, v19.4s, v23.4s\n"
  325. "addp v11.4s, v27.4s, v31.4s\n"
  326. "cmp %w[is_first_k], #1\n"
  327. "addp v0.4s, v4.4s, v5.4s\n"
  328. "addp v1.4s, v6.4s, v7.4s\n"
  329. "addp v2.4s, v8.4s, v9.4s\n"
  330. "addp v3.4s, v10.4s, v11.4s\n"
  331. "beq 6f\n"
  332. "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output]]\n"
  333. "add v0.4s, v0.4s, v8.4s\n"
  334. "add v1.4s, v1.4s, v9.4s\n"
  335. "add v2.4s, v2.4s, v10.4s\n"
  336. "add v3.4s, v3.4s, v11.4s\n"
  337. "6:\n"
  338. "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n"
  339. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
  340. [is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output)
  341. :
  342. : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
  343. "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
  344. "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
  345. "v29", "v30", "v31", "cc", "memory");
  346. }
  347. static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K,
  348. int32_t* output, bool is_first_k, size_t remain_n) {
  349. K = div_ceil(K, 16);
  350. const int8_t* a_ptr = packA;
  351. const int8_t* b_ptr = packB;
  352. asm volatile(
  353. // load accumulator C
  354. "ld1 {v0.16b}, [%[a_ptr]], #16\n"
  355. "eor v16.16b, v16.16b, v16.16b\n"
  356. "eor v17.16b, v17.16b, v17.16b\n"
  357. "eor v18.16b, v18.16b, v18.16b\n"
  358. "ld1 {v1.16b}, [%[a_ptr]], #16\n"
  359. "eor v19.16b, v19.16b, v19.16b\n"
  360. "eor v20.16b, v19.16b, v19.16b\n"
  361. "eor v21.16b, v19.16b, v19.16b\n"
  362. "ld1 {v4.16b, v5.16b}, [%[b_ptr]], #32\n"
  363. "eor v22.16b, v19.16b, v19.16b\n"
  364. "PRFM PLDL1KEEP, [%[a_ptr], #32]\n"
  365. "eor v23.16b, v19.16b, v19.16b\n"
  366. "eor v24.16b, v19.16b, v19.16b\n"
  367. "PRFM PLDL1KEEP, [%[b_ptr], #32]\n"
  368. "eor v25.16b, v19.16b, v19.16b\n"
  369. "eor v26.16b, v19.16b, v19.16b\n"
  370. "PRFM PLDL1KEEP, [%[b_ptr], #64]\n"
  371. "eor v27.16b, v19.16b, v19.16b\n"
  372. "eor v28.16b, v19.16b, v19.16b\n"
  373. "PRFM PLDL1KEEP, [%[a_ptr], #64]\n"
  374. "eor v29.16b, v19.16b, v19.16b\n"
  375. "eor v30.16b, v19.16b, v19.16b\n"
  376. "PRFM PLDL1KEEP, [%[b_ptr], #128]\n"
  377. "eor v31.16b, v19.16b, v19.16b\n"
  378. //! if K==1 jump to compute last K
  379. "cmp %w[k], #2\n"
  380. "beq 2f\n"
  381. "blt 3f\n"
  382. //! K>2
  383. "1:\n"
  384. //! First k
  385. "smull v8.8h, v0.8b, v4.8b\n"
  386. "smull v9.8h, v0.8b, v5.8b\n"
  387. "ld1 {v6.16b}, [%[b_ptr]], #16\n"
  388. "smull v12.8h, v1.8b, v4.8b\n"
  389. "smull v13.8h, v1.8b, v5.8b\n"
  390. "ld1 {v7.16b}, [%[b_ptr]], #16\n"
  391. "smlal2 v8.8h, v0.16b, v4.16b\n"
  392. "smlal2 v9.8h, v0.16b, v5.16b\n"
  393. "smlal2 v12.8h, v1.16b, v4.16b\n"
  394. "smlal2 v13.8h, v1.16b, v5.16b\n"
  395. "smull v10.8h, v0.8b, v6.8b\n"
  396. "ld1 {v2.16b}, [%[a_ptr]], #16\n"
  397. "smull v11.8h, v0.8b, v7.8b\n"
  398. "smull v14.8h, v1.8b, v6.8b\n"
  399. "ld1 {v3.16b}, [%[a_ptr]], #16\n"
  400. "smull v15.8h, v1.8b, v7.8b\n"
  401. "sadalp v16.4s, v8.8h\n"
  402. "smlal2 v10.8h, v0.16b, v6.16b\n"
  403. "sadalp v17.4s, v9.8h\n"
  404. "smlal2 v11.8h, v0.16b, v7.16b\n"
  405. "sadalp v20.4s, v12.8h\n"
  406. "smlal2 v14.8h, v1.16b, v6.16b\n"
  407. "sadalp v21.4s, v13.8h\n"
  408. "smlal2 v15.8h, v1.16b, v7.16b\n"
  409. "smull v8.8h, v2.8b, v4.8b\n"
  410. "smull v9.8h, v2.8b, v5.8b\n"
  411. "ld1 {v0.16b}, [%[a_ptr]], #16\n"
  412. "smull v12.8h, v3.8b, v4.8b\n"
  413. "ld1 {v1.16b}, [%[a_ptr]], #16\n"
  414. "smull v13.8h, v3.8b, v5.8b\n"
  415. "sadalp v18.4s, v10.8h\n"
  416. "smlal2 v8.8h, v2.16b, v4.16b\n"
  417. "sadalp v19.4s, v11.8h\n"
  418. "smlal2 v9.8h, v2.16b, v5.16b\n"
  419. "sadalp v22.4s, v14.8h\n"
  420. "smlal2 v12.8h, v3.16b, v4.16b\n"
  421. "sadalp v23.4s, v15.8h\n"
  422. "smlal2 v13.8h, v3.16b, v5.16b\n"
  423. "smull v10.8h, v2.8b, v6.8b\n"
  424. "smull v11.8h, v2.8b, v7.8b\n"
  425. "ld1 {v4.16b}, [%[b_ptr]], #16\n"
  426. "smull v14.8h, v3.8b, v6.8b\n"
  427. "ld1 {v5.16b}, [%[b_ptr]], #16\n"
  428. "smull v15.8h, v3.8b, v7.8b\n"
  429. "sadalp v24.4s, v8.8h\n"
  430. "smlal2 v10.8h, v2.16b, v6.16b\n"
  431. "sadalp v25.4s, v9.8h\n"
  432. "smlal2 v11.8h, v2.16b, v7.16b\n"
  433. "sadalp v28.4s, v12.8h\n"
  434. "smlal2 v14.8h, v3.16b, v6.16b\n"
  435. "sadalp v29.4s, v13.8h\n"
  436. "smlal2 v15.8h, v3.16b, v7.16b\n"
  437. //! Second k
  438. "smull v8.8h, v0.8b, v4.8b\n"
  439. "smull v9.8h, v0.8b, v5.8b\n"
  440. "ld1 {v6.16b}, [%[b_ptr]], #16\n"
  441. "smull v12.8h, v1.8b, v4.8b\n"
  442. "smull v13.8h, v1.8b, v5.8b\n"
  443. "ld1 {v7.16b}, [%[b_ptr]], #16\n"
  444. "smlal2 v8.8h, v0.16b, v4.16b\n"
  445. "sadalp v26.4s, v10.8h\n"
  446. "smlal2 v9.8h, v0.16b, v5.16b\n"
  447. "sadalp v27.4s, v11.8h\n"
  448. "smlal2 v12.8h, v1.16b, v4.16b\n"
  449. "sadalp v30.4s, v14.8h\n"
  450. "smlal2 v13.8h, v1.16b, v5.16b\n"
  451. "sadalp v31.4s, v15.8h\n"
  452. "smull v10.8h, v0.8b, v6.8b\n"
  453. "ld1 {v2.16b}, [%[a_ptr]], #16\n"
  454. "smull v11.8h, v0.8b, v7.8b\n"
  455. "smull v14.8h, v1.8b, v6.8b\n"
  456. "ld1 {v3.16b}, [%[a_ptr]], #16\n"
  457. "smull v15.8h, v1.8b, v7.8b\n"
  458. "sadalp v16.4s, v8.8h\n"
  459. "smlal2 v10.8h, v0.16b, v6.16b\n"
  460. "sadalp v17.4s, v9.8h\n"
  461. "smlal2 v11.8h, v0.16b, v7.16b\n"
  462. "sadalp v20.4s, v12.8h\n"
  463. "smlal2 v14.8h, v1.16b, v6.16b\n"
  464. "sadalp v21.4s, v13.8h\n"
  465. "smlal2 v15.8h, v1.16b, v7.16b\n"
  466. "smull v8.8h, v2.8b, v4.8b\n"
  467. "smull v9.8h, v2.8b, v5.8b\n"
  468. "ld1 {v0.16b}, [%[a_ptr]], #16\n"
  469. "smull v12.8h, v3.8b, v4.8b\n"
  470. "ld1 {v1.16b}, [%[a_ptr]], #16\n"
  471. "smull v13.8h, v3.8b, v5.8b\n"
  472. "sadalp v18.4s, v10.8h\n"
  473. "smlal2 v8.8h, v2.16b, v4.16b\n"
  474. "sadalp v19.4s, v11.8h\n"
  475. "smlal2 v9.8h, v2.16b, v5.16b\n"
  476. "sadalp v22.4s, v14.8h\n"
  477. "smlal2 v12.8h, v3.16b, v4.16b\n"
  478. "sadalp v23.4s, v15.8h\n"
  479. "smlal2 v13.8h, v3.16b, v5.16b\n"
  480. "sub %w[k], %w[k], #2\n"
  481. "cmp %w[k], #2\n"
  482. "smull v10.8h, v2.8b, v6.8b\n"
  483. "ld1 {v4.16b}, [%[b_ptr]], #16\n"
  484. "smull v11.8h, v2.8b, v7.8b\n"
  485. "ld1 {v5.16b}, [%[b_ptr]], #16\n"
  486. "smull v14.8h, v3.8b, v6.8b\n"
  487. "sadalp v24.4s, v8.8h\n"
  488. "smull v15.8h, v3.8b, v7.8b\n"
  489. "sadalp v25.4s, v9.8h\n"
  490. "smlal2 v10.8h, v2.16b, v6.16b\n"
  491. "sadalp v28.4s, v12.8h\n"
  492. "smlal2 v11.8h, v2.16b, v7.16b\n"
  493. "sadalp v29.4s, v13.8h\n"
  494. "smlal2 v14.8h, v3.16b, v6.16b\n"
  495. "sadalp v26.4s, v10.8h\n"
  496. "smlal2 v15.8h, v3.16b, v7.16b\n"
  497. "sadalp v27.4s, v11.8h\n"
  498. "sadalp v30.4s, v14.8h\n"
  499. "sadalp v31.4s, v15.8h\n"
  500. "bgt 1b\n"
  501. "blt 3f\n"
  502. //! K==2
  503. "2:\n"
  504. "smull v8.8h, v0.8b, v4.8b\n"
  505. "smull v9.8h, v0.8b, v5.8b\n"
  506. "ld1 {v6.16b}, [%[b_ptr]], #16\n"
  507. "smull v12.8h, v1.8b, v4.8b\n"
  508. "smull v13.8h, v1.8b, v5.8b\n"
  509. "ld1 {v7.16b}, [%[b_ptr]], #16\n"
  510. "smlal2 v8.8h, v0.16b, v4.16b\n"
  511. "smlal2 v9.8h, v0.16b, v5.16b\n"
  512. "smlal2 v12.8h, v1.16b, v4.16b\n"
  513. "smlal2 v13.8h, v1.16b, v5.16b\n"
  514. "smull v10.8h, v0.8b, v6.8b\n"
  515. "ld1 {v2.16b}, [%[a_ptr]], #16\n"
  516. "smull v11.8h, v0.8b, v7.8b\n"
  517. "smull v14.8h, v1.8b, v6.8b\n"
  518. "ld1 {v3.16b}, [%[a_ptr]], #16\n"
  519. "smull v15.8h, v1.8b, v7.8b\n"
  520. "sadalp v16.4s, v8.8h\n"
  521. "smlal2 v10.8h, v0.16b, v6.16b\n"
  522. "sadalp v17.4s, v9.8h\n"
  523. "smlal2 v11.8h, v0.16b, v7.16b\n"
  524. "sadalp v20.4s, v12.8h\n"
  525. "smlal2 v14.8h, v1.16b, v6.16b\n"
  526. "sadalp v21.4s, v13.8h\n"
  527. "smlal2 v15.8h, v1.16b, v7.16b\n"
  528. "smull v8.8h, v2.8b, v4.8b\n"
  529. "smull v9.8h, v2.8b, v5.8b\n"
  530. "ld1 {v0.16b}, [%[a_ptr]], #16\n"
  531. "smull v12.8h, v3.8b, v4.8b\n"
  532. "ld1 {v1.16b}, [%[a_ptr]], #16\n"
  533. "smull v13.8h, v3.8b, v5.8b\n"
  534. "sadalp v18.4s, v10.8h\n"
  535. "smlal2 v8.8h, v2.16b, v4.16b\n"
  536. "sadalp v19.4s, v11.8h\n"
  537. "smlal2 v9.8h, v2.16b, v5.16b\n"
  538. "sadalp v22.4s, v14.8h\n"
  539. "smlal2 v12.8h, v3.16b, v4.16b\n"
  540. "sadalp v23.4s, v15.8h\n"
  541. "smlal2 v13.8h, v3.16b, v5.16b\n"
  542. "smull v10.8h, v2.8b, v6.8b\n"
  543. "smull v11.8h, v2.8b, v7.8b\n"
  544. "ld1 {v4.16b}, [%[b_ptr]], #16\n"
  545. "smull v14.8h, v3.8b, v6.8b\n"
  546. "ld1 {v5.16b}, [%[b_ptr]], #16\n"
  547. "smull v15.8h, v3.8b, v7.8b\n"
  548. "sadalp v24.4s, v8.8h\n"
  549. "smlal2 v10.8h, v2.16b, v6.16b\n"
  550. "sadalp v25.4s, v9.8h\n"
  551. "smlal2 v11.8h, v2.16b, v7.16b\n"
  552. "sadalp v28.4s, v12.8h\n"
  553. "smlal2 v14.8h, v3.16b, v6.16b\n"
  554. "sadalp v29.4s, v13.8h\n"
  555. "smlal2 v15.8h, v3.16b, v7.16b\n"
  556. "sadalp v26.4s, v10.8h\n"
  557. "sadalp v27.4s, v11.8h\n"
  558. "sadalp v30.4s, v14.8h\n"
  559. "sadalp v31.4s, v15.8h\n"
  560. //! K==1
  561. "3:\n"
  562. "smull v8.8h, v0.8b, v4.8b\n"
  563. "smull v9.8h, v0.8b, v5.8b\n"
  564. "ld1 {v6.16b}, [%[b_ptr]], #16\n"
  565. "smull v12.8h, v1.8b, v4.8b\n"
  566. "smull v13.8h, v1.8b, v5.8b\n"
  567. "ld1 {v7.16b}, [%[b_ptr]], #16\n"
  568. "smlal2 v8.8h, v0.16b, v4.16b\n"
  569. "smlal2 v9.8h, v0.16b, v5.16b\n"
  570. "smlal2 v12.8h, v1.16b, v4.16b\n"
  571. "smlal2 v13.8h, v1.16b, v5.16b\n"
  572. "smull v10.8h, v0.8b, v6.8b\n"
  573. "ld1 {v2.16b}, [%[a_ptr]], #16\n"
  574. "smull v11.8h, v0.8b, v7.8b\n"
  575. "smull v14.8h, v1.8b, v6.8b\n"
  576. "ld1 {v3.16b}, [%[a_ptr]], #16\n"
  577. "smull v15.8h, v1.8b, v7.8b\n"
  578. "sadalp v16.4s, v8.8h\n"
  579. "smlal2 v10.8h, v0.16b, v6.16b\n"
  580. "sadalp v17.4s, v9.8h\n"
  581. "smlal2 v11.8h, v0.16b, v7.16b\n"
  582. "sadalp v20.4s, v12.8h\n"
  583. "smlal2 v14.8h, v1.16b, v6.16b\n"
  584. "sadalp v21.4s, v13.8h\n"
  585. "smlal2 v15.8h, v1.16b, v7.16b\n"
  586. "smull v8.8h, v2.8b, v4.8b\n"
  587. "smull v9.8h, v2.8b, v5.8b\n"
  588. "smull v12.8h, v3.8b, v4.8b\n"
  589. "smull v13.8h, v3.8b, v5.8b\n"
  590. "sadalp v18.4s, v10.8h\n"
  591. "smlal2 v8.8h, v2.16b, v4.16b\n"
  592. "sadalp v19.4s, v11.8h\n"
  593. "smlal2 v9.8h, v2.16b, v5.16b\n"
  594. "sadalp v22.4s, v14.8h\n"
  595. "smlal2 v12.8h, v3.16b, v4.16b\n"
  596. "sadalp v23.4s, v15.8h\n"
  597. "smlal2 v13.8h, v3.16b, v5.16b\n"
  598. "smull v10.8h, v2.8b, v6.8b\n"
  599. "sadalp v24.4s, v8.8h\n"
  600. "smull v11.8h, v2.8b, v7.8b\n"
  601. "sadalp v25.4s, v9.8h\n"
  602. "smull v14.8h, v3.8b, v6.8b\n"
  603. "sadalp v28.4s, v12.8h\n"
  604. "smull v15.8h, v3.8b, v7.8b\n"
  605. "sadalp v29.4s, v13.8h\n"
  606. "smlal2 v10.8h, v2.16b, v6.16b\n"
  607. "smlal2 v11.8h, v2.16b, v7.16b\n"
  608. "sadalp v26.4s, v10.8h\n"
  609. "smlal2 v14.8h, v3.16b, v6.16b\n"
  610. "sadalp v27.4s, v11.8h\n"
  611. "smlal2 v15.8h, v3.16b, v7.16b\n"
  612. "sadalp v30.4s, v14.8h\n"
  613. "sadalp v31.4s, v15.8h\n"
  614. "addp v4.4s, v16.4s, v20.4s\n"
  615. "addp v5.4s, v24.4s, v28.4s\n"
  616. "addp v6.4s, v17.4s, v21.4s\n"
  617. "addp v7.4s, v25.4s, v29.4s\n"
  618. "addp v8.4s, v18.4s, v22.4s\n"
  619. "addp v9.4s, v26.4s, v30.4s\n"
  620. "addp v10.4s, v19.4s, v23.4s\n"
  621. "addp v11.4s, v27.4s, v31.4s\n"
  622. "addp v0.4s, v4.4s, v5.4s\n"
  623. "addp v1.4s, v6.4s, v7.4s\n"
  624. "addp v2.4s, v8.4s, v9.4s\n"
  625. "addp v3.4s, v10.4s, v11.4s\n"
  626. "cmp %w[is_first_k], #1\n"
  627. "beq 6f\n"
  628. "cmp %w[remain_n], #3\n"
  629. "beq 1003f\n"
  630. "cmp %w[remain_n], #2\n"
  631. "beq 1002f\n"
  632. "cmp %w[remain_n], #1\n"
  633. "beq 1001f\n"
  634. "1003:\n"
  635. "ld1 {v8.4s, v9.4s, v10.4s}, [%[output]]\n"
  636. "add v0.4s, v0.4s, v8.4s\n"
  637. "add v1.4s, v1.4s, v9.4s\n"
  638. "add v2.4s, v2.4s, v10.4s\n"
  639. "b 6f\n"
  640. "1002:\n"
  641. "ld1 {v8.4s, v9.4s}, [%[output]]\n"
  642. "add v0.4s, v0.4s, v8.4s\n"
  643. "add v1.4s, v1.4s, v9.4s\n"
  644. "b 6f\n"
  645. "1001:\n"
  646. "ld1 {v8.4s}, [%[output]]\n"
  647. "add v0.4s, v0.4s, v8.4s\n"
  648. "6:\n"
  649. "cmp %w[remain_n], #3\n"
  650. "beq 10003f\n"
  651. "cmp %w[remain_n], #2\n"
  652. "beq 10002f\n"
  653. "cmp %w[remain_n], #1\n"
  654. "beq 10001f\n"
  655. "10003:\n"
  656. "str q2, [%[output], #32]\n"
  657. "10002:\n"
  658. "str q1, [%[output], #16]\n"
  659. "10001:\n"
  660. "str q0, [%[output]]\n"
  661. "7:\n"
  662. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
  663. [remain_n] "+r"(remain_n), [is_first_k] "+r"(is_first_k),
  664. [k] "+r"(K), [output] "+r"(output)
  665. :
  666. : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
  667. "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
  668. "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
  669. "v29", "v30", "v31", "cc", "memory");
  670. }
  671. static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr,
  672. int ldin, int y0, int ymax, int k0,
  673. int kmax) {
  674. //! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)}
  675. int8_t zerobuff[4][64];
  676. std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4);
  677. megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0,
  678. "mk4 matmul with m is not times of 4");
  679. megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0,
  680. "mk4 matmul with k is not times of 4");
  681. size_t roundk = round_up(kmax - k0, 16);
  682. size_t out_offset = roundk * 4;
  683. int y = y0;
  684. int start_y = y0 / 4;
  685. for (; y + 15 < ymax; y += 16, start_y += 4) {
  686. const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4;
  687. const int8_t* inptr1 = inptr0 + ldin;
  688. const int8_t* inptr2 = inptr1 + ldin;
  689. const int8_t* inptr3 = inptr2 + ldin;
  690. int8_t* output = outptr + (y - y0) / 4 * out_offset;
  691. prefetch_2x(inptr0);
  692. prefetch_2x(inptr1);
  693. prefetch_2x(inptr2);
  694. prefetch_2x(inptr3);
  695. int K = kmax - k0;
  696. for (; K > 15; K -= 16) {
  697. transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output,
  698. out_offset);
  699. output += 64;
  700. }
  701. if (K > 0) {
  702. std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4);
  703. std::memcpy(zerobuff[1], inptr1, sizeof(int8_t) * K * 4);
  704. std::memcpy(zerobuff[2], inptr2, sizeof(int8_t) * K * 4);
  705. std::memcpy(zerobuff[3], inptr3, sizeof(int8_t) * K * 4);
  706. inptr0 = zerobuff[0];
  707. inptr1 = zerobuff[1];
  708. inptr2 = zerobuff[2];
  709. inptr3 = zerobuff[3];
  710. transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output,
  711. out_offset);
  712. output += 64;
  713. }
  714. }
  715. for (; y + 3 < ymax; y += 4, start_y++) {
  716. const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4;
  717. int8_t* output = outptr + (y - y0) / 4 * out_offset;
  718. prefetch_2x(inptr0);
  719. int K = kmax - k0;
  720. for (; K > 15; K -= 16) {
  721. transpose_interleave_1x4_4_b(inptr0, output);
  722. output += 64;
  723. }
  724. if (K > 0) {
  725. std::memcpy(zerobuff[0], inptr0, sizeof(int8_t) * K * 4);
  726. inptr0 = zerobuff[0];
  727. transpose_interleave_1x4_4_b(inptr0, output);
  728. output += 64;
  729. }
  730. }
  731. }
  732. static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
  733. int x0, int xmax, int k0, int kmax) {
  734. int32_t zerobuff[4];
  735. std::memset(zerobuff, 0, sizeof(int8_t) * 16);
  736. const int ksize = kmax - k0;
  737. const int ICB = (ksize) / 4;
  738. const int ksize4 = round_up<int>(ICB, 4) * 4;
  739. int32_t* outptr = reinterpret_cast<int32_t*>(out);
  740. megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0,
  741. "mk4 matmul with k is not times of 4");
  742. int k = k0 / 4;
  743. for (; k + 3 < ICB; k += 4) {
  744. const int32_t* inptr0 =
  745. reinterpret_cast<const int32_t*>(in + k * ldin + x0);
  746. const int32_t* inptr1 =
  747. reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0);
  748. const int32_t* inptr2 =
  749. reinterpret_cast<const int32_t*>(in + (k + 2) * ldin + x0);
  750. const int32_t* inptr3 =
  751. reinterpret_cast<const int32_t*>(in + (k + 3) * ldin + x0);
  752. int32_t* outptr_inner = outptr;
  753. int x = x0;
  754. for (; x + 3 < xmax; x += 4) {
  755. transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner);
  756. outptr_inner += ksize4;
  757. }
  758. if (x < xmax) {
  759. for (; x < xmax; x++) {
  760. *outptr_inner++ = *inptr0++;
  761. *outptr_inner++ = *inptr1++;
  762. *outptr_inner++ = *inptr2++;
  763. *outptr_inner++ = *inptr3++;
  764. }
  765. }
  766. outptr += 4 * 4;
  767. }
  768. if (k < ICB) {
  769. const int32_t* inptr0 =
  770. reinterpret_cast<const int32_t*>(in + k * ldin + x0);
  771. const int32_t* inptr1 =
  772. reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0);
  773. const int32_t* inptr2 =
  774. reinterpret_cast<const int32_t*>(in + (k + 2) * ldin + x0);
  775. const int32_t* inptr3 =
  776. reinterpret_cast<const int32_t*>(in + (k + 3) * ldin + x0);
  777. int32_t* outptr_inner = outptr;
  778. int x = x0;
  779. for (; x + 3 < xmax; x += 4) {
  780. if (k + 3 >= ICB) {
  781. switch (k + 3 - ICB) {
  782. case 2:
  783. inptr1 = zerobuff; MEGDNN_FALLTHRU
  784. case 1:
  785. inptr2 = zerobuff; MEGDNN_FALLTHRU
  786. case 0:
  787. inptr3 = zerobuff;
  788. break;
  789. default:
  790. megdnn_assert(0);
  791. }
  792. }
  793. transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner);
  794. outptr_inner += ksize4;
  795. }
  796. if (x < xmax) {
  797. if (k + 3 >= ICB) {
  798. switch (k + 3 - ICB) {
  799. case 2:
  800. inptr1 = zerobuff; MEGDNN_FALLTHRU
  801. case 1:
  802. inptr2 = zerobuff; MEGDNN_FALLTHRU
  803. case 0:
  804. inptr3 = zerobuff;
  805. break;
  806. default:
  807. megdnn_assert(0);
  808. }
  809. }
  810. for (; x < xmax; x++) {
  811. *outptr_inner++ = *inptr0++;
  812. *outptr_inner++ = *inptr1++;
  813. *outptr_inner++ = *inptr2++;
  814. *outptr_inner++ = *inptr3++;
  815. }
  816. }
  817. outptr += 4 * 4;
  818. }
  819. }
  820. } // namespace matmul_4x4x16
  821. } // namespace aarch64
  822. } // namespace megdnn
  823. #endif
  824. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台