You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sbgemm_kernel_8x4_neoversen2_impl.c 15 kB


  1. /***************************************************************************
  2. * Copyright (c) 2022,2025 The OpenBLAS Project
  3. * All rights reserved.
  4. * Redistribution and use in source and binary forms, with or without
  5. * modification, are permitted provided that the following conditions are
  6. * met:
  7. * 1. Redistributions of source code must retain the above copyright
  8. * notice, this list of conditions and the following disclaimer.
  9. * 2. Redistributions in binary form must reproduce the above copyright
  10. * notice, this list of conditions and the following disclaimer in
  11. * the documentation and/or other materials provided with the
  12. * distribution.
  13. * 3. Neither the name of the OpenBLAS project nor the names of
  14. * its contributors may be used to endorse or promote products
  15. * derived from this software without specific prior written permission.
  16. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  17. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  18. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  19. * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
  20. * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  21. * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  22. * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  23. * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  24. * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  25. * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  26. * POSSIBILITY OF SUCH DAMAGE.
  27. * *****************************************************************************/
  28. #include <arm_sve.h>
  29. #include <arm_neon.h>
  30. #include "common.h"
  31. #define INIT_C(M, N) mc##M##N = svdup_f32(0);
  32. #define MATMUL(M, N) mc##M##N = svbfmmla(mc##M##N, ma##M, mb##N);
  33. #define INIT_C_8x4 \
  34. do { \
  35. INIT_C(0, 0); \
  36. INIT_C(0, 1); \
  37. INIT_C(1, 0); \
  38. INIT_C(1, 1); \
  39. INIT_C(2, 0); \
  40. INIT_C(2, 1); \
  41. INIT_C(3, 0); \
  42. INIT_C(3, 1); \
  43. } while (0);
  44. #ifdef BGEMM
  45. #ifdef ALPHA_ONE
  46. #define UPDATE_C(PG16, PG32, PTR, SRC) \
  47. do { \
  48. tmp32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \
  49. tmp32 = svadd_z((PG32), SRC, tmp32); \
  50. tmp16 = svcvt_bf16_f32_z((PG32), tmp32); \
  51. tmp16 = svuzp1_bf16(tmp16, tmp16); \
  52. svst1_bf16((PG16), (PTR), tmp16); \
  53. } while (0)
  54. #else
  55. #define UPDATE_C(PG16, PG32, PTR, SRC) \
  56. do { \
  57. tmp32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \
  58. tmp32 = svmad_z((PG32), svalpha, SRC, tmp32); \
  59. tmp16 = svcvt_bf16_f32_z((PG32), tmp32); \
  60. tmp16 = svuzp1_bf16(tmp16, tmp16); \
  61. svst1_bf16((PG16), (PTR), tmp16); \
  62. } while (0)
  63. #endif
  64. #else
  65. #ifdef ALPHA_ONE
  66. #define UPDATE_C(PG16, PG32, PTR, SRC) \
  67. do { \
  68. tmp32 = svld1_f32((PG32), (PTR)); \
  69. tmp32 = svadd_z((PG32), SRC, tmp32); \
  70. svst1_f32((PG32), (PTR), tmp32); \
  71. } while (0);
  72. #else
  73. #define UPDATE_C(PG16, PG32, PTR, SRC) \
  74. do { \
  75. tmp32 = svld1_f32((PG32), (PTR)); \
  76. tmp32 = svmad_z((PG32), svalpha, SRC, tmp32); \
  77. svst1_f32((PG32), (PTR), tmp32); \
  78. } while (0);
  79. #endif
  80. #endif
  81. #ifdef BGEMM
  82. #define OUTPUT_FLOAT bfloat16_t
  83. #else
  84. #define OUTPUT_FLOAT float
  85. #endif
  86. #ifdef ALPHA_ONE
  87. static int gemm_kernel_neoversen2_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc)
  88. #else
  89. static int gemm_kernel_neoversen2_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc)
  90. #endif
  91. {
  92. BLASLONG pad_k = (k + 3) & ~3;
  93. svbfloat16_t ma0, ma1, ma2, ma3, mb0, mb1;
  94. svfloat32_t mc00, mc01, mc10, mc11, mc20, mc21, mc30, mc31,
  95. vc0, vc1, vc2, vc3, vc4, vc5, vc6, vc7;
  96. #ifndef ALPHA_ONE
  97. #ifdef BGEMM
  98. bfloat16_t alpha_bf16;
  99. memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t));
  100. svfloat32_t svalpha = svdup_f32(vcvtah_f32_bf16(alpha_bf16));
  101. #else
  102. svfloat32_t svalpha = svdup_f32(alpha);
  103. #endif
  104. #endif
  105. svbool_t pg32_first_4 = svdupq_b32(1, 1, 1, 1);
  106. svbool_t pg32_first_2 = svdupq_b32(1, 1, 0, 0);
  107. svbool_t pg32_first_1 = svdupq_b32(1, 0, 0, 0);
  108. svbool_t pg16_first_8 = svdupq_b16(1, 1, 1, 1, 1, 1, 1, 1);
  109. svbool_t pg16_first_4 = svdupq_b16(1, 1, 1, 1, 0, 0, 0, 0);
  110. #ifdef BGEMM
  111. svbool_t pg16_first_2 = svdupq_b16(1, 1, 0, 0, 0, 0, 0, 0);
  112. svbool_t pg16_first_1 = svdupq_b16(1, 0, 0, 0, 0, 0, 0, 0);
  113. #endif
  114. bfloat16_t *ptr_a = (bfloat16_t *)A;
  115. bfloat16_t *ptr_b = (bfloat16_t *)B;
  116. OUTPUT_FLOAT *ptr_c = (OUTPUT_FLOAT*)C;
  117. bfloat16_t *ptr_a0;
  118. bfloat16_t *ptr_b0;
  119. OUTPUT_FLOAT *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3;
  120. svfloat32_t tmp32;
  121. #ifdef BGEMM
  122. svbfloat16_t tmp16;
  123. #endif
  124. for (BLASLONG j = 0; j < n / 4; j++) {
  125. ptr_c0 = ptr_c;
  126. ptr_c1 = ptr_c0 + ldc;
  127. ptr_c2 = ptr_c1 + ldc;
  128. ptr_c3 = ptr_c2 + ldc;
  129. ptr_c += 4 * ldc;
  130. ptr_a = (bfloat16_t *)A;
  131. for (BLASLONG i = 0; i < m / 8; i++) {
  132. ptr_a0 = ptr_a;
  133. ptr_a += 8 * pad_k;
  134. ptr_b0 = ptr_b;
  135. INIT_C_8x4;
  136. for (BLASLONG p = 0; p < pad_k; p += 4) {
  137. ma0 = svld1_bf16(pg16_first_8, ptr_a0);
  138. ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8);
  139. ma2 = svld1_bf16(pg16_first_8, ptr_a0 + 16);
  140. ma3 = svld1_bf16(pg16_first_8, ptr_a0 + 24);
  141. mb0 = svld1_bf16(pg16_first_8, ptr_b0);
  142. mb1 = svld1_bf16(pg16_first_8, ptr_b0 + 8);
  143. MATMUL(0, 0); MATMUL(0, 1);
  144. MATMUL(1, 0); MATMUL(1, 1);
  145. MATMUL(2, 0); MATMUL(2, 1);
  146. MATMUL(3, 0); MATMUL(3, 1);
  147. ptr_a0 += 32;
  148. ptr_b0 += 16;
  149. }
  150. vc0 = svuzp1(mc00, mc10);
  151. vc1 = svuzp1(mc20, mc30);
  152. vc2 = svuzp2(mc00, mc10);
  153. vc3 = svuzp2(mc20, mc30);
  154. vc4 = svuzp1(mc01, mc11);
  155. vc5 = svuzp1(mc21, mc31);
  156. vc6 = svuzp2(mc01, mc11);
  157. vc7 = svuzp2(mc21, mc31);
  158. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0);
  159. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0+4, vc1);
  160. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, vc2);
  161. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1+4, vc3);
  162. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c2, vc4);
  163. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c2+4, vc5);
  164. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c3, vc6);
  165. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c3+4, vc7);
  166. ptr_c0 += 8;
  167. ptr_c1 += 8;
  168. ptr_c2 += 8;
  169. ptr_c3 += 8;
  170. }
  171. if (m & 4) {
  172. ptr_a0 = ptr_a;
  173. ptr_a += 4 * pad_k;
  174. ptr_b0 = ptr_b;
  175. INIT_C(0, 0); INIT_C(0, 1);
  176. INIT_C(1, 0); INIT_C(1, 1);
  177. for (BLASLONG p = 0; p < pad_k; p += 4) {
  178. ma0 = svld1_bf16(pg16_first_8, ptr_a0);
  179. ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8);
  180. mb0 = svld1_bf16(pg16_first_8, ptr_b0);
  181. mb1 = svld1_bf16(pg16_first_8, ptr_b0 + 8);
  182. MATMUL(0, 0); MATMUL(0, 1);
  183. MATMUL(1, 0); MATMUL(1, 1);
  184. ptr_a0 += 16;
  185. ptr_b0 += 16;
  186. }
  187. vc0 = svuzp1(mc00, mc10);
  188. vc1 = svuzp2(mc00, mc10);
  189. vc2 = svuzp1(mc01, mc11);
  190. vc3 = svuzp2(mc01, mc11);
  191. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0);
  192. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, vc1);
  193. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c2, vc2);
  194. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c3, vc3);
  195. ptr_c0 += 4;
  196. ptr_c1 += 4;
  197. ptr_c2 += 4;
  198. ptr_c3 += 4;
  199. }
  200. if (m & 2) {
  201. ptr_a0 = ptr_a;
  202. ptr_a += 2 * pad_k;
  203. ptr_b0 = ptr_b;
  204. INIT_C(0, 0); INIT_C(0, 1);
  205. for (BLASLONG p = 0; p < pad_k; p += 4) {
  206. ma0 = svld1_bf16(pg16_first_8, ptr_a0);
  207. mb0 = svld1_bf16(pg16_first_8, ptr_b0);
  208. mb1 = svld1_bf16(pg16_first_8, ptr_b0 + 8);
  209. MATMUL(0, 0); MATMUL(0, 1);
  210. ptr_a0 += 8;
  211. ptr_b0 += 16;
  212. }
  213. vc0 = svuzp1(mc00, mc00);
  214. vc1 = svuzp2(mc00, mc00);
  215. vc2 = svuzp1(mc01, mc01);
  216. vc3 = svuzp2(mc01, mc01);
  217. UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, vc0);
  218. UPDATE_C(pg16_first_2, pg32_first_2, ptr_c1, vc1);
  219. UPDATE_C(pg16_first_2, pg32_first_2, ptr_c2, vc2);
  220. UPDATE_C(pg16_first_2, pg32_first_2, ptr_c3, vc3);
  221. ptr_c0 += 2;
  222. ptr_c1 += 2;
  223. ptr_c2 += 2;
  224. ptr_c3 += 2;
  225. }
  226. if (m & 1) {
  227. ptr_a0 = ptr_a;
  228. ptr_b0 = ptr_b;
  229. INIT_C(0, 0); INIT_C(0, 1);
  230. for (BLASLONG p = 0; p < pad_k; p += 4) {
  231. ma0 = svld1_bf16(pg16_first_4, ptr_a0);
  232. mb0 = svld1_bf16(pg16_first_8, ptr_b0);
  233. mb1 = svld1_bf16(pg16_first_8, ptr_b0 + 8);
  234. MATMUL(0, 0); MATMUL(0, 1);
  235. ptr_a0 += 4;
  236. ptr_b0 += 16;
  237. }
  238. vc1 = svuzp2(mc00, mc00);
  239. vc3 = svuzp2(mc01, mc01);
  240. UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, mc00);
  241. UPDATE_C(pg16_first_1, pg32_first_1, ptr_c1, vc1);
  242. UPDATE_C(pg16_first_1, pg32_first_1, ptr_c2, mc01);
  243. UPDATE_C(pg16_first_1, pg32_first_1, ptr_c3, vc3);
  244. }
  245. ptr_b += 4 * pad_k;
  246. }
  247. if (n & 2) {
  248. ptr_c0 = ptr_c;
  249. ptr_c1 = ptr_c0 + ldc;
  250. ptr_c += 2 * ldc;
  251. ptr_a = (bfloat16_t *)A;
  252. for (BLASLONG i = 0; i < m / 8; i++) {
  253. ptr_a0 = ptr_a;
  254. ptr_a += 8 * pad_k;
  255. ptr_b0 = ptr_b;
  256. INIT_C(0, 0);
  257. INIT_C(1, 0);
  258. INIT_C(2, 0);
  259. INIT_C(3, 0);
  260. for (BLASLONG p = 0; p < pad_k; p += 4) {
  261. ma0 = svld1_bf16(pg16_first_8, ptr_a0);
  262. ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8);
  263. ma2 = svld1_bf16(pg16_first_8, ptr_a0 + 16);
  264. ma3 = svld1_bf16(pg16_first_8, ptr_a0 + 24);
  265. mb0 = svld1_bf16(pg16_first_8, ptr_b0);
  266. MATMUL(0, 0);
  267. MATMUL(1, 0);
  268. MATMUL(2, 0);
  269. MATMUL(3, 0);
  270. ptr_a0 += 32;
  271. ptr_b0 += 8;
  272. }
  273. vc0 = svuzp1(mc00, mc10);
  274. vc1 = svuzp1(mc20, mc30);
  275. vc2 = svuzp2(mc00, mc10);
  276. vc3 = svuzp2(mc20, mc30);
  277. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0);
  278. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0 + 4, vc1);
  279. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, vc2);
  280. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1 + 4, vc3);
  281. ptr_c0 += 8;
  282. ptr_c1 += 8;
  283. }
  284. if (m & 4) {
  285. ptr_a0 = ptr_a;
  286. ptr_a += 4 * pad_k;
  287. ptr_b0 = ptr_b;
  288. INIT_C(0, 0);
  289. INIT_C(1, 0);
  290. for (BLASLONG p = 0; p < pad_k; p += 4) {
  291. ma0 = svld1_bf16(pg16_first_8, ptr_a0);
  292. ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8);
  293. mb0 = svld1_bf16(pg16_first_8, ptr_b0);
  294. MATMUL(0, 0);
  295. MATMUL(1, 0);
  296. ptr_a0 += 16;
  297. ptr_b0 += 8;
  298. }
  299. vc0 = svuzp1(mc00, mc10);
  300. vc1 = svuzp2(mc00, mc10);
  301. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0);
  302. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, vc1);
  303. ptr_c0 += 4;
  304. ptr_c1 += 4;
  305. }
  306. if (m & 2) {
  307. ptr_a0 = ptr_a;
  308. ptr_a += 2 * pad_k;
  309. ptr_b0 = ptr_b;
  310. INIT_C(0, 0);
  311. for (BLASLONG p = 0; p < pad_k; p += 4) {
  312. ma0 = svld1_bf16(pg16_first_8, ptr_a0);
  313. mb0 = svld1_bf16(pg16_first_8, ptr_b0);
  314. MATMUL(0, 0);
  315. ptr_a0 += 8;
  316. ptr_b0 += 8;
  317. }
  318. vc0 = svuzp1(mc00, mc00);
  319. vc1 = svuzp2(mc00, mc00);
  320. UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, vc0);
  321. UPDATE_C(pg16_first_2, pg32_first_2, ptr_c1, vc1);
  322. ptr_c0 += 2;
  323. ptr_c1 += 2;
  324. }
  325. if (m & 1) {
  326. ptr_a0 = ptr_a;
  327. ptr_b0 = ptr_b;
  328. INIT_C(0, 0);
  329. for (BLASLONG p = 0; p < pad_k; p += 4) {
  330. ma0 = svld1_bf16(pg16_first_4, ptr_a0);
  331. mb0 = svld1_bf16(pg16_first_8, ptr_b0);
  332. MATMUL(0, 0);
  333. ptr_a0 += 4;
  334. ptr_b0 += 8;
  335. }
  336. vc1 = svuzp2(mc00, mc00);
  337. UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, mc00);
  338. UPDATE_C(pg16_first_1, pg32_first_1, ptr_c1, vc1);
  339. }
  340. ptr_b += 2 * pad_k;
  341. }
  342. if (n & 1) {
  343. ptr_c0 = ptr_c;
  344. ptr_a = (bfloat16_t *)A;
  345. for (BLASLONG i = 0; i < m / 8; i++) {
  346. ptr_a0 = ptr_a;
  347. ptr_a += 8 * pad_k;
  348. ptr_b0 = ptr_b;
  349. INIT_C(0, 0);
  350. INIT_C(1, 0);
  351. INIT_C(2, 0);
  352. INIT_C(3, 0);
  353. for (BLASLONG p = 0; p < pad_k; p += 4) {
  354. ma0 = svld1_bf16(pg16_first_8, ptr_a0);
  355. ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8);
  356. ma2 = svld1_bf16(pg16_first_8, ptr_a0 + 16);
  357. ma3 = svld1_bf16(pg16_first_8, ptr_a0 + 24);
  358. mb0 = svld1_bf16(pg16_first_4, ptr_b0);
  359. MATMUL(0, 0);
  360. MATMUL(1, 0);
  361. MATMUL(2, 0);
  362. MATMUL(3, 0);
  363. ptr_a0 += 32;
  364. ptr_b0 += 4;
  365. }
  366. vc0 = svuzp1(mc00, mc10);
  367. vc1 = svuzp1(mc20, mc30);
  368. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0);
  369. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0 + 4, vc1);
  370. ptr_c0 += 8;
  371. }
  372. if (m & 4) {
  373. ptr_a0 = ptr_a;
  374. ptr_a += 4 * pad_k;
  375. ptr_b0 = ptr_b;
  376. INIT_C(0, 0);
  377. INIT_C(1, 0);
  378. for (BLASLONG p = 0; p < pad_k; p += 4) {
  379. ma0 = svld1_bf16(pg16_first_8, ptr_a0);
  380. ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8);
  381. mb0 = svld1_bf16(pg16_first_4, ptr_b0);
  382. MATMUL(0, 0);
  383. MATMUL(1, 0);
  384. ptr_a0 += 16;
  385. ptr_b0 += 4;
  386. }
  387. vc0 = svuzp1(mc00, mc10);
  388. UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0);
  389. ptr_c0 += 4;
  390. }
  391. if (m & 2) {
  392. ptr_a0 = ptr_a;
  393. ptr_a += 2 * pad_k;
  394. ptr_b0 = ptr_b;
  395. INIT_C(0, 0);
  396. for (BLASLONG p = 0; p < pad_k; p += 4) {
  397. ma0 = svld1_bf16(pg16_first_8, ptr_a0);
  398. mb0 = svld1_bf16(pg16_first_4, ptr_b0);
  399. MATMUL(0, 0);
  400. ptr_a0 += 8;
  401. ptr_b0 += 4;
  402. }
  403. vc0 = svuzp1(mc00, mc00);
  404. UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, vc0);
  405. ptr_c0 += 2;
  406. }
  407. if (m & 1) {
  408. ptr_a0 = ptr_a;
  409. ptr_b0 = ptr_b;
  410. INIT_C(0, 0);
  411. for (BLASLONG p = 0; p < pad_k; p += 4) {
  412. ma0 = svld1_bf16(pg16_first_4, ptr_a0);
  413. mb0 = svld1_bf16(pg16_first_4, ptr_b0);
  414. MATMUL(0, 0);
  415. ptr_a0 += 4;
  416. ptr_b0 += 4;
  417. }
  418. UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, mc00);
  419. }
  420. }
  421. return 0;
  422. }