You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_mk4_8x12.h 38 kB


  1. /**
  2. * \file dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "src/aarch64/matrix_mul/asm/common.h"
  14. #include "src/arm_common/simd_macro/marm_neon.h"
  15. namespace megdnn {
  16. namespace aarch64 {
  17. struct matmul_mk4_8x12 {
  18. // Overview of register layout:
  19. //
  20. // A 1x12 cell of Rhs is stored in 32bit in v2-v7
  21. // A 8x1 cell of Lhs is stored in 32bit in (v0-v1)
  22. // A 8x12 block of accumulators is stored in 32bit in v8-v31.
  23. //
  24. // +--------+--------+--------+
  25. // | v2[0-3]| v3[0-3]| v4[0-3]|
  26. // | v5[0-3]| v6[0-3]| v7[0-3]|
  27. // Rhs +--------+--------+--------+
  28. //
  29. // | | | |
  30. //
  31. // Lhs | | | |
  32. //
  33. // +--+ --- - +--------+--------+--------+
  34. // |v0| | v8[0-3]| v9[0-3]|v10[0-3]|
  35. // |v0| |v11[0-3]|v12[0-3]|v13[0-3]|
  36. // |v0| |v14[0-3]|v15[0-3]|v16[0-3]|
  37. // |v0| |v17[0-3]|v18[0-3]|v19[0-3]|
  38. // |v1| |v20[0-3]|v21[0-3]|v22[0-3]|
  39. // |v1| |v23[0-3]|v24[0-3]|v25[0-3]|
  40. // |v1| |v26[0-3]|v27[0-3]|v28[0-3]|
  41. // |v1| |v29[0-3]|v30[0-3]|v31[0-3]|
  42. // +--+ --- - +--------+--------+--------+
  43. //
  44. // Accumulator
  45. static void kern_8x12(const float* packA, const float* packB, int K,
  46. float* output, int LDC, bool is_first_k) {
  47. const float* a_ptr = packA;
  48. const float* b_ptr = packB;
  49. float* output0 = output;
  50. float* output1 = output0 + LDC;
  51. int oddk = (K & 1);
  52. K = ((K + 1) / 2) - 1;
  53. asm volatile(
  54. "cmp %w[is_first_k], #1\n"
  55. "beq 1f\n"
  56. "mov x1, %[output0]\n"
  57. "mov x2, %[output1]\n"
  58. "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n"
  59. "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n"
  60. "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n"
  61. "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64\n"
  62. "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64\n"
  63. "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64\n"
  64. "ld1 {v0.4s}, [%[a_ptr]], #16\n"
  65. "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n"
  66. "b 2f\n"
  67. "1:\n"
  68. "eor v8.16b, v8.16b, v8.16b\n"
  69. "eor v9.16b, v9.16b, v9.16b\n"
  70. "eor v10.16b, v10.16b, v10.16b\n"
  71. "prfm pstl1keep, [%[output0]]\n"
  72. "eor v11.16b, v11.16b, v11.16b\n"
  73. "eor v12.16b, v12.16b, v12.16b\n"
  74. "eor v13.16b, v13.16b, v13.16b\n"
  75. "prfm pstl1keep, [%[output1]]\n"
  76. "eor v14.16b, v14.16b, v14.16b\n"
  77. "eor v15.16b, v15.16b, v15.16b\n"
  78. "ld1 {v2.4s}, [%[b_ptr]], #16\n"
  79. "eor v16.16b, v16.16b, v16.16b\n"
  80. "ld1 {v3.4s}, [%[b_ptr]], #16\n"
  81. "eor v17.16b, v17.16b, v17.16b\n"
  82. "ld1 {v4.4s}, [%[b_ptr]], #16\n"
  83. "eor v18.16b, v18.16b, v18.16b\n"
  84. "eor v19.16b, v19.16b, v19.16b\n"
  85. "eor v20.16b, v20.16b, v20.16b\n"
  86. "ld1 {v0.4s}, [%[a_ptr]], #16\n"
  87. "eor v21.16b, v21.16b, v21.16b\n"
  88. "eor v22.16b, v22.16b, v22.16b\n"
  89. "eor v23.16b, v23.16b, v23.16b\n"
  90. "eor v24.16b, v24.16b, v24.16b\n"
  91. "eor v25.16b, v25.16b, v25.16b\n"
  92. "eor v26.16b, v26.16b, v26.16b\n"
  93. "eor v27.16b, v27.16b, v27.16b\n"
  94. "eor v28.16b, v28.16b, v28.16b\n"
  95. "eor v29.16b, v29.16b, v29.16b\n"
  96. "eor v30.16b, v30.16b, v30.16b\n"
  97. "eor v31.16b, v31.16b, v31.16b\n"
  98. "2: \n"
  99. "cmp %w[K], #0\n"
  100. "beq 4f\n"
  101. "3:\n"
  102. "fmla v8.4s, v0.4s, v2.s[0]\n"
  103. "fmla v9.4s, v0.4s, v2.s[1]\n"
  104. "ld1 {v1.4s}, [%[a_ptr]], 16\n"
  105. "fmla v10.4s, v0.4s, v2.s[2]\n"
  106. "fmla v11.4s, v0.4s, v2.s[3]\n"
  107. "ld1 {v5.4s}, [%[b_ptr]], #16\n"
  108. "fmla v12.4s, v0.4s, v3.s[0]\n"
  109. "fmla v13.4s, v0.4s, v3.s[1]\n"
  110. "ld1 {v6.4s}, [%[b_ptr]], #16\n"
  111. "fmla v14.4s, v0.4s, v3.s[2]\n"
  112. "fmla v15.4s, v0.4s, v3.s[3]\n"
  113. "ld1 {v7.4s}, [%[b_ptr]], #16\n"
  114. "fmla v16.4s, v0.4s, v4.s[0]\n"
  115. "fmla v17.4s, v0.4s, v4.s[1]\n"
  116. "fmla v18.4s, v0.4s, v4.s[2]\n"
  117. "fmla v19.4s, v0.4s, v4.s[3]\n"
  118. "ld1 {v0.4s}, [%[a_ptr]], 16\n"
  119. "fmla v20.4s, v1.4s, v2.s[0]\n"
  120. "fmla v21.4s, v1.4s, v2.s[1]\n"
  121. "fmla v22.4s, v1.4s, v2.s[2]\n"
  122. "fmla v23.4s, v1.4s, v2.s[3]\n"
  123. "fmla v24.4s, v1.4s, v3.s[0]\n"
  124. "fmla v25.4s, v1.4s, v3.s[1]\n"
  125. "fmla v26.4s, v1.4s, v3.s[2]\n"
  126. "fmla v27.4s, v1.4s, v3.s[3]\n"
  127. "fmla v28.4s, v1.4s, v4.s[0]\n"
  128. "fmla v29.4s, v1.4s, v4.s[1]\n"
  129. "fmla v30.4s, v1.4s, v4.s[2]\n"
  130. "fmla v31.4s, v1.4s, v4.s[3]\n"
  131. "fmla v8.4s, v0.4s, v5.s[0]\n"
  132. "fmla v9.4s, v0.4s, v5.s[1]\n"
  133. "ld1 {v1.4s}, [%[a_ptr]], 16\n"
  134. "fmla v10.4s, v0.4s, v5.s[2]\n"
  135. "fmla v11.4s, v0.4s, v5.s[3]\n"
  136. "ld1 {v2.4s}, [%[b_ptr]], 16\n"
  137. "fmla v12.4s, v0.4s, v6.s[0]\n"
  138. "fmla v13.4s, v0.4s, v6.s[1]\n"
  139. "ld1 {v3.4s}, [%[b_ptr]], 16\n"
  140. "fmla v14.4s, v0.4s, v6.s[2]\n"
  141. "fmla v15.4s, v0.4s, v6.s[3]\n"
  142. "ld1 {v4.4s}, [%[b_ptr]], 16\n"
  143. "fmla v16.4s, v0.4s, v7.s[0]\n"
  144. "fmla v17.4s, v0.4s, v7.s[1]\n"
  145. "fmla v18.4s, v0.4s, v7.s[2]\n"
  146. "fmla v19.4s, v0.4s, v7.s[3]\n"
  147. "ld1 {v0.4s}, [%[a_ptr]], 16\n"
  148. "fmla v20.4s, v1.4s, v5.s[0]\n"
  149. "fmla v21.4s, v1.4s, v5.s[1]\n"
  150. "fmla v22.4s, v1.4s, v5.s[2]\n"
  151. "fmla v23.4s, v1.4s, v5.s[3]\n"
  152. "fmla v24.4s, v1.4s, v6.s[0]\n"
  153. "subs %w[K], %w[K], #1\n"
  154. "fmla v25.4s, v1.4s, v6.s[1]\n"
  155. "fmla v26.4s, v1.4s, v6.s[2]\n"
  156. "fmla v27.4s, v1.4s, v6.s[3]\n"
  157. "fmla v28.4s, v1.4s, v7.s[0]\n"
  158. "fmla v29.4s, v1.4s, v7.s[1]\n"
  159. "fmla v30.4s, v1.4s, v7.s[2]\n"
  160. "fmla v31.4s, v1.4s, v7.s[3]\n"
  161. "bne 3b\n"
  162. "4:\n"
  163. "cmp %w[oddk], #1\n"
  164. "beq 5f\n"
  165. // Even tail
  166. "fmla v8.4s, v0.4s, v2.s[0]\n"
  167. "fmla v9.4s, v0.4s, v2.s[1]\n"
  168. "ld1 {v1.4s}, [%[a_ptr]], 16\n"
  169. "fmla v10.4s, v0.4s, v2.s[2]\n"
  170. "fmla v11.4s, v0.4s, v2.s[3]\n"
  171. "fmla v12.4s, v0.4s, v3.s[0]\n"
  172. "fmla v13.4s, v0.4s, v3.s[1]\n"
  173. "fmla v14.4s, v0.4s, v3.s[2]\n"
  174. "fmla v15.4s, v0.4s, v3.s[3]\n"
  175. "fmla v16.4s, v0.4s, v4.s[0]\n"
  176. "fmla v17.4s, v0.4s, v4.s[1]\n"
  177. "fmla v18.4s, v0.4s, v4.s[2]\n"
  178. "fmla v19.4s, v0.4s, v4.s[3]\n"
  179. "fmla v20.4s, v1.4s, v2.s[0]\n"
  180. "ld1 {v5.4s}, [%[b_ptr]], #16\n"
  181. "fmla v21.4s, v1.4s, v2.s[1]\n"
  182. "fmla v22.4s, v1.4s, v2.s[2]\n"
  183. "ld1 {v6.4s}, [%[b_ptr]], #16\n"
  184. "fmla v23.4s, v1.4s, v2.s[3]\n"
  185. "fmla v24.4s, v1.4s, v3.s[0]\n"
  186. "ld1 {v7.4s}, [%[b_ptr]], #16\n"
  187. "fmla v25.4s, v1.4s, v3.s[1]\n"
  188. "ld1 {v0.4s}, [%[a_ptr]], 16\n"
  189. "fmla v26.4s, v1.4s, v3.s[2]\n"
  190. "fmla v27.4s, v1.4s, v3.s[3]\n"
  191. "fmla v28.4s, v1.4s, v4.s[0]\n"
  192. "fmla v29.4s, v1.4s, v4.s[1]\n"
  193. "fmla v30.4s, v1.4s, v4.s[2]\n"
  194. "fmla v31.4s, v1.4s, v4.s[3]\n"
  195. "fmla v8.4s, v0.4s, v5.s[0]\n"
  196. "fmla v9.4s, v0.4s, v5.s[1]\n"
  197. "fmla v10.4s, v0.4s, v5.s[2]\n"
  198. "fmla v11.4s, v0.4s, v5.s[3]\n"
  199. "ld1 {v1.4s}, [%[a_ptr]], 16\n"
  200. "fmla v12.4s, v0.4s, v6.s[0]\n"
  201. "fmla v13.4s, v0.4s, v6.s[1]\n"
  202. "fmla v14.4s, v0.4s, v6.s[2]\n"
  203. "fmla v15.4s, v0.4s, v6.s[3]\n"
  204. "st1 {v8.4s}, [%[output0]], #16\n"
  205. "fmla v16.4s, v0.4s, v7.s[0]\n"
  206. "st1 {v9.4s}, [%[output0]], #16\n"
  207. "fmla v17.4s, v0.4s, v7.s[1]\n"
  208. "st1 {v10.4s}, [%[output0]], #16\n"
  209. "fmla v18.4s, v0.4s, v7.s[2]\n"
  210. "st1 {v11.4s}, [%[output0]], #16\n"
  211. "fmla v19.4s, v0.4s, v7.s[3]\n"
  212. "st1 {v12.4s}, [%[output0]], #16\n"
  213. "fmla v20.4s, v1.4s, v5.s[0]\n"
  214. "st1 {v13.4s}, [%[output0]], #16\n"
  215. "fmla v21.4s, v1.4s, v5.s[1]\n"
  216. "st1 {v14.4s}, [%[output0]], #16\n"
  217. "fmla v22.4s, v1.4s, v5.s[2]\n"
  218. "st1 {v15.4s}, [%[output0]], #16\n"
  219. "fmla v23.4s, v1.4s, v5.s[3]\n"
  220. "st1 {v16.4s}, [%[output0]], #16\n"
  221. "fmla v24.4s, v1.4s, v6.s[0]\n"
  222. "st1 {v17.4s}, [%[output0]], #16\n"
  223. "fmla v25.4s, v1.4s, v6.s[1]\n"
  224. "st1 {v18.4s}, [%[output0]], #16\n"
  225. "fmla v26.4s, v1.4s, v6.s[2]\n"
  226. "st1 {v19.4s}, [%[output0]], #16\n"
  227. "fmla v27.4s, v1.4s, v6.s[3]\n"
  228. "st1 {v20.4s}, [%[output1]], #16\n"
  229. "fmla v28.4s, v1.4s, v7.s[0]\n"
  230. "st1 {v21.4s}, [%[output1]], #16\n"
  231. "fmla v29.4s, v1.4s, v7.s[1]\n"
  232. "st1 {v22.4s}, [%[output1]], #16\n"
  233. "fmla v30.4s, v1.4s, v7.s[2]\n"
  234. "st1 {v23.4s}, [%[output1]], #16\n"
  235. "fmla v31.4s, v1.4s, v7.s[3]\n"
  236. "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output1]], #64\n"
  237. "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output1]], #64\n"
  238. "b 6f\n"
  239. // odd tail
  240. "5:\n"
  241. "fmla v8.4s, v0.4s, v2.s[0]\n"
  242. "fmla v9.4s, v0.4s, v2.s[1]\n"
  243. "fmla v10.4s, v0.4s, v2.s[2]\n"
  244. "ld1 {v1.4s}, [%[a_ptr]], 16\n"
  245. "fmla v11.4s, v0.4s, v2.s[3]\n"
  246. "fmla v12.4s, v0.4s, v3.s[0]\n"
  247. "fmla v13.4s, v0.4s, v3.s[1]\n"
  248. "fmla v14.4s, v0.4s, v3.s[2]\n"
  249. "st1 {v8.4s}, [%[output0]], #16\n"
  250. "fmla v15.4s, v0.4s, v3.s[3]\n"
  251. "st1 {v9.4s}, [%[output0]], #16\n"
  252. "fmla v16.4s, v0.4s, v4.s[0]\n"
  253. "st1 {v10.4s}, [%[output0]], #16\n"
  254. "fmla v17.4s, v0.4s, v4.s[1]\n"
  255. "st1 {v11.4s}, [%[output0]], #16\n"
  256. "fmla v18.4s, v0.4s, v4.s[2]\n"
  257. "st1 {v12.4s}, [%[output0]], #16\n"
  258. "fmla v19.4s, v0.4s, v4.s[3]\n"
  259. "st1 {v13.4s}, [%[output0]], #16\n"
  260. "fmla v20.4s, v1.4s, v2.s[0]\n"
  261. "st1 {v14.4s}, [%[output0]], #16\n"
  262. "fmla v21.4s, v1.4s, v2.s[1]\n"
  263. "st1 {v15.4s}, [%[output0]], #16\n"
  264. "fmla v22.4s, v1.4s, v2.s[2]\n"
  265. "st1 {v16.4s}, [%[output0]], #16\n"
  266. "fmla v23.4s, v1.4s, v2.s[3]\n"
  267. "st1 {v17.4s}, [%[output0]], #16\n"
  268. "fmla v24.4s, v1.4s, v3.s[0]\n"
  269. "st1 {v18.4s}, [%[output0]], #16\n"
  270. "fmla v25.4s, v1.4s, v3.s[1]\n"
  271. "st1 {v19.4s}, [%[output0]], #16\n"
  272. "fmla v26.4s, v1.4s, v3.s[2]\n"
  273. "st1 {v20.4s}, [%[output1]], #16\n"
  274. "fmla v27.4s, v1.4s, v3.s[3]\n"
  275. "st1 {v21.4s}, [%[output1]], #16\n"
  276. "fmla v28.4s, v1.4s, v4.s[0]\n"
  277. "st1 {v22.4s}, [%[output1]], #16\n"
  278. "fmla v29.4s, v1.4s, v4.s[1]\n"
  279. "st1 {v23.4s}, [%[output1]], #16\n"
  280. "fmla v30.4s, v1.4s, v4.s[2]\n"
  281. "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output1]], #64\n"
  282. "fmla v31.4s, v1.4s, v4.s[3]\n"
  283. "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output1]], #64\n"
  284. "6:\n"
  285. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
  286. [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
  287. [output0] "+r"(output0), [output1] "+r"(output1)
  288. :
  289. : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
  290. "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
  291. "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
  292. "v28", "v29", "v30", "v31", "x1", "x2", "cc", "memory");
  293. }
  294. // Overview of register layout:
  295. //
  296. // A 1x12 cell of Rhs is stored in 32bit in v2-v7
  297. // A 8x1 cell of Lhs is stored in 32bit in (v0-v1)
  298. // A 8x12 block of accumulators is stored in 32bit in v8-v31.
  299. //
  300. // +--------+
  301. // | v2[0-3]|
  302. // | v3[0-3]|
  303. // Rhs +--------+
  304. //
  305. // | |
  306. //
  307. // Lhs | |
  308. //
  309. // +--+ --- - +--------+
  310. // |v0| | v8[0-3]|
  311. // |v0| |v11[0-3]|
  312. // |v0| |v14[0-3]|
  313. // |v0| |v17[0-3]|
  314. // |v1| |v20[0-3]|
  315. // |v1| |v23[0-3]|
  316. // |v1| |v26[0-3]|
  317. // |v1| |v29[0-3]|
  318. // +--+ --- - +--------+
  319. //
  320. // Accumulator
  321. static void kern_8x4(const float* packA, const float* packB, int K,
  322. float* output, int LDC, bool is_first_k,
  323. int n_remain) {
  324. const float* a_ptr = packA;
  325. const float* b_ptr = packB;
  326. float* output0 = output;
  327. float* output1 = output0 + LDC;
  328. int oddk = (K & 1);
  329. K = ((K + 1) / 2) - 1;
  330. //clang-format off
  331. #define LOAD_C \
  332. "cmp %w[n_remain], #4\n" \
  333. "blt 11f\n" \
  334. "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \
  335. "ld1 {v12.4s, v13.4s, v14.4s, v15.4s},[%[output1]]\n" \
  336. "b 14f\n" \
  337. "11:\n" \
  338. "cmp %w[n_remain], #3\n" \
  339. "blt 12f\n" \
  340. "ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
  341. "ld1 {v12.4s, v13.4s, v14.4s},[%[output1]]\n" \
  342. "b 14f\n" \
  343. "12:\n" \
  344. "cmp %w[n_remain], #2\n" \
  345. "blt 13f\n" \
  346. "ld1 {v8.4s, v9.4s}, [%[output0]]\n" \
  347. "ld1 {v12.4s, v13.4s},[%[output1]]\n" \
  348. "b 14f\n" \
  349. "13:\n" \
  350. "ld1 {v8.4s}, [%[output0]]\n" \
  351. "ld1 {v12.4s},[%[output1]]\n" \
  352. "14:\n"
  353. #define STORE_C \
  354. "cmp %w[n_remain], #4\n" \
  355. "blt 21f\n" \
  356. "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \
  357. "st1 {v12.4s, v13.4s, v14.4s, v15.4s},[%[output1]]\n" \
  358. "b 24f\n" \
  359. "21:\n" \
  360. "cmp %w[n_remain], #3\n" \
  361. "blt 22f\n" \
  362. "st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
  363. "st1 {v12.4s, v13.4s, v14.4s},[%[output1]]\n" \
  364. "b 23f\n" \
  365. "22:\n" \
  366. "cmp %w[n_remain], #2\n" \
  367. "blt 23f\n" \
  368. "st1 {v8.4s, v9.4s}, [%[output0]]\n" \
  369. "st1 {v12.4s, v13.4s},[%[output1]]\n" \
  370. "b 24f\n" \
  371. "23:\n" \
  372. "st1 {v8.4s}, [%[output0]]\n" \
  373. "st1 {v12.4s},[%[output1]]\n" \
  374. "24:\n"
  375. //clang-format on
  376. asm volatile(
  377. // load accumulator C
  378. "cmp %w[is_first_k], #1\n"
  379. "beq 1f\n" LOAD_C
  380. "ld1 {v0.4s}, [%[a_ptr]], #16\n"
  381. "ld1 {v2.4s}, [%[b_ptr]], #16\n"
  382. "b 2f\n"
  383. "1:\n"
  384. "eor v8.16b, v8.16b, v8.16b\n"
  385. "ld1 {v0.4s}, [%[a_ptr]], #16\n"
  386. "eor v9.16b, v9.16b, v9.16b\n"
  387. "eor v10.16b, v10.16b, v10.16b\n"
  388. "prfm pstl1keep, [%[output0]]\n"
  389. "eor v11.16b, v11.16b, v11.16b\n"
  390. "eor v12.16b, v12.16b, v12.16b\n"
  391. "prfm pstl1keep, [%[output1]]\n"
  392. "eor v13.16b, v13.16b, v13.16b\n"
  393. "eor v14.16b, v14.16b, v14.16b\n"
  394. "eor v15.16b, v15.16b, v15.16b\n"
  395. "ld1 {v2.4s}, [%[b_ptr]], #16\n"
  396. "2: \n"
  397. "cmp %w[K], #0\n"
  398. "beq 4f\n"
  399. "3:\n"
  400. "fmla v8.4s, v0.4s, v2.s[0]\n"
  401. "ld1 {v1.4s}, [%[a_ptr]], #16\n"
  402. "fmla v9.4s, v0.4s, v2.s[1]\n"
  403. "fmla v10.4s, v0.4s, v2.s[2]\n"
  404. "ld1 {v3.4s}, [%[b_ptr]], #16\n"
  405. "fmla v11.4s, v0.4s, v2.s[3]\n"
  406. "fmla v12.4s, v1.4s, v2.s[0]\n"
  407. "ld1 {v0.4s}, [%[a_ptr]], #16\n"
  408. "fmla v13.4s, v1.4s, v2.s[1]\n"
  409. "fmla v14.4s, v1.4s, v2.s[2]\n"
  410. "fmla v15.4s, v1.4s, v2.s[3]\n"
  411. "fmla v8.4s, v0.4s, v3.s[0]\n"
  412. "ld1 {v1.4s}, [%[a_ptr]], #16\n"
  413. "fmla v9.4s, v0.4s, v3.s[1]\n"
  414. "fmla v10.4s, v0.4s, v3.s[2]\n"
  415. "fmla v11.4s, v0.4s, v3.s[3]\n"
  416. "ld1 {v2.4s}, [%[b_ptr]], #16\n"
  417. "fmla v12.4s, v1.4s, v3.s[0]\n"
  418. "subs %w[K], %w[K], #1\n"
  419. "fmla v13.4s, v1.4s, v3.s[1]\n"
  420. "ld1 {v0.4s}, [%[a_ptr]], #16\n"
  421. "fmla v14.4s, v1.4s, v3.s[2]\n"
  422. "fmla v15.4s, v1.4s, v3.s[3]\n"
  423. "bne 3b\n"
  424. "4:\n"
  425. "cmp %w[oddk], #1\n"
  426. "beq 5f\n"
  427. // Even tail
  428. "fmla v8.4s, v0.4s, v2.s[0]\n"
  429. "ld1 {v1.4s}, [%[a_ptr]], #16\n"
  430. "fmla v9.4s, v0.4s, v2.s[1]\n"
  431. "fmla v10.4s, v0.4s, v2.s[2]\n"
  432. "ld1 {v3.4s}, [%[b_ptr]], #16\n"
  433. "fmla v11.4s, v0.4s, v2.s[3]\n"
  434. "fmla v12.4s, v1.4s, v2.s[0]\n"
  435. "ld1 {v0.4s}, [%[a_ptr]], #16\n"
  436. "fmla v13.4s, v1.4s, v2.s[1]\n"
  437. "fmla v14.4s, v1.4s, v2.s[2]\n"
  438. "fmla v15.4s, v1.4s, v2.s[3]\n"
  439. "fmla v8.4s, v0.4s, v3.s[0]\n"
  440. "ld1 {v1.4s}, [%[a_ptr]], #16\n"
  441. "fmla v9.4s, v0.4s, v3.s[1]\n"
  442. "fmla v10.4s, v0.4s, v3.s[2]\n"
  443. "fmla v11.4s, v0.4s, v3.s[3]\n"
  444. "fmla v12.4s, v1.4s, v3.s[0]\n"
  445. "fmla v13.4s, v1.4s, v3.s[1]\n"
  446. "fmla v14.4s, v1.4s, v3.s[2]\n"
  447. "fmla v15.4s, v1.4s, v3.s[3]\n"
  448. "b 6f\n"
  449. // odd tail
  450. "5:\n"
  451. "fmla v8.4s, v0.4s, v2.s[0]\n"
  452. "ld1 {v1.4s}, [%[a_ptr]], #16\n"
  453. "fmla v9.4s, v0.4s, v2.s[1]\n"
  454. "fmla v10.4s, v0.4s, v2.s[2]\n"
  455. "fmla v11.4s, v0.4s, v2.s[3]\n"
  456. "fmla v12.4s, v1.4s, v2.s[0]\n"
  457. "fmla v13.4s, v1.4s, v2.s[1]\n"
  458. "fmla v14.4s, v1.4s, v2.s[2]\n"
  459. "fmla v15.4s, v1.4s, v2.s[3]\n"
  460. "6:\n" STORE_C
  461. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
  462. [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
  463. [output0] "+r"(output0), [output1] "+r"(output1),
  464. [n_remain] "+r"(n_remain)
  465. :
  466. : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12",
  467. "v13", "v14", "v15", "cc", "memory");
  468. #undef LOAD_C
  469. #undef STORE_C
  470. }
  471. // Overview of register layout:
  472. //
  473. // A 1x12 cell of Rhs is stored in 32bit in v2-v7
  474. // A 8x1 cell of Lhs is stored in 32bit in (v0-v1)
  475. // A 8x12 block of accumulators is stored in 32bit in v8-v31.
  476. //
  477. // +--------+--------+--------+
  478. // | v2[0-3]| v3[0-3]| v4[0-3]|
  479. // | v5[0-3]| v6[0-3]| v7[0-3]|
  480. // Rhs +--------+--------+--------+
  481. //
  482. // | | | |
  483. //
  484. // Lhs | | | |
  485. //
  486. // +--+ --- - +--------+--------+--------+
  487. // |v0| | v8[0-3]| v9[0-3]|v10[0-3]|
  488. // |v0| |v11[0-3]|v12[0-3]|v13[0-3]|
  489. // |v0| |v14[0-3]|v15[0-3]|v16[0-3]|
  490. // |v0| |v17[0-3]|v18[0-3]|v19[0-3]|
  491. // +--+ --- - +--------+--------+--------+
  492. //
  493. // Accumulator
  494. static void kern_4x12(const float* packA, const float* packB, int K,
  495. float* output, int LDC, bool is_first_k) {
  496. MEGDNN_MARK_USED_VAR(LDC);
  497. const float* a_ptr = packA;
  498. const float* b_ptr = packB;
  499. float* output0 = output;
  500. int oddk = (K & 1);
  501. K = ((K + 1) / 2) - 1;
  502. asm volatile(
  503. "cmp %w[is_first_k], #1\n"
  504. "beq 1f\n"
  505. "mov x1, %[output0]\n"
  506. "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n"
  507. "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n"
  508. "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n"
  509. "ld1 {v0.4s}, [%[a_ptr]], #16\n"
  510. "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n"
  511. "b 2f\n"
  512. "1:\n"
  513. "eor v8.16b, v8.16b, v8.16b\n"
  514. "eor v9.16b, v9.16b, v9.16b\n"
  515. "eor v10.16b, v10.16b, v10.16b\n"
  516. "prfm pstl1keep, [%[output0]]\n"
  517. "eor v11.16b, v11.16b, v11.16b\n"
  518. "eor v12.16b, v12.16b, v12.16b\n"
  519. "eor v13.16b, v13.16b, v13.16b\n"
  520. "eor v14.16b, v14.16b, v14.16b\n"
  521. "eor v15.16b, v15.16b, v15.16b\n"
  522. "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n"
  523. "eor v16.16b, v16.16b, v16.16b\n"
  524. "eor v17.16b, v17.16b, v17.16b\n"
  525. "ld1 {v0.4s}, [%[a_ptr]], #16\n"
  526. "eor v18.16b, v18.16b, v18.16b\n"
  527. "eor v19.16b, v19.16b, v19.16b\n"
  528. "2: \n"
  529. "cmp %w[K], #0\n"
  530. "beq 4f\n"
  531. "3:\n"
  532. "fmla v8.4s, v0.4s, v2.s[0]\n"
  533. "fmla v9.4s, v0.4s, v2.s[1]\n"
  534. "ld1 {v1.4s}, [%[a_ptr]], 16\n"
  535. "fmla v10.4s, v0.4s, v2.s[2]\n"
  536. "fmla v11.4s, v0.4s, v2.s[3]\n"
  537. "fmla v12.4s, v0.4s, v3.s[0]\n"
  538. "fmla v13.4s, v0.4s, v3.s[1]\n"
  539. "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n"
  540. "fmla v14.4s, v0.4s, v3.s[2]\n"
  541. "fmla v15.4s, v0.4s, v3.s[3]\n"
  542. "fmla v16.4s, v0.4s, v4.s[0]\n"
  543. "fmla v17.4s, v0.4s, v4.s[1]\n"
  544. "fmla v18.4s, v0.4s, v4.s[2]\n"
  545. "fmla v19.4s, v0.4s, v4.s[3]\n"
  546. "fmla v8.4s, v1.4s, v5.s[0]\n"
  547. "fmla v9.4s, v1.4s, v5.s[1]\n"
  548. "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n"
  549. "fmla v10.4s, v1.4s, v5.s[2]\n"
  550. "fmla v11.4s, v1.4s, v5.s[3]\n"
  551. "ld1 {v0.4s}, [%[a_ptr]], 16\n"
  552. "fmla v12.4s, v1.4s, v6.s[0]\n"
  553. "fmla v13.4s, v1.4s, v6.s[1]\n"
  554. "subs %w[K], %w[K], #1\n"
  555. "fmla v14.4s, v1.4s, v6.s[2]\n"
  556. "fmla v15.4s, v1.4s, v6.s[3]\n"
  557. "fmla v16.4s, v1.4s, v7.s[0]\n"
  558. "fmla v17.4s, v1.4s, v7.s[1]\n"
  559. "fmla v18.4s, v1.4s, v7.s[2]\n"
  560. "fmla v19.4s, v1.4s, v7.s[3]\n"
  561. "bne 3b\n"
  562. "4:\n"
  563. "cmp %w[oddk], #1\n"
  564. "beq 5f\n"
  565. // Even tail
  566. "fmla v8.4s, v0.4s, v2.s[0]\n"
  567. "fmla v9.4s, v0.4s, v2.s[1]\n"
  568. "ld1 {v1.4s}, [%[a_ptr]], 16\n"
  569. "fmla v10.4s, v0.4s, v2.s[2]\n"
  570. "fmla v11.4s, v0.4s, v2.s[3]\n"
  571. "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n"
  572. "fmla v12.4s, v0.4s, v3.s[0]\n"
  573. "fmla v13.4s, v0.4s, v3.s[1]\n"
  574. "fmla v14.4s, v0.4s, v3.s[2]\n"
  575. "fmla v15.4s, v0.4s, v3.s[3]\n"
  576. "fmla v16.4s, v0.4s, v4.s[0]\n"
  577. "fmla v17.4s, v0.4s, v4.s[1]\n"
  578. "fmla v18.4s, v0.4s, v4.s[2]\n"
  579. "fmla v19.4s, v0.4s, v4.s[3]\n"
  580. "fmla v8.4s, v1.4s, v5.s[0]\n"
  581. "fmla v9.4s, v1.4s, v5.s[1]\n"
  582. "fmla v10.4s, v1.4s, v5.s[2]\n"
  583. "fmla v11.4s, v1.4s, v5.s[3]\n"
  584. "ld1 {v0.4s}, [%[a_ptr]], 16\n"
  585. "fmla v12.4s, v1.4s, v6.s[0]\n"
  586. "fmla v13.4s, v1.4s, v6.s[1]\n"
  587. "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n"
  588. "fmla v14.4s, v1.4s, v6.s[2]\n"
  589. "fmla v15.4s, v1.4s, v6.s[3]\n"
  590. "fmla v16.4s, v1.4s, v7.s[0]\n"
  591. "fmla v17.4s, v1.4s, v7.s[1]\n"
  592. "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n"
  593. "fmla v18.4s, v1.4s, v7.s[2]\n"
  594. "fmla v19.4s, v1.4s, v7.s[3]\n"
  595. "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n"
  596. "b 6f\n"
  597. // odd tail
  598. "5:\n"
  599. "fmla v8.4s, v0.4s, v2.s[0]\n"
  600. "fmla v9.4s, v0.4s, v2.s[1]\n"
  601. "fmla v10.4s, v0.4s, v2.s[2]\n"
  602. "fmla v11.4s, v0.4s, v2.s[3]\n"
  603. "fmla v12.4s, v0.4s, v3.s[0]\n"
  604. "fmla v13.4s, v0.4s, v3.s[1]\n"
  605. "fmla v14.4s, v0.4s, v3.s[2]\n"
  606. "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n"
  607. "fmla v15.4s, v0.4s, v3.s[3]\n"
  608. "fmla v16.4s, v0.4s, v4.s[0]\n"
  609. "fmla v17.4s, v0.4s, v4.s[1]\n"
  610. "fmla v18.4s, v0.4s, v4.s[2]\n"
  611. "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n"
  612. "fmla v19.4s, v0.4s, v4.s[3]\n"
  613. "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n"
  614. "6:\n"
  615. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
  616. [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
  617. [output0] "+r"(output0)
  618. :
  619. : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
  620. "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
  621. "v19", "x1", "cc", "memory");
  622. }
  623. // Overview of register layout:
  624. //
  625. // A 2x4 cell of Rhs is stored in 32bit in v2 - v3
  626. // A 4x2 cell of Lhs is stored in 32bit in v0 - v1
  627. // A 4x4 block of accumulators is stored in 32bit in v4-v6
  628. //
  629. // +--------+
  630. // | v2[0-3]|
  631. // | v5[0-3]|
  632. // Rhs +--------+
  633. //
  634. // | |
  635. //
  636. // Lhs | |
  637. //
  638. // +--+ --- - +--------+
  639. // |v0| | v8[0-3]|
  640. // |v0| |v11[0-3]|
  641. // |v0| |v14[0-3]|
  642. // |v0| |v17[0-3]|
  643. // +--+ --- - +--------+
  644. //
  645. // Accumulator
  646. static void kern_4x4(const float* packA, const float* packB, int K,
  647. float* output, int LDC, bool is_first_k,
  648. int n_remain) {
  649. MEGDNN_MARK_USED_VAR(LDC);
  650. const float* a_ptr = packA;
  651. const float* b_ptr = packB;
  652. float* output0 = output;
  653. int oddk = (K & 1);
  654. K = ((K + 1) / 2) - 1;
  655. //clang-format off
  656. #define LOAD_C \
  657. "cmp %w[n_remain], #4\n" \
  658. "blt 11f\n" \
  659. "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \
  660. "b 14f\n" \
  661. "11:\n" \
  662. "cmp %w[n_remain], #3\n" \
  663. "blt 12f\n" \
  664. "ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
  665. "b 14f\n" \
  666. "12:\n" \
  667. "cmp %w[n_remain], #2\n" \
  668. "blt 13f\n" \
  669. "ld1 {v8.4s, v9.4s}, [%[output0]]\n" \
  670. "b 14f\n" \
  671. "13:\n" \
  672. "ld1 {v8.4s}, [%[output0]]\n" \
  673. "14:\n"
  674. #define STORE_C \
  675. "cmp %w[n_remain], #4\n" \
  676. "blt 21f\n" \
  677. "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \
  678. "b 24f\n" \
  679. "21:\n" \
  680. "cmp %w[n_remain], #3\n" \
  681. "blt 22f\n" \
  682. "st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
  683. "b 24f\n" \
  684. "22:\n" \
  685. "cmp %w[n_remain], #2\n" \
  686. "blt 23f\n" \
  687. "st1 {v8.4s, v9.4s}, [%[output0]]\n" \
  688. "b 24f\n" \
  689. "23:\n" \
  690. "st1 {v8.4s}, [%[output0]]\n" \
  691. "24:\n"
  692. //clang-format on
  693. asm volatile(
  694. // load accumulator C
  695. "cmp %w[is_first_k], #1\n"
  696. "beq 1f\n" LOAD_C
  697. "ld1 {v0.4s}, [%[a_ptr]], #16\n"
  698. "ld1 {v2.4s}, [%[b_ptr]], #16\n"
  699. "b 2f\n"
  700. "1:\n"
  701. "eor v8.16b, v8.16b, v8.16b\n"
  702. "ld1 {v2.4s}, [%[b_ptr]], #16\n"
  703. "eor v9.16b, v9.16b, v9.16b\n"
  704. "ld1 {v0.4s}, [%[a_ptr]], #16\n"
  705. "eor v10.16b, v10.16b, v10.16b\n"
  706. "prfm pstl1keep, [%[output0]]\n"
  707. "eor v11.16b, v11.16b, v11.16b\n"
  708. "2: \n"
  709. "cmp %w[K], #0\n"
  710. "beq 4f\n"
  711. "3:\n"
  712. "fmla v8.4s, v0.4s, v2.s[0]\n"
  713. "ld1 {v1.4s}, [%[a_ptr]], 16\n"
  714. "fmla v9.4s, v0.4s, v2.s[1]\n"
  715. "fmla v10.4s, v0.4s, v2.s[2]\n"
  716. "ld1 {v3.4s}, [%[b_ptr]], 16\n"
  717. "fmla v11.4s, v0.4s, v2.s[3]\n"
  718. "fmla v8.4s, v1.4s, v3.s[0]\n"
  719. "fmla v9.4s, v1.4s, v3.s[1]\n"
  720. "ld1 {v0.4s}, [%[a_ptr]], 16\n"
  721. "fmla v10.4s, v1.4s, v3.s[2]\n"
  722. "fmla v11.4s, v1.4s, v3.s[3]\n"
  723. "ld1 {v2.4s}, [%[b_ptr]], 16\n"
  724. "subs %w[K], %w[K], #1\n"
  725. "bne 3b\n"
  726. "4:\n"
  727. "cmp %w[oddk], #1\n"
  728. "beq 5f\n"
  729. // Even tail
  730. "fmla v8.4s, v0.4s, v2.s[0]\n"
  731. "ld1 {v1.4s}, [%[a_ptr]], 16\n"
  732. "fmla v9.4s, v0.4s, v2.s[1]\n"
  733. "fmla v10.4s, v0.4s, v2.s[2]\n"
  734. "ld1 {v3.4s}, [%[b_ptr]], 16\n"
  735. "fmla v11.4s, v0.4s, v2.s[3]\n"
  736. "fmla v8.4s, v1.4s, v3.s[0]\n"
  737. "fmla v9.4s, v1.4s, v3.s[1]\n"
  738. "fmla v10.4s, v1.4s, v3.s[2]\n"
  739. "fmla v11.4s, v1.4s, v3.s[3]\n"
  740. "b 6f\n"
  741. // odd tail
  742. "5:\n"
  743. "fmla v8.4s, v0.4s, v2.s[0]\n"
  744. "fmla v9.4s, v0.4s, v2.s[1]\n"
  745. "fmla v10.4s, v0.4s, v2.s[2]\n"
  746. "fmla v11.4s, v0.4s, v2.s[3]\n"
  747. "6:\n" STORE_C
  748. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
  749. [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
  750. [output0] "+r"(output0), [n_remain] "+r"(n_remain)
  751. :
  752. : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc",
  753. "memory");
  754. #undef LOAD_C
  755. #undef STORE_C
  756. }
  757. static void sgemm_8x12_pack_A(float* outptr, const float* inptr, int ldin,
  758. int y0, int ymax, int k0, int kmax) {
  759. megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4");
  760. megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
  761. constexpr int PACK_SIZE_32 = 4 * 8;
  762. constexpr int PACK_SIZE_16 = 4 * 4;
  763. constexpr int PACK_C_SIZE = 4;
  764. int y = y0;
  765. for (; y + 7 < ymax; y += 8) {
  766. const float* inptr0 = inptr + y / PACK_C_SIZE * ldin + k0;
  767. const float* inptr1 = inptr0 + ldin;
  768. prefetch_2x(inptr0);
  769. prefetch_2x(inptr1);
  770. int k = (kmax - k0);
  771. for (; k > 3; k -= 4) {
  772. interleave_2x4_4_s(inptr0, inptr1, outptr);
  773. outptr += PACK_SIZE_32;
  774. }
  775. }
  776. for (; y < ymax; y += 4) {
  777. const float* inptr0 = inptr + y / PACK_C_SIZE * ldin + k0;
  778. prefetch_2x(inptr0);
  779. int K = (kmax - k0);
  780. for (; K > 3; K -= 4) {
  781. interleave_1x4_4_s(inptr0, outptr);
  782. outptr += PACK_SIZE_16;
  783. }
  784. }
  785. }
  786. static void sgemm_8x12_pack_B(float* out, const float* in, int ldin, int x0,
  787. int xmax, int k0, int kmax) {
  788. megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
  789. float tmpbuff[16] = {0.0f};
  790. constexpr int PACK_C_SIZE = 4;
  791. int ksize = kmax - k0;
  792. int ksize12 = ksize * 12;
  793. int ksize4 = (ksize << 2);
  794. float* outptr_base = out;
  795. float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12;
  796. int k = k0;
  797. for (; k + 3 < kmax; k += 4) {
  798. const float* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE;
  799. prefetch_3x(inptr);
  800. int x = x0;
  801. auto outptr = outptr_base;
  802. for (; x + 12 <= xmax; x += 12) {
  803. auto outptr_interleave = outptr;
  804. transpose_1x12_4_s(inptr, outptr_interleave);
  805. outptr += ksize12;
  806. }
  807. outptr = outptr_base4;
  808. for (; x + 4 <= xmax; x += 4) {
  809. auto outptr_interleave = outptr;
  810. transpose_1x4_4_s(inptr, outptr_interleave);
  811. outptr += ksize4;
  812. }
  813. if (x < xmax) {
  814. std::memcpy(tmpbuff, inptr,
  815. sizeof(float) * (xmax - x) * PACK_C_SIZE);
  816. auto outptr_interleave = outptr;
  817. const float* tmp_ptr = &tmpbuff[0];
  818. transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave);
  819. outptr += ksize4;
  820. }
  821. outptr_base += 12 * 4;
  822. outptr_base4 += 4 * 4;
  823. }
  824. }
  825. };
  826. } // namespace aarch64
  827. } // namespace megdnn
  828. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台