You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

strategy_mk8_8x8.cpp 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. /**
  2. * \file dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/aarch64/matrix_mul/asm/common.h"
  12. #include "src/aarch64/matrix_mul/int16/strategy.h"
  13. #include "src/arm_common/simd_macro/marm_neon.h"
  14. #include "src/common/utils.h"
  15. using namespace megdnn;
  16. using namespace aarch64;
  17. using namespace aarch64::matmul;
  18. namespace {
  19. void kern_8x1(
  20. const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
  21. dt_int32* output) {
  22. //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24
  23. //! here.
  24. LDB *= sizeof(dt_int16);
  25. asm volatile(
  26. "subs %w[K], %w[K], #8\n"
  27. "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[a_ptr]], 64\n"
  28. "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[a_ptr]], 64\n"
  29. "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n"
  30. "smull v16.4s, v24.4h, v0.h[0]\n"
  31. "smull2 v17.4s, v24.8h, v0.h[0]\n"
  32. "smull v18.4s, v25.4h, v0.h[1]\n"
  33. "smull2 v19.4s, v25.8h, v0.h[1]\n"
  34. "smull v20.4s, v26.4h, v0.h[2]\n"
  35. "smull2 v21.4s, v26.8h, v0.h[2]\n"
  36. "smull v22.4s, v27.4h, v0.h[3]\n"
  37. "smull2 v23.4s, v27.8h, v0.h[3]\n"
  38. "beq 2f\n"
  39. "1:\n"
  40. "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[a_ptr]], 64\n"
  41. "smlal v16.4s, v28.4h, v0.h[4]\n"
  42. "smlal2 v17.4s, v28.8h, v0.h[4]\n"
  43. "smlal v18.4s, v29.4h, v0.h[5]\n"
  44. "smlal2 v19.4s, v29.8h, v0.h[5]\n"
  45. "smlal v20.4s, v30.4h, v0.h[6]\n"
  46. "smlal2 v21.4s, v30.8h, v0.h[6]\n"
  47. "smlal v22.4s, v31.4h, v0.h[7]\n"
  48. "smlal2 v23.4s, v31.8h, v0.h[7]\n"
  49. "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n"
  50. "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[a_ptr]], 64\n"
  51. "smlal v16.4s, v24.4h, v0.h[0]\n"
  52. "smlal2 v17.4s, v24.8h, v0.h[0]\n"
  53. "smlal v18.4s, v25.4h, v0.h[1]\n"
  54. "smlal2 v19.4s, v25.8h, v0.h[1]\n"
  55. "smlal v20.4s, v26.4h, v0.h[2]\n"
  56. "smlal2 v21.4s, v26.8h, v0.h[2]\n"
  57. "smlal v22.4s, v27.4h, v0.h[3]\n"
  58. "smlal2 v23.4s, v27.8h, v0.h[3]\n"
  59. "subs %w[K], %w[K], #8\n"
  60. "bne 1b\n"
  61. "2:\n"
  62. "smlal v16.4s, v28.4h, v0.h[4]\n"
  63. "smlal2 v17.4s, v28.8h, v0.h[4]\n"
  64. "smlal v18.4s, v29.4h, v0.h[5]\n"
  65. "smlal2 v19.4s, v29.8h, v0.h[5]\n"
  66. "smlal v20.4s, v30.4h, v0.h[6]\n"
  67. "smlal2 v21.4s, v30.8h, v0.h[6]\n"
  68. "smlal v22.4s, v31.4h, v0.h[7]\n"
  69. "smlal2 v23.4s, v31.8h, v0.h[7]\n"
  70. "add v16.4s, v16.4s, v18.4s\n"
  71. "add v20.4s, v20.4s, v22.4s\n"
  72. "add v17.4s, v17.4s, v19.4s\n"
  73. "add v21.4s, v21.4s, v23.4s\n"
  74. "add v16.4s, v16.4s, v20.4s\n"
  75. "add v17.4s, v17.4s, v21.4s\n"
  76. "st1 {v16.4s, v17.4s}, [%[output]], 32\n"
  77. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
  78. [output] "+r"(output), [LDB] "+r"(LDB)
  79. :
  80. : "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
  81. "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory");
  82. }
  83. // Overview of register layout:
  84. //
  85. // A 8x1 cell of Lhs is stored in 16bit in v24-v27
  86. // A 8x1 cell of Rhs is stored in 16bit in v0-v7
  87. // A 8x2 block of accumulators is stored in 32bit in v16-v23.
  88. //
  89. // Lhs +--------+
  90. // |v0[0-7]|
  91. // |v1[0-7]|
  92. // |v2[0-7]|
  93. // |v3[0-7]|
  94. // +--------+
  95. // Rhs
  96. // +---------+ - - - - -+--------+
  97. // | v24[0-7]| |v16[0-3]|
  98. // | v25[0-7]| |v17[0-3]|
  99. // | v26[0-7]| |v18[0-3]|
  100. // | v27[0-7]| |v19[0-3]|
  101. // | v28[0-7]| |v20[0-3]|
  102. // | v29[0-7]| |v21[0-3]|
  103. // | v30[0-7]| |v22[0-3]|
  104. // | v31[0-7]| |v23[0-3]|
  105. // +---------+ +--------+
  106. // Accumulator
  107. void kern_8x4(
  108. const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
  109. dt_int32* output) {
  110. //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24
  111. //! here.
  112. LDB = (LDB - 24) * sizeof(dt_int16);
  113. asm volatile(
  114. "subs %w[K], %w[K], #8\n"
  115. "ld1 {v24.4s}, [%[a_ptr]], 16\n"
  116. "ld1 {v0.4s}, [%[b_ptr]], 16\n"
  117. "ld1 {v1.4s}, [%[b_ptr]], 16\n"
  118. "ld1 {v2.4s}, [%[b_ptr]], 16\n"
  119. "ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n"
  120. "smull v16.4s, v24.4h, v0.h[0]\n"
  121. "smull2 v17.4s, v24.8h, v0.h[0]\n"
  122. "smull v18.4s, v24.4h, v1.h[0]\n"
  123. "smull2 v19.4s, v24.8h, v1.h[0]\n"
  124. "ld1 {v25.4s}, [%[a_ptr]], 16\n"
  125. "smull v20.4s, v24.4h, v2.h[0]\n"
  126. "smull2 v21.4s, v24.8h, v2.h[0]\n"
  127. "smull v22.4s, v24.4h, v3.h[0]\n"
  128. "smull2 v23.4s, v24.8h, v3.h[0]\n"
  129. "smlal v16.4s, v25.4h, v0.h[1]\n"
  130. "smlal2 v17.4s, v25.8h, v0.h[1]\n"
  131. "smlal v18.4s, v25.4h, v1.h[1]\n"
  132. "smlal2 v19.4s, v25.8h, v1.h[1]\n"
  133. "ld1 {v26.4s}, [%[a_ptr]], 16\n"
  134. "smlal v20.4s, v25.4h, v2.h[1]\n"
  135. "smlal2 v21.4s, v25.8h, v2.h[1]\n"
  136. "smlal v22.4s, v25.4h, v3.h[1]\n"
  137. "smlal2 v23.4s, v25.8h, v3.h[1]\n"
  138. "smlal v16.4s, v26.4h, v0.h[2]\n"
  139. "smlal2 v17.4s, v26.8h, v0.h[2]\n"
  140. "smlal v18.4s, v26.4h, v1.h[2]\n"
  141. "smlal2 v19.4s, v26.8h, v1.h[2]\n"
  142. "ld1 {v27.4s}, [%[a_ptr]], 16\n"
  143. "smlal v20.4s, v26.4h, v2.h[2]\n"
  144. "smlal2 v21.4s, v26.8h, v2.h[2]\n"
  145. "smlal v22.4s, v26.4h, v3.h[2]\n"
  146. "smlal2 v23.4s, v26.8h, v3.h[2]\n"
  147. "smlal v16.4s, v27.4h, v0.h[3]\n"
  148. "smlal2 v17.4s, v27.8h, v0.h[3]\n"
  149. "smlal v18.4s, v27.4h, v1.h[3]\n"
  150. "smlal2 v19.4s, v27.8h, v1.h[3]\n"
  151. "ld1 {v28.4s}, [%[a_ptr]], 16\n"
  152. "smlal v20.4s, v27.4h, v2.h[3]\n"
  153. "smlal2 v21.4s, v27.8h, v2.h[3]\n"
  154. "smlal v22.4s, v27.4h, v3.h[3]\n"
  155. "smlal2 v23.4s, v27.8h, v3.h[3]\n"
  156. "smlal v16.4s, v28.4h, v0.h[4]\n"
  157. "smlal2 v17.4s, v28.8h, v0.h[4]\n"
  158. "smlal v18.4s, v28.4h, v1.h[4]\n"
  159. "smlal2 v19.4s, v28.8h, v1.h[4]\n"
  160. "ld1 {v29.4s}, [%[a_ptr]], 16\n"
  161. "smlal v20.4s, v28.4h, v2.h[4]\n"
  162. "smlal2 v21.4s, v28.8h, v2.h[4]\n"
  163. "smlal v22.4s, v28.4h, v3.h[4]\n"
  164. "smlal2 v23.4s, v28.8h, v3.h[4]\n"
  165. "smlal v16.4s, v29.4h, v0.h[5]\n"
  166. "smlal2 v17.4s, v29.8h, v0.h[5]\n"
  167. "smlal v18.4s, v29.4h, v1.h[5]\n"
  168. "smlal2 v19.4s, v29.8h, v1.h[5]\n"
  169. "ld1 {v30.4s}, [%[a_ptr]], 16\n"
  170. "smlal v20.4s, v29.4h, v2.h[5]\n"
  171. "smlal2 v21.4s, v29.8h, v2.h[5]\n"
  172. "smlal v22.4s, v29.4h, v3.h[5]\n"
  173. "smlal2 v23.4s, v29.8h, v3.h[5]\n"
  174. "smlal v16.4s, v30.4h, v0.h[6]\n"
  175. "smlal2 v17.4s, v30.8h, v0.h[6]\n"
  176. "smlal v18.4s, v30.4h, v1.h[6]\n"
  177. "smlal2 v19.4s, v30.8h, v1.h[6]\n"
  178. "ld1 {v31.4s}, [%[a_ptr]], 16\n"
  179. "smlal v20.4s, v30.4h, v2.h[6]\n"
  180. "smlal2 v21.4s, v30.8h, v2.h[6]\n"
  181. "smlal v22.4s, v30.4h, v3.h[6]\n"
  182. "smlal2 v23.4s, v30.8h, v3.h[6]\n"
  183. "beq 2f\n"
  184. "1:\n"
  185. "ld1 {v24.4s}, [%[a_ptr]], 16\n"
  186. "smlal v16.4s, v31.4h, v0.h[7]\n"
  187. "smlal2 v17.4s, v31.8h, v0.h[7]\n"
  188. "ld1 {v0.4s}, [%[b_ptr]], 16\n"
  189. "smlal v18.4s, v31.4h, v1.h[7]\n"
  190. "smlal2 v19.4s, v31.8h, v1.h[7]\n"
  191. "ld1 {v1.4s}, [%[b_ptr]], 16\n"
  192. "smlal v20.4s, v31.4h, v2.h[7]\n"
  193. "smlal2 v21.4s, v31.8h, v2.h[7]\n"
  194. "ld1 {v2.4s}, [%[b_ptr]], 16\n"
  195. "smlal v22.4s, v31.4h, v3.h[7]\n"
  196. "smlal2 v23.4s, v31.8h, v3.h[7]\n"
  197. "ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n"
  198. "smlal v16.4s, v24.4h, v0.h[0]\n"
  199. "smlal2 v17.4s, v24.8h, v0.h[0]\n"
  200. "smlal v18.4s, v24.4h, v1.h[0]\n"
  201. "smlal2 v19.4s, v24.8h, v1.h[0]\n"
  202. "ld1 {v25.4s}, [%[a_ptr]], 16\n"
  203. "smlal v20.4s, v24.4h, v2.h[0]\n"
  204. "smlal2 v21.4s, v24.8h, v2.h[0]\n"
  205. "smlal v22.4s, v24.4h, v3.h[0]\n"
  206. "smlal2 v23.4s, v24.8h, v3.h[0]\n"
  207. "smlal v16.4s, v25.4h, v0.h[1]\n"
  208. "smlal2 v17.4s, v25.8h, v0.h[1]\n"
  209. "smlal v18.4s, v25.4h, v1.h[1]\n"
  210. "smlal2 v19.4s, v25.8h, v1.h[1]\n"
  211. "ld1 {v26.4s}, [%[a_ptr]], 16\n"
  212. "smlal v20.4s, v25.4h, v2.h[1]\n"
  213. "smlal2 v21.4s, v25.8h, v2.h[1]\n"
  214. "smlal v22.4s, v25.4h, v3.h[1]\n"
  215. "smlal2 v23.4s, v25.8h, v3.h[1]\n"
  216. "smlal v16.4s, v26.4h, v0.h[2]\n"
  217. "smlal2 v17.4s, v26.8h, v0.h[2]\n"
  218. "smlal v18.4s, v26.4h, v1.h[2]\n"
  219. "smlal2 v19.4s, v26.8h, v1.h[2]\n"
  220. "ld1 {v27.4s}, [%[a_ptr]], 16\n"
  221. "smlal v20.4s, v26.4h, v2.h[2]\n"
  222. "smlal2 v21.4s, v26.8h, v2.h[2]\n"
  223. "smlal v22.4s, v26.4h, v3.h[2]\n"
  224. "smlal2 v23.4s, v26.8h, v3.h[2]\n"
  225. "smlal v16.4s, v27.4h, v0.h[3]\n"
  226. "smlal2 v17.4s, v27.8h, v0.h[3]\n"
  227. "smlal v18.4s, v27.4h, v1.h[3]\n"
  228. "smlal2 v19.4s, v27.8h, v1.h[3]\n"
  229. "ld1 {v28.4s}, [%[a_ptr]], 16\n"
  230. "smlal v20.4s, v27.4h, v2.h[3]\n"
  231. "smlal2 v21.4s, v27.8h, v2.h[3]\n"
  232. "smlal v22.4s, v27.4h, v3.h[3]\n"
  233. "smlal2 v23.4s, v27.8h, v3.h[3]\n"
  234. "smlal v16.4s, v28.4h, v0.h[4]\n"
  235. "smlal2 v17.4s, v28.8h, v0.h[4]\n"
  236. "smlal v18.4s, v28.4h, v1.h[4]\n"
  237. "smlal2 v19.4s, v28.8h, v1.h[4]\n"
  238. "ld1 {v29.4s}, [%[a_ptr]], 16\n"
  239. "smlal v20.4s, v28.4h, v2.h[4]\n"
  240. "smlal2 v21.4s, v28.8h, v2.h[4]\n"
  241. "smlal v22.4s, v28.4h, v3.h[4]\n"
  242. "smlal2 v23.4s, v28.8h, v3.h[4]\n"
  243. "smlal v16.4s, v29.4h, v0.h[5]\n"
  244. "smlal2 v17.4s, v29.8h, v0.h[5]\n"
  245. "smlal v18.4s, v29.4h, v1.h[5]\n"
  246. "smlal2 v19.4s, v29.8h, v1.h[5]\n"
  247. "ld1 {v30.4s}, [%[a_ptr]], 16\n"
  248. "smlal v20.4s, v29.4h, v2.h[5]\n"
  249. "smlal2 v21.4s, v29.8h, v2.h[5]\n"
  250. "smlal v22.4s, v29.4h, v3.h[5]\n"
  251. "smlal2 v23.4s, v29.8h, v3.h[5]\n"
  252. "smlal v16.4s, v30.4h, v0.h[6]\n"
  253. "smlal2 v17.4s, v30.8h, v0.h[6]\n"
  254. "smlal v18.4s, v30.4h, v1.h[6]\n"
  255. "smlal2 v19.4s, v30.8h, v1.h[6]\n"
  256. "ld1 {v31.4s}, [%[a_ptr]], 16\n"
  257. "smlal v20.4s, v30.4h, v2.h[6]\n"
  258. "smlal2 v21.4s, v30.8h, v2.h[6]\n"
  259. "smlal v22.4s, v30.4h, v3.h[6]\n"
  260. "smlal2 v23.4s, v30.8h, v3.h[6]\n"
  261. "subs %w[K], %w[K], #8\n"
  262. "bne 1b\n"
  263. "2:\n"
  264. "smlal v16.4s, v31.4h, v0.h[7]\n"
  265. "smlal2 v17.4s, v31.8h, v0.h[7]\n"
  266. "smlal v18.4s, v31.4h, v1.h[7]\n"
  267. "smlal2 v19.4s, v31.8h, v1.h[7]\n"
  268. "smlal v20.4s, v31.4h, v2.h[7]\n"
  269. "smlal2 v21.4s, v31.8h, v2.h[7]\n"
  270. "smlal v22.4s, v31.4h, v3.h[7]\n"
  271. "smlal2 v23.4s, v31.8h, v3.h[7]\n"
  272. "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n"
  273. "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n"
  274. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
  275. [output] "+r"(output), [LDB] "+r"(LDB)
  276. :
  277. : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
  278. "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc",
  279. "memory");
  280. }
  281. // Overview of register layout:
  282. //
  283. // A 8x1 cell of Rhs is stored in 16bit in v8-v15
  284. // A 8x1 cell of Lhs is stored in 16bit in v0-v7
  285. // A 8x2 block of accumulators is stored in 32bit in v16-v31.
  286. //
  287. // Rhs +--------+
  288. // | v8[0-7]|
  289. // | v9[0-7]|
  290. // |v10[0-7]|
  291. // |v11[0-7]|
  292. // |v12[0-7]|
  293. // |v13[0-7]|
  294. // |v14[0-7]|
  295. // |v15[0-7]|
  296. // +--------+
  297. // Lhs
  298. // +--------+ - - - - -+--------+--------+
  299. // | v0[0-7]| |v16[0-3]|v17[0-3]|
  300. // | v1[0-7]| |v18[0-3]|v19[0-3]|
  301. // | v2[0-7]| |v20[0-3]|v21[0-3]|
  302. // | v3[0-7]| |v22[0-3]|v23[0-3]|
  303. // | v4[0-7]| |v24[0-3]|v25[0-3]|
  304. // | v5[0-7]| |v26[0-3]|v27[0-3]|
  305. // | v6[0-7]| |v28[0-3]|v29[0-3]|
  306. // | v7[0-7]| |v30[0-3]|v31[0-3]|
  307. // +--------+ +--------+--------+
  308. // Accumulator
  309. void kern_8x8(
  310. const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
  311. dt_int32* output) {
  312. //! As each load 64 number from B, but the pos add 48 * 2, so we minus 48
  313. //! here.
  314. LDB = (LDB - 48) * sizeof(dt_int16);
  315. asm volatile(
  316. "subs %w[K], %w[K], #8\n"
  317. "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n"
  318. "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n"
  319. "ld1 {v0.4s}, [%[b_ptr]], 16\n"
  320. "smull v16.4s, v8.4h, v0.h[0]\n"
  321. "ld1 {v1.4s}, [%[b_ptr]], 16\n"
  322. "smlal v16.4s, v9.4h, v0.h[1]\n"
  323. "smull v18.4s, v8.4h, v1.h[0]\n"
  324. "smull2 v17.4s, v8.8h, v0.h[0]\n"
  325. "smull2 v19.4s, v8.8h, v1.h[0]\n"
  326. "smlal v16.4s, v10.4h, v0.h[2]\n"
  327. "smlal v18.4s, v9.4h, v1.h[1]\n"
  328. "smlal2 v17.4s, v9.8h, v0.h[1]\n"
  329. "smlal2 v19.4s, v9.8h, v1.h[1]\n"
  330. "smlal v16.4s, v11.4h, v0.h[3]\n"
  331. "smlal v18.4s, v10.4h, v1.h[2]\n"
  332. "smlal2 v17.4s, v10.8h, v0.h[2]\n"
  333. "smlal2 v19.4s, v10.8h, v1.h[2]\n"
  334. "smlal v16.4s, v12.4h, v0.h[4]\n"
  335. "smlal v18.4s, v11.4h, v1.h[3]\n"
  336. "smlal2 v17.4s, v11.8h, v0.h[3]\n"
  337. "smlal2 v19.4s, v11.8h, v1.h[3]\n"
  338. "smlal v16.4s, v13.4h, v0.h[5]\n"
  339. "smlal v18.4s, v12.4h, v1.h[4]\n"
  340. "smlal2 v17.4s, v12.8h, v0.h[4]\n"
  341. "smlal2 v19.4s, v12.8h, v1.h[4]\n"
  342. "smlal2 v17.4s, v13.8h, v0.h[5]\n"
  343. "ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n"
  344. "smlal v16.4s, v14.4h, v0.h[6]\n"
  345. "smlal v18.4s, v13.4h, v1.h[5]\n"
  346. "smlal2 v17.4s, v14.8h, v0.h[6]\n"
  347. "smlal2 v19.4s, v13.8h, v1.h[5]\n"
  348. "smull v20.4s, v8.4h, v2.h[0]\n"
  349. "smull v22.4s, v8.4h, v3.h[0]\n"
  350. "smull2 v21.4s, v8.8h, v2.h[0]\n"
  351. "smull2 v23.4s, v8.8h, v3.h[0]\n"
  352. "smlal v16.4s, v15.4h, v0.h[7]\n"
  353. "smlal v18.4s, v14.4h, v1.h[6]\n"
  354. "smlal2 v17.4s, v15.8h, v0.h[7]\n"
  355. "smlal2 v19.4s, v14.8h, v1.h[6]\n"
  356. "smlal v20.4s, v9.4h, v2.h[1]\n"
  357. "smlal v22.4s, v9.4h, v3.h[1]\n"
  358. "smlal2 v21.4s, v9.8h, v2.h[1]\n"
  359. "smlal2 v23.4s, v9.8h, v3.h[1]\n"
  360. "smlal v18.4s, v15.4h, v1.h[7]\n"
  361. "smlal v20.4s, v10.4h, v2.h[2]\n"
  362. "smlal v22.4s, v10.4h, v3.h[2]\n"
  363. "smlal2 v21.4s, v10.8h, v2.h[2]\n"
  364. "smlal2 v23.4s, v10.8h, v3.h[2]\n"
  365. "smlal2 v19.4s, v15.8h, v1.h[7]\n"
  366. "smlal v20.4s, v11.4h, v2.h[3]\n"
  367. "smlal v22.4s, v11.4h, v3.h[3]\n"
  368. "smlal2 v21.4s, v11.8h, v2.h[3]\n"
  369. "smlal2 v23.4s, v11.8h, v3.h[3]\n"
  370. "smlal v20.4s, v12.4h, v2.h[4]\n"
  371. "smlal v22.4s, v12.4h, v3.h[4]\n"
  372. "smlal2 v21.4s, v12.8h, v2.h[4]\n"
  373. "smlal2 v23.4s, v12.8h, v3.h[4]\n"
  374. "smlal v20.4s, v13.4h, v2.h[5]\n"
  375. "smlal v22.4s, v13.4h, v3.h[5]\n"
  376. "smlal2 v21.4s, v13.8h, v2.h[5]\n"
  377. "smlal2 v23.4s, v13.8h, v3.h[5]\n"
  378. "ld1 {v4.4s, v5.4s}, [%[b_ptr]], 32\n"
  379. "smlal v20.4s, v14.4h, v2.h[6]\n"
  380. "smlal v22.4s, v14.4h, v3.h[6]\n"
  381. "smlal2 v21.4s, v14.8h, v2.h[6]\n"
  382. "smlal2 v23.4s, v14.8h, v3.h[6]\n"
  383. "smull v24.4s, v8.4h, v4.h[0]\n"
  384. "smull v26.4s, v8.4h, v5.h[0]\n"
  385. "smull2 v25.4s, v8.8h, v4.h[0]\n"
  386. "smull2 v27.4s, v8.8h, v5.h[0]\n"
  387. "smlal v20.4s, v15.4h, v2.h[7]\n"
  388. "smlal v22.4s, v15.4h, v3.h[7]\n"
  389. "smlal2 v21.4s, v15.8h, v2.h[7]\n"
  390. "smlal2 v23.4s, v15.8h, v3.h[7]\n"
  391. "smlal v24.4s, v9.4h, v4.h[1]\n"
  392. "smlal v26.4s, v9.4h, v5.h[1]\n"
  393. "smlal2 v25.4s, v9.8h, v4.h[1]\n"
  394. "smlal2 v27.4s, v9.8h, v5.h[1]\n"
  395. "smlal v24.4s, v10.4h, v4.h[2]\n"
  396. "smlal v26.4s, v10.4h, v5.h[2]\n"
  397. "smlal2 v25.4s, v10.8h, v4.h[2]\n"
  398. "smlal2 v27.4s, v10.8h, v5.h[2]\n"
  399. "smlal v24.4s, v11.4h, v4.h[3]\n"
  400. "smlal v26.4s, v11.4h, v5.h[3]\n"
  401. "smlal2 v25.4s, v11.8h, v4.h[3]\n"
  402. "smlal2 v27.4s, v11.8h, v5.h[3]\n"
  403. "smlal v24.4s, v12.4h, v4.h[4]\n"
  404. "smlal v26.4s, v12.4h, v5.h[4]\n"
  405. "smlal2 v25.4s, v12.8h, v4.h[4]\n"
  406. "smlal2 v27.4s, v12.8h, v5.h[4]\n"
  407. "smlal v24.4s, v13.4h, v4.h[5]\n"
  408. "smlal v26.4s, v13.4h, v5.h[5]\n"
  409. "smlal2 v25.4s, v13.8h, v4.h[5]\n"
  410. "smlal2 v27.4s, v13.8h, v5.h[5]\n"
  411. "ld1 {v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n"
  412. "smlal v24.4s, v14.4h, v4.h[6]\n"
  413. "smlal v26.4s, v14.4h, v5.h[6]\n"
  414. "smlal2 v25.4s, v14.8h, v4.h[6]\n"
  415. "smlal2 v27.4s, v14.8h, v5.h[6]\n"
  416. "smull v28.4s, v8.4h, v6.h[0]\n"
  417. "smull v30.4s, v8.4h, v7.h[0]\n"
  418. "smull2 v29.4s, v8.8h, v6.h[0]\n"
  419. "smull2 v31.4s, v8.8h, v7.h[0]\n"
  420. "smlal v28.4s, v9.4h, v6.h[1]\n"
  421. "smlal v30.4s, v9.4h, v7.h[1]\n"
  422. "smlal2 v29.4s, v9.8h, v6.h[1]\n"
  423. "smlal2 v31.4s, v9.8h, v7.h[1]\n"
  424. "smlal v28.4s, v10.4h, v6.h[2]\n"
  425. "smlal v30.4s, v10.4h, v7.h[2]\n"
  426. "smlal2 v29.4s, v10.8h, v6.h[2]\n"
  427. "smlal2 v31.4s, v10.8h, v7.h[2]\n"
  428. "smlal v28.4s, v11.4h, v6.h[3]\n"
  429. "smlal v30.4s, v11.4h, v7.h[3]\n"
  430. "smlal2 v29.4s, v11.8h, v6.h[3]\n"
  431. "smlal2 v31.4s, v11.8h, v7.h[3]\n"
  432. "smlal v28.4s, v12.4h, v6.h[4]\n"
  433. "smlal v30.4s, v12.4h, v7.h[4]\n"
  434. "smlal2 v29.4s, v12.8h, v6.h[4]\n"
  435. "smlal2 v31.4s, v12.8h, v7.h[4]\n"
  436. "smlal v28.4s, v13.4h, v6.h[5]\n"
  437. "smlal v30.4s, v13.4h, v7.h[5]\n"
  438. "smlal2 v29.4s, v13.8h, v6.h[5]\n"
  439. "smlal2 v31.4s, v13.8h, v7.h[5]\n"
  440. "beq 2f\n"
  441. "1:\n"
  442. "smlal v24.4s, v15.4h, v4.h[7]\n"
  443. "smlal v26.4s, v15.4h, v5.h[7]\n"
  444. "smlal2 v25.4s, v15.8h, v4.h[7]\n"
  445. "ld1 {v8.4s, v9.4s}, [%[a_ptr]], 32\n"
  446. "smlal2 v27.4s, v15.8h, v5.h[7]\n"
  447. "smlal v28.4s, v14.4h, v6.h[6]\n"
  448. "smlal v30.4s, v14.4h, v7.h[6]\n"
  449. "ld1 {v10.4s, v11.4s}, [%[a_ptr]], 32\n"
  450. "smlal2 v29.4s, v15.8h, v6.h[7]\n"
  451. "smlal2 v31.4s, v14.8h, v7.h[6]\n"
  452. "smlal v28.4s, v15.4h, v6.h[7]\n"
  453. "ld1 {v12.4s, v13.4s}, [%[a_ptr]], 32\n"
  454. "smlal v30.4s, v15.4h, v7.h[7]\n"
  455. "smlal2 v29.4s, v14.8h, v6.h[6]\n"
  456. "ld1 {v0.4s}, [%[b_ptr]], 16\n"
  457. "smlal2 v31.4s, v15.8h, v7.h[7]\n"
  458. "smlal v16.4s, v8.4h, v0.h[0]\n"
  459. "ld1 {v1.4s}, [%[b_ptr]], 16\n"
  460. "smlal v16.4s, v9.4h, v0.h[1]\n"
  461. "smlal2 v17.4s, v8.8h, v0.h[0]\n"
  462. "smlal v16.4s, v10.4h, v0.h[2]\n"
  463. "smlal v18.4s, v8.4h, v1.h[0]\n"
  464. "smlal2 v17.4s, v9.8h, v0.h[1]\n"
  465. "smlal2 v19.4s, v8.8h, v1.h[0]\n"
  466. "ld1 {v14.4s, v15.4s}, [%[a_ptr]], 32\n"
  467. "smlal v16.4s, v11.4h, v0.h[3]\n"
  468. "smlal v18.4s, v9.4h, v1.h[1]\n"
  469. "smlal2 v17.4s, v10.8h, v0.h[2]\n"
  470. "smlal2 v19.4s, v9.8h, v1.h[1]\n"
  471. "smlal v16.4s, v12.4h, v0.h[4]\n"
  472. "smlal v18.4s, v10.4h, v1.h[2]\n"
  473. "smlal2 v17.4s, v11.8h, v0.h[3]\n"
  474. "smlal2 v19.4s, v10.8h, v1.h[2]\n"
  475. "smlal v16.4s, v13.4h, v0.h[5]\n"
  476. "smlal v18.4s, v11.4h, v1.h[3]\n"
  477. "smlal2 v17.4s, v12.8h, v0.h[4]\n"
  478. "smlal2 v19.4s, v11.8h, v1.h[3]\n"
  479. "smlal v16.4s, v14.4h, v0.h[6]\n"
  480. "smlal v18.4s, v12.4h, v1.h[4]\n"
  481. "smlal2 v17.4s, v13.8h, v0.h[5]\n"
  482. "smlal2 v19.4s, v12.8h, v1.h[4]\n"
  483. "smlal v16.4s, v15.4h, v0.h[7]\n"
  484. "smlal v18.4s, v13.4h, v1.h[5]\n"
  485. "smlal2 v17.4s, v14.8h, v0.h[6]\n"
  486. "smlal2 v19.4s, v13.8h, v1.h[5]\n"
  487. "ld1 {v2.4s, v3.4s}, [%[b_ptr]], 32\n"
  488. "smlal v18.4s, v14.4h, v1.h[6]\n"
  489. "smlal2 v17.4s, v15.8h, v0.h[7]\n"
  490. "smlal2 v19.4s, v14.8h, v1.h[6]\n"
  491. "smlal v20.4s, v8.4h, v2.h[0]\n"
  492. "smlal v22.4s, v8.4h, v3.h[0]\n"
  493. "smlal2 v21.4s, v8.8h, v2.h[0]\n"
  494. "smlal2 v23.4s, v8.8h, v3.h[0]\n"
  495. "smlal v18.4s, v15.4h, v1.h[7]\n"
  496. "smlal v20.4s, v9.4h, v2.h[1]\n"
  497. "smlal v22.4s, v9.4h, v3.h[1]\n"
  498. "smlal2 v21.4s, v9.8h, v2.h[1]\n"
  499. "smlal2 v23.4s, v9.8h, v3.h[1]\n"
  500. "smlal2 v19.4s, v15.8h, v1.h[7]\n"
  501. "smlal v20.4s, v10.4h, v2.h[2]\n"
  502. "smlal v22.4s, v10.4h, v3.h[2]\n"
  503. "smlal2 v21.4s, v10.8h, v2.h[2]\n"
  504. "smlal2 v23.4s, v10.8h, v3.h[2]\n"
  505. "smlal v20.4s, v11.4h, v2.h[3]\n"
  506. "smlal v22.4s, v11.4h, v3.h[3]\n"
  507. "smlal2 v21.4s, v11.8h, v2.h[3]\n"
  508. "smlal2 v23.4s, v11.8h, v3.h[3]\n"
  509. "smlal v20.4s, v12.4h, v2.h[4]\n"
  510. "smlal v22.4s, v12.4h, v3.h[4]\n"
  511. "smlal2 v21.4s, v12.8h, v2.h[4]\n"
  512. "smlal2 v23.4s, v12.8h, v3.h[4]\n"
  513. "smlal v20.4s, v13.4h, v2.h[5]\n"
  514. "smlal v22.4s, v13.4h, v3.h[5]\n"
  515. "smlal2 v21.4s, v13.8h, v2.h[5]\n"
  516. "smlal2 v23.4s, v13.8h, v3.h[5]\n"
  517. "ld1 {v4.4s, v5.4s}, [%[b_ptr]], 32\n"
  518. "smlal v20.4s, v14.4h, v2.h[6]\n"
  519. "smlal v22.4s, v14.4h, v3.h[6]\n"
  520. "smlal2 v21.4s, v14.8h, v2.h[6]\n"
  521. "smlal2 v23.4s, v14.8h, v3.h[6]\n"
  522. "smlal v24.4s, v8.4h, v4.h[0]\n"
  523. "smlal v26.4s, v8.4h, v5.h[0]\n"
  524. "smlal2 v25.4s, v8.8h, v4.h[0]\n"
  525. "smlal2 v27.4s, v8.8h, v5.h[0]\n"
  526. "smlal v20.4s, v15.4h, v2.h[7]\n"
  527. "smlal2 v21.4s, v15.8h, v2.h[7]\n"
  528. "smlal v22.4s, v15.4h, v3.h[7]\n"
  529. "smlal2 v23.4s, v15.8h, v3.h[7]\n"
  530. "smlal v24.4s, v9.4h, v4.h[1]\n"
  531. "smlal v26.4s, v9.4h, v5.h[1]\n"
  532. "smlal2 v25.4s, v9.8h, v4.h[1]\n"
  533. "smlal2 v27.4s, v9.8h, v5.h[1]\n"
  534. "smlal v24.4s, v10.4h, v4.h[2]\n"
  535. "smlal v26.4s, v10.4h, v5.h[2]\n"
  536. "smlal2 v25.4s, v10.8h, v4.h[2]\n"
  537. "smlal2 v27.4s, v10.8h, v5.h[2]\n"
  538. "smlal v24.4s, v11.4h, v4.h[3]\n"
  539. "smlal v26.4s, v11.4h, v5.h[3]\n"
  540. "smlal2 v25.4s, v11.8h, v4.h[3]\n"
  541. "smlal2 v27.4s, v11.8h, v5.h[3]\n"
  542. "smlal v24.4s, v12.4h, v4.h[4]\n"
  543. "smlal v26.4s, v12.4h, v5.h[4]\n"
  544. "smlal2 v25.4s, v12.8h, v4.h[4]\n"
  545. "smlal2 v27.4s, v12.8h, v5.h[4]\n"
  546. "smlal v24.4s, v13.4h, v4.h[5]\n"
  547. "smlal v26.4s, v13.4h, v5.h[5]\n"
  548. "smlal2 v25.4s, v13.8h, v4.h[5]\n"
  549. "smlal2 v27.4s, v13.8h, v5.h[5]\n"
  550. "ld1 {v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n"
  551. "smlal v24.4s, v14.4h, v4.h[6]\n"
  552. "smlal v26.4s, v14.4h, v5.h[6]\n"
  553. "smlal2 v25.4s, v14.8h, v4.h[6]\n"
  554. "smlal2 v27.4s, v14.8h, v5.h[6]\n"
  555. "smlal v28.4s, v8.4h, v6.h[0]\n"
  556. "smlal v30.4s, v8.4h, v7.h[0]\n"
  557. "smlal2 v29.4s, v8.8h, v6.h[0]\n"
  558. "smlal2 v31.4s, v8.8h, v7.h[0]\n"
  559. "smlal v28.4s, v9.4h, v6.h[1]\n"
  560. "smlal v30.4s, v9.4h, v7.h[1]\n"
  561. "smlal2 v29.4s, v9.8h, v6.h[1]\n"
  562. "smlal2 v31.4s, v9.8h, v7.h[1]\n"
  563. "smlal v28.4s, v10.4h, v6.h[2]\n"
  564. "smlal v30.4s, v10.4h, v7.h[2]\n"
  565. "smlal2 v29.4s, v10.8h, v6.h[2]\n"
  566. "smlal2 v31.4s, v10.8h, v7.h[2]\n"
  567. "smlal v28.4s, v11.4h, v6.h[3]\n"
  568. "smlal v30.4s, v11.4h, v7.h[3]\n"
  569. "smlal2 v29.4s, v11.8h, v6.h[3]\n"
  570. "smlal2 v31.4s, v11.8h, v7.h[3]\n"
  571. "smlal v28.4s, v12.4h, v6.h[4]\n"
  572. "smlal v30.4s, v12.4h, v7.h[4]\n"
  573. "smlal2 v29.4s, v12.8h, v6.h[4]\n"
  574. "smlal2 v31.4s, v12.8h, v7.h[4]\n"
  575. "smlal v28.4s, v13.4h, v6.h[5]\n"
  576. "smlal v30.4s, v13.4h, v7.h[5]\n"
  577. "smlal2 v29.4s, v13.8h, v6.h[5]\n"
  578. "smlal2 v31.4s, v13.8h, v7.h[5]\n"
  579. "subs %w[K], %w[K], #8\n"
  580. "bne 1b\n"
  581. "2:\n"
  582. "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output]], 64\n"
  583. "smlal v24.4s, v15.4h, v4.h[7]\n"
  584. "smlal v28.4s, v14.4h, v6.h[6]\n"
  585. "smlal v30.4s, v14.4h, v7.h[6]\n"
  586. "smlal v26.4s, v15.4h, v5.h[7]\n"
  587. "smlal2 v25.4s, v15.8h, v4.h[7]\n"
  588. "smlal2 v27.4s, v15.8h, v5.h[7]\n"
  589. "smlal2 v29.4s, v14.8h, v6.h[6]\n"
  590. "smlal2 v31.4s, v14.8h, v7.h[6]\n"
  591. "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output]], 64\n"
  592. "smlal v28.4s, v15.4h, v6.h[7]\n"
  593. "smlal v30.4s, v15.4h, v7.h[7]\n"
  594. "smlal2 v29.4s, v15.8h, v6.h[7]\n"
  595. "smlal2 v31.4s, v15.8h, v7.h[7]\n"
  596. "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n"
  597. "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n"
  598. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
  599. [output] "+r"(output), [LDB] "+r"(LDB)
  600. :
  601. : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
  602. "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
  603. "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
  604. "cc", "memory");
  605. }
  606. } // anonymous namespace
  607. MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8);
  608. void gemm_nopack_s16_8x8::kern(
  609. const dt_int16* A, size_t LDA, const dt_int16* B, size_t LDB, dt_int32* C,
  610. size_t LDC, size_t M, size_t K, size_t N, const dt_int32*, void*, bool trA,
  611. bool trB) const {
  612. constexpr static size_t MB = 8;
  613. constexpr static size_t KB = 8;
  614. constexpr static size_t NB = 8;
  615. constexpr static size_t CALCBLK = 4;
  616. megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0);
  617. //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8)
  618. for (size_t m = 0; m < M; m += MB) {
  619. dt_int32* output = C + (m / MB) * LDC;
  620. const dt_int16* cur_B = B;
  621. size_t n = 0;
  622. for (; n + NB - 1 < N; n += NB) {
  623. kern_8x8(A, cur_B, LDB, K, output);
  624. cur_B += KB * NB;
  625. output += MB * NB;
  626. }
  627. if (N - n >= 4) {
  628. kern_8x4(A, cur_B, LDB, K, output);
  629. cur_B += KB * CALCBLK;
  630. output += MB * CALCBLK;
  631. n += 4;
  632. }
  633. while (n < N) {
  634. kern_8x1(A, cur_B, LDB, K, output);
  635. cur_B += KB;
  636. output += MB;
  637. n++;
  638. }
  639. A += LDA;
  640. }
  641. }
  642. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台