You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemmkernel_2x2.c 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. /***************************************************************************
  2. * Copyright (c) 2025, The OpenBLAS Project
  3. * All rights reserved.
  4. * Redistribution and use in source and binary forms, with or without
  5. * modification, are permitted provided that the following conditions are
  6. * met:
  7. * 1. Redistributions of source code must retain the above copyright
  8. * notice, this list of conditions and the following disclaimer.
  9. * 2. Redistributions in binary form must reproduce the above copyright
  10. * notice, this list of conditions and the following disclaimer in
  11. * the documentation and/or other materials provided with the
  12. * distribution.
  13. * 3. Neither the name of the OpenBLAS project nor the names of
  14. * its contributors may be used to endorse or promote products
  15. * derived from this software without specific prior written permission.
  16. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  17. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  18. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  19. * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
  20. * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  21. * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  22. * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  23. * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  24. * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  25. * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  26. * POSSIBILITY OF SUCH DAMAGE.
  27. * *****************************************************************************/
  28. #include "common.h"
  29. #if defined(BFLOAT16) && defined(BFLOAT16CONVERSION)
  30. static float
  31. bfloat16tof32 (bfloat16 value)
  32. {
  33. blasint one = 1;
  34. float result;
  35. sbf16tos_(&one, &value, &one, &result, &one);
  36. return result;
  37. }
  38. #ifdef BGEMM
  39. static bfloat16 f32tobfloat16(float value) {
  40. blasint one = 1;
  41. bfloat16 result;
  42. sbstobf16_(&one, &value, &one, &result, &one);
  43. return result;
  44. }
  45. #endif
  46. #ifdef BGEMM
  47. #define ALPHA bfloat16tof32(alpha)
  48. #define BF16TOF32(x) (bfloat16tof32(x))
  49. #define F32TOBF16(x) (f32tobfloat16(x))
  50. #else
  51. #define ALPHA alpha
  52. #define BF16TOF32(x) (bfloat16tof32(x))
  53. #define F32TOBF16(x) x
  54. #endif
  55. #else
  56. #define ALPHA alpha
  57. #define BF16TOF32(x) x
  58. #define F32TOBF16(x) x
  59. #endif
  60. int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc
  61. #ifdef TRMMKERNEL
  62. ,BLASLONG offset
  63. #endif
  64. )
  65. {
  66. BLASLONG i,j,k;
  67. FLOAT *C0,*C1;
  68. IFLOAT *ptrba,*ptrbb;
  69. #ifdef BGEMM
  70. float res0,res1,res2,res3;
  71. #else
  72. FLOAT res0,res1,res2,res3;
  73. #endif
  74. IFLOAT load0,load1,load2,load3,load4,load5,load6,load7;
  75. for (j=0; j<bn/2; j+=1)
  76. {
  77. C0 = C;
  78. C1 = C0+ldc;
  79. ptrba = ba;
  80. for (i=0; i<bm/2; i+=1)
  81. {
  82. ptrbb = bb;
  83. res0 = 0;
  84. res1 = 0;
  85. res2 = 0;
  86. res3 = 0;
  87. for (k=0; k<bk/4; k+=1)
  88. {
  89. load0 = ptrba[2*0+0];
  90. load1 = ptrbb[2*0+0];
  91. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  92. load2 = ptrba[2*0+1];
  93. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  94. load3 = ptrbb[2*0+1];
  95. res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
  96. res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
  97. load4 = ptrba[2*1+0];
  98. load5 = ptrbb[2*1+0];
  99. res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
  100. load6 = ptrba[2*1+1];
  101. res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
  102. load7 = ptrbb[2*1+1];
  103. res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
  104. res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
  105. load0 = ptrba[2*2+0];
  106. load1 = ptrbb[2*2+0];
  107. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  108. load2 = ptrba[2*2+1];
  109. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  110. load3 = ptrbb[2*2+1];
  111. res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
  112. res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
  113. load4 = ptrba[2*3+0];
  114. load5 = ptrbb[2*3+0];
  115. res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
  116. load6 = ptrba[2*3+1];
  117. res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
  118. load7 = ptrbb[2*3+1];
  119. res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
  120. res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
  121. ptrba = ptrba+8;
  122. ptrbb = ptrbb+8;
  123. }
  124. for (k=0; k<(bk&3); k+=1)
  125. {
  126. load0 = ptrba[2*0+0];
  127. load1 = ptrbb[2*0+0];
  128. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  129. load2 = ptrba[2*0+1];
  130. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  131. load3 = ptrbb[2*0+1];
  132. res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
  133. res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
  134. ptrba = ptrba+2;
  135. ptrbb = ptrbb+2;
  136. }
  137. res0 = res0*ALPHA;
  138. C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
  139. res1 = res1*ALPHA;
  140. C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1);
  141. res2 = res2*ALPHA;
  142. C1[0] = F32TOBF16(BF16TOF32(C1[0])+res2);
  143. res3 = res3*ALPHA;
  144. C1[1] = F32TOBF16(BF16TOF32(C1[1])+res3);
  145. C0 = C0+2;
  146. C1 = C1+2;
  147. }
  148. for (i=0; i<(bm&1); i+=1)
  149. {
  150. ptrbb = bb;
  151. res0 = 0;
  152. res1 = 0;
  153. for (k=0; k<bk; k+=1)
  154. {
  155. load0 = ptrba[0+0];
  156. load1 = ptrbb[2*0+0];
  157. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  158. load2 = ptrbb[2*0+1];
  159. res1 = res1+BF16TOF32(load0)*BF16TOF32(load2);
  160. ptrba = ptrba+1;
  161. ptrbb = ptrbb+2;
  162. }
  163. res0 = res0*ALPHA;
  164. C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
  165. res1 = res1*ALPHA;
  166. C1[0] = F32TOBF16(BF16TOF32(C1[0])+res1);
  167. C0 = C0+1;
  168. C1 = C1+1;
  169. }
  170. k = (bk<<1);
  171. bb = bb+k;
  172. i = (ldc<<1);
  173. C = C+i;
  174. }
  175. for (j=0; j<(bn&1); j+=1)
  176. {
  177. C0 = C;
  178. ptrba = ba;
  179. for (i=0; i<bm/2; i+=1)
  180. {
  181. ptrbb = bb;
  182. res0 = 0;
  183. res1 = 0;
  184. for (k=0; k<bk; k+=1)
  185. {
  186. load0 = ptrba[2*0+0];
  187. load1 = ptrbb[0+0];
  188. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  189. load2 = ptrba[2*0+1];
  190. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  191. ptrba = ptrba+2;
  192. ptrbb = ptrbb+1;
  193. }
  194. res0 = res0*ALPHA;
  195. C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
  196. res1 = res1*ALPHA;
  197. C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1);
  198. C0 = C0+2;
  199. }
  200. for (i=0; i<(bm&1); i+=1)
  201. {
  202. ptrbb = bb;
  203. res0 = 0;
  204. for (k=0; k<bk; k+=1)
  205. {
  206. load0 = ptrba[0+0];
  207. load1 = ptrbb[0+0];
  208. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  209. ptrba = ptrba+1;
  210. ptrbb = ptrbb+1;
  211. }
  212. res0 = res0*ALPHA;
  213. C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
  214. C0 = C0+1;
  215. }
  216. k = (bk<<0);
  217. bb = bb+k;
  218. C = C+ldc;
  219. }
  220. return 0;
  221. }