You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gemm_4.cpp 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "testutil.h"
  15. #if NCNN_INT8
  16. static void RandomizeA(ncnn::Mat& m, int transA, float absmax)
  17. {
  18. if (transA == 0)
  19. {
  20. const int h = m.dims == 3 ? m.c : m.h;
  21. for (int i = 0; i < h; i++)
  22. {
  23. float* p = m.dims == 3 ? m.channel(i) : m.row(i);
  24. float randabsmax = RandomFloat(absmax * 0.5f, absmax);
  25. randabsmax = ncnn::float16_to_float32(ncnn::float32_to_float16(randabsmax));
  26. randabsmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(randabsmax));
  27. for (int j = 0; j < m.w; j++)
  28. {
  29. p[j] = RandomFloat(-randabsmax, randabsmax);
  30. }
  31. // set random a and b
  32. p[RandomInt(0, m.w - 1)] = -randabsmax;
  33. p[RandomInt(0, m.w - 1)] = randabsmax;
  34. // drop 0.45 ~ 0.55
  35. for (int j = 0; j < m.w; j++)
  36. {
  37. float v = p[j] * (127.f / randabsmax);
  38. float vv = fabs(v - (int)v);
  39. float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  40. float hv = hp * (127.f / randabsmax);
  41. float hvv = fabs(hv - (int)hv);
  42. float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  43. float bv = bp * (127.f / randabsmax);
  44. float bvv = fabs(bv - (int)bv);
  45. while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f))
  46. {
  47. p[j] = RandomFloat(-randabsmax, randabsmax);
  48. v = p[j] * (127.f / randabsmax);
  49. vv = fabs(v - (int)v);
  50. hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  51. hv = hp * (127.f / randabsmax);
  52. hvv = fabs(hv - (int)hv);
  53. bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  54. bv = bp * (127.f / randabsmax);
  55. bvv = fabs(bv - (int)bv);
  56. }
  57. }
  58. }
  59. }
  60. else // if (transA == 1)
  61. {
  62. std::vector<float> randabsmaxes(m.w);
  63. for (int j = 0; j < m.w; j++)
  64. {
  65. float randabsmax = RandomFloat(absmax * 0.5f, absmax);
  66. randabsmax = ncnn::float16_to_float32(ncnn::float32_to_float16(randabsmax));
  67. randabsmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(randabsmax));
  68. randabsmaxes[j] = randabsmax;
  69. }
  70. const int h = m.dims == 3 ? m.c : m.h;
  71. for (int i = 0; i < h; i++)
  72. {
  73. float* p = m.dims == 3 ? m.channel(i) : m.row(i);
  74. for (int j = 0; j < m.w; j++)
  75. {
  76. const float randabsmax = randabsmaxes[j];
  77. p[j] = RandomFloat(-randabsmax, randabsmax);
  78. }
  79. // drop 0.45 ~ 0.55
  80. for (int j = 0; j < m.w; j++)
  81. {
  82. const float randabsmax = randabsmaxes[j];
  83. float v = p[j] * (127.f / randabsmax);
  84. float vv = fabs(v - (int)v);
  85. float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  86. float hv = hp * (127.f / randabsmax);
  87. float hvv = fabs(hv - (int)hv);
  88. float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  89. float bv = bp * (127.f / randabsmax);
  90. float bvv = fabs(bv - (int)bv);
  91. while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f))
  92. {
  93. p[j] = RandomFloat(-randabsmax, randabsmax);
  94. v = p[j] * (127.f / randabsmax);
  95. vv = fabs(v - (int)v);
  96. hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  97. hv = hp * (127.f / randabsmax);
  98. hvv = fabs(hv - (int)hv);
  99. bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  100. bv = bp * (127.f / randabsmax);
  101. bvv = fabs(bv - (int)bv);
  102. }
  103. }
  104. }
  105. for (int j = 0; j < m.w; j++)
  106. {
  107. const int randi0 = RandomInt(0, h - 1);
  108. const int randi1 = RandomInt(0, h - 1);
  109. float* p0 = m.dims == 3 ? m.channel(randi0) : m.row(randi0);
  110. float* p1 = m.dims == 3 ? m.channel(randi1) : m.row(randi1);
  111. const float randabsmax = randabsmaxes[j];
  112. // set random a and b
  113. p0[j] = -randabsmax;
  114. p1[j] = randabsmax;
  115. }
  116. }
  117. }
  118. static void RandomizeB(ncnn::Mat& m, float absmax)
  119. {
  120. absmax = ncnn::float16_to_float32(ncnn::float32_to_float16(absmax));
  121. absmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(absmax));
  122. const int h = m.dims == 3 ? m.c : m.h;
  123. float* p = m;
  124. for (int i = 0; i < h; i++)
  125. {
  126. float* p = m.dims == 3 ? m.channel(i) : m.row(i);
  127. for (int j = 0; j < m.w; j++)
  128. {
  129. p[j] = RandomFloat(-absmax, absmax);
  130. // drop 0.45 ~ 0.55
  131. float v = p[j] * (127.f / absmax);
  132. float vv = fabs(v - (int)v);
  133. float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  134. float hv = hp * (127.f / absmax);
  135. float hvv = fabs(hv - (int)hv);
  136. float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  137. float bv = bp * (127.f / absmax);
  138. float bvv = fabs(bv - (int)bv);
  139. while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f))
  140. {
  141. p[j] = RandomFloat(-absmax, absmax);
  142. v = p[j] * (127.f / absmax);
  143. vv = fabs(v - (int)v);
  144. hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  145. hv = hp * (127.f / absmax);
  146. hvv = fabs(hv - (int)hv);
  147. bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  148. bv = bp * (127.f / absmax);
  149. bvv = fabs(bv - (int)bv);
  150. }
  151. }
  152. }
  153. // set random a and b
  154. if (m.dims == 3)
  155. {
  156. m.channel(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = -absmax;
  157. m.channel(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = absmax;
  158. }
  159. else
  160. {
  161. m.row(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = -absmax;
  162. m.row(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = absmax;
  163. }
  164. }
  165. static int test_gemm_int8(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K, float alpha, int transA, int transB, int output_transpose)
  166. {
  167. ncnn::ParamDict pd;
  168. pd.set(0, alpha);
  169. pd.set(1, 1.f); // beta
  170. pd.set(2, transA);
  171. pd.set(3, transB);
  172. pd.set(14, output_transpose);
  173. pd.set(18, 2); // int8_scale_term
  174. pd.set(20, TILE_M);
  175. pd.set(21, TILE_N);
  176. pd.set(22, TILE_K);
  177. std::vector<ncnn::Mat> weights(0);
  178. std::vector<ncnn::Mat> a(2);
  179. a[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M);
  180. a[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K);
  181. RandomizeA(a[0], transA, 10.f);
  182. RandomizeB(a[1], 10.f);
  183. int ret = test_layer("Gemm", pd, weights, a);
  184. if (ret != 0)
  185. {
  186. fprintf(stderr, "test_gemm_int8 failed M=%d N=%d K=%d TILE_M=%d TILE_N=%d TILE_K=%d alpha=%f transA=%d transB=%d output_transpose=%d\n", M, N, K, TILE_M, TILE_N, TILE_K, alpha, transA, transB, output_transpose);
  187. }
  188. return ret;
  189. }
  190. static int test_gemm_0(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K)
  191. {
  192. return 0
  193. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 0)
  194. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 0)
  195. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 0)
  196. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 0)
  197. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 1)
  198. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 1)
  199. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 1)
  200. || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 1);
  201. }
  202. #endif // NCNN_INT8
  203. int main()
  204. {
  205. SRAND(7767517);
  206. #if NCNN_INT8
  207. int mnk[][3] = {
  208. {1, 1, 1},
  209. {2, 2, 2},
  210. {3, 3, 3},
  211. {4, 4, 4},
  212. {5, 5, 5},
  213. {6, 6, 6},
  214. {7, 7, 7},
  215. {8, 8, 8},
  216. {15, 15, 15},
  217. {16, 16, 16},
  218. {24, 24, 24},
  219. {31, 31, 31},
  220. {31, 32, 31},
  221. {32, 31, 32},
  222. {32, 32, 32},
  223. {20, 32, 20},
  224. {40, 40, 40},
  225. {47, 47, 47},
  226. {48, 48, 48},
  227. {52, 52, 52},
  228. {63, 64, 63},
  229. {64, 63, 64},
  230. {64, 64, 64}
  231. };
  232. int tile_mnk[][3] = {
  233. {1, 1, 1},
  234. {2, 2, 2},
  235. {4, 4, 4},
  236. {8, 8, 8},
  237. {12, 12, 12},
  238. {16, 16, 16},
  239. {20, 20, 20},
  240. {24, 24, 24},
  241. {28, 28, 28}
  242. };
  243. int mnk_count = sizeof(mnk) / sizeof(int) / 3;
  244. int tile_mnk_count = sizeof(tile_mnk) / sizeof(int) / 3;
  245. for (int i = 0; i < mnk_count; i++)
  246. {
  247. int M = mnk[i][0];
  248. int N = mnk[i][1];
  249. int K = mnk[i][2];
  250. for (int j = 0; j < tile_mnk_count; j++)
  251. {
  252. int TILE_M = tile_mnk[j][0];
  253. int TILE_N = tile_mnk[j][1];
  254. int TILE_K = tile_mnk[j][2];
  255. if (TILE_M >= M && TILE_N >= N && TILE_K >= K)
  256. continue;
  257. int ret = test_gemm_0(M, N, K, TILE_M, TILE_N, TILE_K);
  258. if (ret != 0)
  259. return ret;
  260. }
  261. // test no tiling
  262. int ret = test_gemm_0(M, N, K, 100, 100, 100);
  263. if (ret != 0)
  264. return ret;
  265. }
  266. #else
  267. // test nothing for non-int8 build
  268. #endif
  269. return 0;
  270. }