You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gemm_3.cpp 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. // Copyright 2024 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "testutil.h"
  4. #if NCNN_INT8
  5. static void RandomizeA(ncnn::Mat& m, int transA, float absmax)
  6. {
  7. if (transA == 0)
  8. {
  9. const int h = m.dims == 3 ? m.c : m.h;
  10. for (int i = 0; i < h; i++)
  11. {
  12. float* p = m.dims == 3 ? m.channel(i) : m.row(i);
  13. float randabsmax = RandomFloat(absmax * 0.5f, absmax);
  14. randabsmax = ncnn::float16_to_float32(ncnn::float32_to_float16(randabsmax));
  15. randabsmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(randabsmax));
  16. for (int j = 0; j < m.w; j++)
  17. {
  18. p[j] = RandomFloat(-randabsmax, randabsmax);
  19. }
  20. // set random a and b
  21. p[RandomInt(0, m.w - 1)] = -randabsmax;
  22. p[RandomInt(0, m.w - 1)] = randabsmax;
  23. // drop 0.45 ~ 0.55
  24. for (int j = 0; j < m.w; j++)
  25. {
  26. float v = p[j] * (127.f / randabsmax);
  27. float vv = fabs(v - (int)v);
  28. float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  29. float hv = hp * (127.f / randabsmax);
  30. float hvv = fabs(hv - (int)hv);
  31. float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  32. float bv = bp * (127.f / randabsmax);
  33. float bvv = fabs(bv - (int)bv);
  34. while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f))
  35. {
  36. p[j] = RandomFloat(-randabsmax, randabsmax);
  37. v = p[j] * (127.f / randabsmax);
  38. vv = fabs(v - (int)v);
  39. hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  40. hv = hp * (127.f / randabsmax);
  41. hvv = fabs(hv - (int)hv);
  42. bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  43. bv = bp * (127.f / randabsmax);
  44. bvv = fabs(bv - (int)bv);
  45. }
  46. }
  47. }
  48. }
  49. else // if (transA == 1)
  50. {
  51. std::vector<float> randabsmaxes(m.w);
  52. for (int j = 0; j < m.w; j++)
  53. {
  54. float randabsmax = RandomFloat(absmax * 0.5f, absmax);
  55. randabsmax = ncnn::float16_to_float32(ncnn::float32_to_float16(randabsmax));
  56. randabsmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(randabsmax));
  57. randabsmaxes[j] = randabsmax;
  58. }
  59. const int h = m.dims == 3 ? m.c : m.h;
  60. for (int i = 0; i < h; i++)
  61. {
  62. float* p = m.dims == 3 ? m.channel(i) : m.row(i);
  63. for (int j = 0; j < m.w; j++)
  64. {
  65. const float randabsmax = randabsmaxes[j];
  66. p[j] = RandomFloat(-randabsmax, randabsmax);
  67. }
  68. // drop 0.45 ~ 0.55
  69. for (int j = 0; j < m.w; j++)
  70. {
  71. const float randabsmax = randabsmaxes[j];
  72. float v = p[j] * (127.f / randabsmax);
  73. float vv = fabs(v - (int)v);
  74. float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  75. float hv = hp * (127.f / randabsmax);
  76. float hvv = fabs(hv - (int)hv);
  77. float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  78. float bv = bp * (127.f / randabsmax);
  79. float bvv = fabs(bv - (int)bv);
  80. while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f))
  81. {
  82. p[j] = RandomFloat(-randabsmax, randabsmax);
  83. v = p[j] * (127.f / randabsmax);
  84. vv = fabs(v - (int)v);
  85. hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  86. hv = hp * (127.f / randabsmax);
  87. hvv = fabs(hv - (int)hv);
  88. bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  89. bv = bp * (127.f / randabsmax);
  90. bvv = fabs(bv - (int)bv);
  91. }
  92. }
  93. }
  94. for (int j = 0; j < m.w; j++)
  95. {
  96. const int randi0 = RandomInt(0, h - 1);
  97. const int randi1 = RandomInt(0, h - 1);
  98. float* p0 = m.dims == 3 ? m.channel(randi0) : m.row(randi0);
  99. float* p1 = m.dims == 3 ? m.channel(randi1) : m.row(randi1);
  100. const float randabsmax = randabsmaxes[j];
  101. // set random a and b
  102. p0[j] = -randabsmax;
  103. p1[j] = randabsmax;
  104. }
  105. }
  106. }
  107. static void RandomizeB(ncnn::Mat& m, float absmax)
  108. {
  109. absmax = ncnn::float16_to_float32(ncnn::float32_to_float16(absmax));
  110. absmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(absmax));
  111. const int h = m.dims == 3 ? m.c : m.h;
  112. float* p = m;
  113. for (int i = 0; i < h; i++)
  114. {
  115. float* p = m.dims == 3 ? m.channel(i) : m.row(i);
  116. for (int j = 0; j < m.w; j++)
  117. {
  118. p[j] = RandomFloat(-absmax, absmax);
  119. // drop 0.45 ~ 0.55
  120. float v = p[j] * (127.f / absmax);
  121. float vv = fabs(v - (int)v);
  122. float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  123. float hv = hp * (127.f / absmax);
  124. float hvv = fabs(hv - (int)hv);
  125. float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  126. float bv = bp * (127.f / absmax);
  127. float bvv = fabs(bv - (int)bv);
  128. while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f))
  129. {
  130. p[j] = RandomFloat(-absmax, absmax);
  131. v = p[j] * (127.f / absmax);
  132. vv = fabs(v - (int)v);
  133. hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  134. hv = hp * (127.f / absmax);
  135. hvv = fabs(hv - (int)hv);
  136. bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  137. bv = bp * (127.f / absmax);
  138. bvv = fabs(bv - (int)bv);
  139. }
  140. }
  141. }
  142. // set random a and b
  143. if (m.dims == 3)
  144. {
  145. m.channel(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = -absmax;
  146. m.channel(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = absmax;
  147. }
  148. else
  149. {
  150. m.row(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = -absmax;
  151. m.row(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = absmax;
  152. }
  153. }
  154. static int test_gemm_int8(int M, int N, int K, float alpha, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int output_N1M)
  155. {
  156. ncnn::ParamDict pd;
  157. pd.set(0, alpha);
  158. pd.set(1, 1.f); // beta
  159. pd.set(2, transA);
  160. pd.set(3, transB);
  161. pd.set(4, constantA);
  162. pd.set(5, constantB);
  163. pd.set(6, 1);
  164. pd.set(7, M);
  165. pd.set(8, N);
  166. pd.set(9, K);
  167. pd.set(10, -1);
  168. pd.set(11, output_N1M);
  169. pd.set(13, output_elemtype);
  170. pd.set(14, output_transpose);
  171. pd.set(18, 2); // int8_scale_term
  172. std::vector<ncnn::Mat> weights;
  173. if (constantA) weights.push_back(transA ? RandomS8Mat(M, K) : RandomS8Mat(K, M));
  174. if (constantB) weights.push_back(transB ? RandomS8Mat(K, N) : RandomS8Mat(N, K));
  175. if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f));
  176. if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f));
  177. std::vector<ncnn::Mat> a;
  178. if (!constantA)
  179. {
  180. a.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M)));
  181. RandomizeA(a[a.size() - 1], transA, 10.f);
  182. }
  183. if (!constantB)
  184. {
  185. a.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K)));
  186. RandomizeB(a[a.size() - 1], 10.f);
  187. }
  188. int ret = test_layer("Gemm", pd, weights, a);
  189. if (ret != 0)
  190. {
  191. fprintf(stderr, "test_gemm_int8 failed M=%d N=%d K=%d alpha=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d output_N1M=%d\n", M, N, K, alpha, transA, transB, output_elemtype, output_transpose, constantA, constantB, output_N1M);
  192. }
  193. return ret;
  194. }
  195. static int test_gemm_int8_bias(int M, int N, int K, const ncnn::Mat& C, float alpha, float beta, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int constantC)
  196. {
  197. int broadcast_type_C = 0;
  198. if (C.dims == 1 && C.w == 1)
  199. {
  200. // scalar
  201. broadcast_type_C = 0;
  202. }
  203. if (C.dims == 1 && C.w == M)
  204. {
  205. // M
  206. // auto broadcast from h to w is the ncnn-style convention
  207. broadcast_type_C = 1;
  208. }
  209. if (C.dims == 1 && C.w == N)
  210. {
  211. // N
  212. broadcast_type_C = 4;
  213. }
  214. if (C.dims == 2 && C.w == 1 && C.h == M)
  215. {
  216. // Mx1
  217. broadcast_type_C = 2;
  218. }
  219. if (C.dims == 2 && C.w == N && C.h == M)
  220. {
  221. // MxN
  222. broadcast_type_C = 3;
  223. }
  224. if (C.dims == 2 && C.w == N && C.h == 1)
  225. {
  226. // 1xN
  227. broadcast_type_C = 4;
  228. }
  229. ncnn::ParamDict pd;
  230. pd.set(0, alpha);
  231. pd.set(1, beta);
  232. pd.set(2, transA);
  233. pd.set(3, transB);
  234. pd.set(4, constantA);
  235. pd.set(5, constantB);
  236. pd.set(6, constantC);
  237. pd.set(7, M);
  238. pd.set(8, N);
  239. pd.set(9, K);
  240. pd.set(10, broadcast_type_C);
  241. // pd.set(12, 1); // output_elempack
  242. pd.set(13, output_elemtype);
  243. pd.set(14, output_transpose);
  244. pd.set(18, 2); // int8_scale_term
  245. std::vector<ncnn::Mat> weights;
  246. if (constantA) weights.push_back(transA ? RandomS8Mat(M, K) : RandomS8Mat(K, M));
  247. if (constantB) weights.push_back(transB ? RandomS8Mat(K, N) : RandomS8Mat(N, K));
  248. if (constantC) weights.push_back(C);
  249. if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f));
  250. if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f));
  251. std::vector<ncnn::Mat> a;
  252. if (!constantA)
  253. {
  254. a.push_back(transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M));
  255. RandomizeA(a[a.size() - 1], transA, 10.f);
  256. }
  257. if (!constantB)
  258. {
  259. a.push_back(transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K));
  260. RandomizeB(a[a.size() - 1], 10.f);
  261. }
  262. if (!constantC) a.push_back(C);
  263. int ret = test_layer("Gemm", pd, weights, a);
  264. if (ret != 0)
  265. {
  266. fprintf(stderr, "test_gemm_int8_bias failed M=%d N=%d K=%d C.dims=%d C=(%d %d %d) alpha=%f beta=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d constantC=%d\n", M, N, K, C.dims, C.w, C.h, C.c, alpha, beta, transA, transB, output_elemtype, output_transpose, constantA, constantB, constantC);
  267. }
  268. return ret;
  269. }
  270. static int test_gemm_int8_fp16s(int M, int N, int K, float alpha, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int output_N1M)
  271. {
  272. ncnn::ParamDict pd;
  273. pd.set(0, alpha);
  274. pd.set(1, 1.f); // beta
  275. pd.set(2, transA);
  276. pd.set(3, transB);
  277. pd.set(4, constantA);
  278. pd.set(5, constantB);
  279. pd.set(6, 1);
  280. pd.set(7, M);
  281. pd.set(8, N);
  282. pd.set(9, K);
  283. pd.set(10, -1);
  284. pd.set(11, output_N1M);
  285. pd.set(13, output_elemtype);
  286. pd.set(14, output_transpose);
  287. pd.set(18, 2); // int8_scale_term
  288. std::vector<ncnn::Mat> weights;
  289. if (constantA) weights.push_back(transA ? RandomS8Mat(M, K) : RandomS8Mat(K, M));
  290. if (constantB) weights.push_back(transB ? RandomS8Mat(K, N) : RandomS8Mat(N, K));
  291. if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f));
  292. if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f));
  293. std::vector<ncnn::Mat> a;
  294. if (!constantA)
  295. {
  296. a.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M)));
  297. RandomizeA(a[a.size() - 1], transA, 10.f);
  298. }
  299. if (!constantB)
  300. {
  301. a.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K)));
  302. RandomizeB(a[a.size() - 1], 10.f);
  303. }
  304. ncnn::Option opt;
  305. opt.num_threads = 1;
  306. opt.use_packing_layout = true;
  307. opt.use_fp16_packed = false;
  308. opt.use_fp16_storage = true;
  309. opt.use_fp16_arithmetic = false;
  310. opt.use_bf16_storage = false;
  311. float epsilon = 0.001;
  312. int ret = test_layer_opt("Gemm", pd, weights, opt, a, 1, epsilon);
  313. if (ret != 0)
  314. {
  315. fprintf(stderr, "test_gemm_int8_fp16s failed M=%d N=%d K=%d alpha=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d output_N1M=%d\n", M, N, K, alpha, transA, transB, output_elemtype, output_transpose, constantA, constantB, output_N1M);
  316. return ret;
  317. }
  318. return 0;
  319. }
  320. static int test_gemm_0(int M, int N, int K)
  321. {
  322. return 0
  323. || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 0, 0, 0, 0)
  324. || test_gemm_int8(M, N, K, 3.1f, 1, 1, 0, 0, 0, 0, 0)
  325. || test_gemm_int8(M, N, K, 4.1f, 0, 0, 0, 0, 0, 0, 1)
  326. || test_gemm_int8(M, N, K, 5.1f, 1, 0, 0, 0, 0, 0, 1)
  327. || test_gemm_int8(M, N, K, 0.2f, 0, 1, 0, 0, 1, 0, 1)
  328. || test_gemm_int8(M, N, K, 0.3f, 1, 1, 0, 0, 1, 0, 1)
  329. || test_gemm_int8(M, N, K, 0.4f, 0, 0, 0, 0, 0, 1, 0)
  330. || test_gemm_int8(M, N, K, 0.5f, 0, 1, 0, 0, 0, 1, 0)
  331. || test_gemm_int8(M, N, K, 1.2f, 0, 1, 0, 0, 1, 1, 0)
  332. || test_gemm_int8(M, N, K, 1.3f, 1, 1, 0, 0, 1, 1, 1)
  333. || test_gemm_int8(M, N, K, 1.4f, 0, 0, 0, 0, 1, 1, 0)
  334. || test_gemm_int8(M, N, K, 1.5f, 1, 0, 0, 0, 1, 1, 1)
  335. || test_gemm_int8(M, N, K, -1.2f, 0, 1, 0, 1, 0, 0, 0)
  336. || test_gemm_int8(M, N, K, -1.3f, 1, 1, 0, 1, 0, 0, 0)
  337. || test_gemm_int8(M, N, K, -1.4f, 0, 0, 0, 1, 0, 0, 1)
  338. || test_gemm_int8(M, N, K, -1.5f, 1, 0, 0, 1, 0, 0, 1)
  339. || test_gemm_int8(M, N, K, -2.0f, 0, 1, 0, 1, 1, 0, 1)
  340. || test_gemm_int8(M, N, K, -3.0f, 1, 1, 0, 1, 1, 0, 1)
  341. || test_gemm_int8(M, N, K, -4.0f, 0, 0, 0, 1, 0, 1, 0)
  342. || test_gemm_int8(M, N, K, -5.0f, 0, 1, 0, 1, 0, 1, 0)
  343. || test_gemm_int8(M, N, K, -2.1f, 0, 1, 0, 1, 1, 1, 0)
  344. || test_gemm_int8(M, N, K, -3.1f, 1, 1, 0, 1, 1, 1, 1)
  345. || test_gemm_int8(M, N, K, -4.1f, 0, 0, 0, 1, 1, 1, 0)
  346. || test_gemm_int8(M, N, K, -5.1f, 1, 0, 0, 1, 1, 1, 1)
  347. || test_gemm_int8_fp16s(M, N, K, 1.f, 0, 1, 0, 0, 0, 0, 0)
  348. || test_gemm_int8_fp16s(M, N, K, 1.f, 1, 0, 0, 1, 0, 0, 0);
  349. }
  350. static int test_gemm_1(int M, int N, int K)
  351. {
  352. return 0
  353. || test_gemm_int8_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 0, 0, 0, 0)
  354. || test_gemm_int8_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 1, 1, 0, 0, 0)
  355. || test_gemm_int8_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 2, 0, 0, 0, 0)
  356. || test_gemm_int8_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 3, 1, 0, 0, 0)
  357. || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 0, 0, 0, 0, 0)
  358. || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 1, 0, 0, 0)
  359. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, -0.8f, 1, 1, 2, 0, 0, 0, 0)
  360. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, -0.8f, 1, 1, 3, 1, 0, 0, 0)
  361. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 0, 0, 0, 0, 0)
  362. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 1, 1, 0, 0, 0)
  363. || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 2, 0, 0, 0, 0)
  364. || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 3, 1, 0, 0, 0)
  365. || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 0.8f, 1.f, 0, 0, 0, 0, 0, 0, 0)
  366. || test_gemm_int8_bias(M, N, K, RandomMat(N), 0.8f, 1.f, 0, 0, 1, 1, 0, 0, 0)
  367. || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 2, 0, 0, 0, 0)
  368. || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 3, 1, 0, 0, 0)
  369. || test_gemm_int8_bias(M, N, K, RandomMat(1), -2.1f, 0.5f, 0, 0, 0, 0, 1, 1, 1)
  370. || test_gemm_int8_bias(M, N, K, RandomMat(1), -2.1f, 0.5f, 0, 0, 1, 1, 1, 1, 1)
  371. || test_gemm_int8_bias(M, N, K, RandomMat(M), -3.1f, 0.6f, 0, 1, 2, 0, 1, 1, 1)
  372. || test_gemm_int8_bias(M, N, K, RandomMat(M), -3.1f, 0.6f, 0, 1, 3, 1, 1, 1, 1)
  373. || test_gemm_int8_bias(M, N, K, RandomMat(1, M), -4.1f, 0.7f, 1, 0, 0, 0, 1, 1, 1)
  374. || test_gemm_int8_bias(M, N, K, RandomMat(1, M), -4.1f, 0.7f, 1, 0, 1, 1, 1, 1, 1)
  375. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 2, 0, 1, 1, 1)
  376. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 3, 1, 1, 1, 1)
  377. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 0, 0, 1, 1, 1)
  378. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 1, 1, 1, 1, 1)
  379. || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 2, 0, 1, 1, 1)
  380. || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 3, 1, 1, 1, 1)
  381. || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 0.8f, 1.f, 0, 0, 0, 0, 1, 1, 1)
  382. || test_gemm_int8_bias(M, N, K, RandomMat(N), 0.8f, 1.f, 0, 0, 1, 1, 1, 1, 1)
  383. || test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 2, 0, 1, 1, 1)
  384. || test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 3, 1, 1, 1, 1);
  385. }
  386. #endif // NCNN_INT8
  387. int main()
  388. {
  389. SRAND(7767517);
  390. #if NCNN_INT8
  391. int mnk[][3] = {
  392. {1, 1, 1},
  393. {1, 1, 23},
  394. {1, 1, 47},
  395. {1, 23, 1},
  396. {1, 23, 23},
  397. {1, 31, 1},
  398. {1, 35, 1},
  399. {1, 35, 47},
  400. {1, 47, 1},
  401. {2, 2, 2},
  402. {3, 3, 3},
  403. {4, 4, 4},
  404. {5, 5, 5},
  405. {6, 6, 6},
  406. {7, 7, 7},
  407. {7, 31, 3},
  408. {8, 8, 8},
  409. {12, 12, 23},
  410. {12, 23, 12},
  411. {12, 31, 12},
  412. {15, 15, 15},
  413. {16, 16, 16},
  414. {19, 44, 7},
  415. {20, 28, 7},
  416. {23, 31, 1},
  417. {23, 31, 23},
  418. {24, 24, 47},
  419. {24, 35, 24},
  420. {24, 47, 24},
  421. {31, 31, 31},
  422. {32, 32, 9},
  423. {35, 47, 48},
  424. {35, 48, 47},
  425. {40, 40, 40},
  426. {47, 48, 47}
  427. };
  428. int mnk_count = sizeof(mnk) / sizeof(int) / 3;
  429. for (int i = 0; i < mnk_count; i++)
  430. {
  431. int M = mnk[i][0];
  432. int N = mnk[i][1];
  433. int K = mnk[i][2];
  434. int ret = test_gemm_0(M, N, K) || test_gemm_1(M, N, K);
  435. if (ret != 0)
  436. return ret;
  437. if (M != N)
  438. {
  439. int ret = test_gemm_0(N, M, K) || test_gemm_1(N, M, K);
  440. if (ret != 0)
  441. return ret;
  442. }
  443. }
  444. #else
  445. // test nothing for non-int8 build
  446. #endif
  447. return 0;
  448. }