You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gemm_4.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. // Copyright 2024 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "testutil.h"
  4. #if NCNN_INT8
  5. static void RandomizeA(ncnn::Mat& m, int transA, float absmax)
  6. {
  7. if (transA == 0)
  8. {
  9. const int h = m.dims == 3 ? m.c : m.h;
  10. for (int i = 0; i < h; i++)
  11. {
  12. float* p = m.dims == 3 ? m.channel(i) : m.row(i);
  13. float randabsmax = RandomFloat(absmax * 0.5f, absmax);
  14. randabsmax = ncnn::float16_to_float32(ncnn::float32_to_float16(randabsmax));
  15. randabsmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(randabsmax));
  16. for (int j = 0; j < m.w; j++)
  17. {
  18. p[j] = RandomFloat(-randabsmax, randabsmax);
  19. }
  20. // set random a and b
  21. p[RandomInt(0, m.w - 1)] = -randabsmax;
  22. p[RandomInt(0, m.w - 1)] = randabsmax;
  23. // drop 0.45 ~ 0.55
  24. for (int j = 0; j < m.w; j++)
  25. {
  26. float v = p[j] * (127.f / randabsmax);
  27. float vv = fabs(v - (int)v);
  28. float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  29. float hv = hp * (127.f / randabsmax);
  30. float hvv = fabs(hv - (int)hv);
  31. float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  32. float bv = bp * (127.f / randabsmax);
  33. float bvv = fabs(bv - (int)bv);
  34. while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f))
  35. {
  36. p[j] = RandomFloat(-randabsmax, randabsmax);
  37. v = p[j] * (127.f / randabsmax);
  38. vv = fabs(v - (int)v);
  39. hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  40. hv = hp * (127.f / randabsmax);
  41. hvv = fabs(hv - (int)hv);
  42. bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  43. bv = bp * (127.f / randabsmax);
  44. bvv = fabs(bv - (int)bv);
  45. }
  46. }
  47. }
  48. }
  49. else // if (transA == 1)
  50. {
  51. std::vector<float> randabsmaxes(m.w);
  52. for (int j = 0; j < m.w; j++)
  53. {
  54. float randabsmax = RandomFloat(absmax * 0.5f, absmax);
  55. randabsmax = ncnn::float16_to_float32(ncnn::float32_to_float16(randabsmax));
  56. randabsmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(randabsmax));
  57. randabsmaxes[j] = randabsmax;
  58. }
  59. const int h = m.dims == 3 ? m.c : m.h;
  60. for (int i = 0; i < h; i++)
  61. {
  62. float* p = m.dims == 3 ? m.channel(i) : m.row(i);
  63. for (int j = 0; j < m.w; j++)
  64. {
  65. const float randabsmax = randabsmaxes[j];
  66. p[j] = RandomFloat(-randabsmax, randabsmax);
  67. }
  68. // drop 0.45 ~ 0.55
  69. for (int j = 0; j < m.w; j++)
  70. {
  71. const float randabsmax = randabsmaxes[j];
  72. float v = p[j] * (127.f / randabsmax);
  73. float vv = fabs(v - (int)v);
  74. float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  75. float hv = hp * (127.f / randabsmax);
  76. float hvv = fabs(hv - (int)hv);
  77. float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  78. float bv = bp * (127.f / randabsmax);
  79. float bvv = fabs(bv - (int)bv);
  80. while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f))
  81. {
  82. p[j] = RandomFloat(-randabsmax, randabsmax);
  83. v = p[j] * (127.f / randabsmax);
  84. vv = fabs(v - (int)v);
  85. hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  86. hv = hp * (127.f / randabsmax);
  87. hvv = fabs(hv - (int)hv);
  88. bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  89. bv = bp * (127.f / randabsmax);
  90. bvv = fabs(bv - (int)bv);
  91. }
  92. }
  93. }
  94. for (int j = 0; j < m.w; j++)
  95. {
  96. const int randi0 = RandomInt(0, h - 1);
  97. const int randi1 = RandomInt(0, h - 1);
  98. float* p0 = m.dims == 3 ? m.channel(randi0) : m.row(randi0);
  99. float* p1 = m.dims == 3 ? m.channel(randi1) : m.row(randi1);
  100. const float randabsmax = randabsmaxes[j];
  101. // set random a and b
  102. p0[j] = -randabsmax;
  103. p1[j] = randabsmax;
  104. }
  105. }
  106. }
  107. static void RandomizeB(ncnn::Mat& m, float absmax)
  108. {
  109. absmax = ncnn::float16_to_float32(ncnn::float32_to_float16(absmax));
  110. absmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(absmax));
  111. const int h = m.dims == 3 ? m.c : m.h;
  112. float* p = m;
  113. for (int i = 0; i < h; i++)
  114. {
  115. float* p = m.dims == 3 ? m.channel(i) : m.row(i);
  116. for (int j = 0; j < m.w; j++)
  117. {
  118. p[j] = RandomFloat(-absmax, absmax);
  119. // drop 0.45 ~ 0.55
  120. float v = p[j] * (127.f / absmax);
  121. float vv = fabs(v - (int)v);
  122. float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  123. float hv = hp * (127.f / absmax);
  124. float hvv = fabs(hv - (int)hv);
  125. float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  126. float bv = bp * (127.f / absmax);
  127. float bvv = fabs(bv - (int)bv);
  128. while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f))
  129. {
  130. p[j] = RandomFloat(-absmax, absmax);
  131. v = p[j] * (127.f / absmax);
  132. vv = fabs(v - (int)v);
  133. hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  134. hv = hp * (127.f / absmax);
  135. hvv = fabs(hv - (int)hv);
  136. bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  137. bv = bp * (127.f / absmax);
  138. bvv = fabs(bv - (int)bv);
  139. }
  140. }
  141. }
  142. // set random a and b
  143. if (m.dims == 3)
  144. {
  145. m.channel(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = -absmax;
  146. m.channel(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = absmax;
  147. }
  148. else
  149. {
  150. m.row(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = -absmax;
  151. m.row(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = absmax;
  152. }
  153. }
  154. static int test_gemm_int8(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K, float alpha, int transA, int transB, int output_transpose)
  155. {
  156. ncnn::ParamDict pd;
  157. pd.set(0, alpha);
  158. pd.set(1, 1.f); // beta
  159. pd.set(2, transA);
  160. pd.set(3, transB);
  161. pd.set(14, output_transpose);
  162. pd.set(18, 2); // int8_scale_term
  163. pd.set(20, TILE_M);
  164. pd.set(21, TILE_N);
  165. pd.set(22, TILE_K);
  166. std::vector<ncnn::Mat> weights(0);
  167. std::vector<ncnn::Mat> a(2);
  168. a[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M);
  169. a[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K);
  170. RandomizeA(a[0], transA, 10.f);
  171. RandomizeB(a[1], 10.f);
  172. int ret = test_layer("Gemm", pd, weights, a);
  173. if (ret != 0)
  174. {
  175. fprintf(stderr, "test_gemm_int8 failed M=%d N=%d K=%d TILE_M=%d TILE_N=%d TILE_K=%d alpha=%f transA=%d transB=%d output_transpose=%d\n", M, N, K, TILE_M, TILE_N, TILE_K, alpha, transA, transB, output_transpose);
  176. }
  177. return ret;
  178. }
  179. static int test_gemm_0(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K)
  180. {
  181. return 0
  182. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 0)
  183. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 0)
  184. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 0)
  185. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 0)
  186. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 1)
  187. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 1)
  188. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 1)
  189. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 1);
  190. }
  191. #endif // NCNN_INT8
  192. int main()
  193. {
  194. SRAND(7767517);
  195. #if NCNN_INT8
  196. int mnk[][3] = {
  197. {1, 1, 1},
  198. {2, 2, 2},
  199. {3, 3, 3},
  200. {4, 4, 4},
  201. {5, 5, 5},
  202. {6, 6, 6},
  203. {7, 7, 7},
  204. {8, 8, 8},
  205. {15, 15, 15},
  206. {16, 16, 16},
  207. {24, 24, 24},
  208. {31, 31, 31},
  209. {31, 32, 31},
  210. {32, 31, 32},
  211. {32, 32, 32},
  212. {20, 32, 20},
  213. {40, 40, 40},
  214. {47, 47, 47},
  215. {48, 48, 48},
  216. {52, 52, 52},
  217. {63, 64, 63},
  218. {64, 63, 64},
  219. {64, 64, 64}
  220. };
  221. int tile_mnk[][3] = {
  222. {1, 1, 1},
  223. {2, 2, 2},
  224. {4, 4, 4},
  225. {8, 8, 8},
  226. {12, 12, 12},
  227. {16, 16, 16},
  228. {20, 20, 20},
  229. {24, 24, 24},
  230. {28, 28, 28}
  231. };
  232. int mnk_count = sizeof(mnk) / sizeof(int) / 3;
  233. int tile_mnk_count = sizeof(tile_mnk) / sizeof(int) / 3;
  234. for (int i = 0; i < mnk_count; i++)
  235. {
  236. int M = mnk[i][0];
  237. int N = mnk[i][1];
  238. int K = mnk[i][2];
  239. for (int j = 0; j < tile_mnk_count; j++)
  240. {
  241. int TILE_M = tile_mnk[j][0];
  242. int TILE_N = tile_mnk[j][1];
  243. int TILE_K = tile_mnk[j][2];
  244. if (TILE_M >= M && TILE_N >= N && TILE_K >= K)
  245. continue;
  246. int ret = test_gemm_0(M, N, K, TILE_M, TILE_N, TILE_K);
  247. if (ret != 0)
  248. return ret;
  249. }
  250. // test no tiling
  251. int ret = test_gemm_0(M, N, K, 100, 100, 100);
  252. if (ret != 0)
  253. return ret;
  254. }
  255. #else
  256. // test nothing for non-int8 build
  257. #endif
  258. return 0;
  259. }