You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gemm_3.cpp 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "testutil.h"
  15. #if NCNN_INT8
  16. static int test_gemm_int8(int M, int N, int K, float alpha, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int output_N1M)
  17. {
  18. ncnn::ParamDict pd;
  19. pd.set(0, alpha);
  20. pd.set(1, 1.f); // beta
  21. pd.set(2, transA);
  22. pd.set(3, transB);
  23. pd.set(4, constantA);
  24. pd.set(5, constantB);
  25. pd.set(6, 1);
  26. pd.set(7, M);
  27. pd.set(8, N);
  28. pd.set(9, K);
  29. pd.set(10, -1);
  30. pd.set(11, output_N1M);
  31. pd.set(13, output_elemtype);
  32. pd.set(14, output_transpose);
  33. pd.set(18, 2); // int8_scale_term
  34. std::vector<ncnn::Mat> weights;
  35. if (constantA) weights.push_back(transA ? (output_N1M ? RandomS8Mat(M, 1, K) : RandomS8Mat(M, K)) : (output_N1M ? RandomS8Mat(K, 1, M) : RandomS8Mat(K, M)));
  36. if (constantB) weights.push_back(transB ? (output_N1M ? RandomS8Mat(K, 1, N) : RandomS8Mat(K, N)) : (output_N1M ? RandomS8Mat(N, 1, K) : RandomS8Mat(N, K)));
  37. if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f));
  38. if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f));
  39. std::vector<ncnn::Mat> a;
  40. if (!constantA) a.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M)));
  41. if (!constantB) a.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K)));
  42. for (size_t i = 0; i < a.size(); i++)
  43. {
  44. Randomize(a[i], -10.f, 10.f);
  45. }
  46. int ret = test_layer("Gemm", pd, weights, a);
  47. if (ret != 0)
  48. {
  49. fprintf(stderr, "test_gemm_int8 failed M=%d N=%d K=%d alpha=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d output_N1M=%d\n", M, N, K, alpha, transA, transB, output_elemtype, output_transpose, constantA, constantB, output_N1M);
  50. }
  51. return ret;
  52. }
  53. static int test_gemm_int8_bias(int M, int N, int K, const ncnn::Mat& C, float alpha, float beta, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int constantC)
  54. {
  55. int broadcast_type_C = 0;
  56. if (C.dims == 1 && C.w == 1)
  57. {
  58. // scalar
  59. broadcast_type_C = 0;
  60. }
  61. if (C.dims == 1 && C.w == M)
  62. {
  63. // M
  64. // auto broadcast from h to w is the ncnn-style convention
  65. broadcast_type_C = 1;
  66. }
  67. if (C.dims == 1 && C.w == N)
  68. {
  69. // N
  70. broadcast_type_C = 4;
  71. }
  72. if (C.dims == 2 && C.w == 1 && C.h == M)
  73. {
  74. // Mx1
  75. broadcast_type_C = 2;
  76. }
  77. if (C.dims == 2 && C.w == N && C.h == M)
  78. {
  79. // MxN
  80. broadcast_type_C = 3;
  81. }
  82. if (C.dims == 2 && C.w == N && C.h == 1)
  83. {
  84. // 1xN
  85. broadcast_type_C = 4;
  86. }
  87. ncnn::ParamDict pd;
  88. pd.set(0, alpha);
  89. pd.set(1, beta);
  90. pd.set(2, transA);
  91. pd.set(3, transB);
  92. pd.set(4, constantA);
  93. pd.set(5, constantB);
  94. pd.set(6, constantC);
  95. pd.set(7, M);
  96. pd.set(8, N);
  97. pd.set(9, K);
  98. pd.set(10, broadcast_type_C);
  99. // pd.set(12, 1); // output_elempack
  100. pd.set(13, output_elemtype);
  101. pd.set(14, output_transpose);
  102. pd.set(18, 2); // int8_scale_term
  103. std::vector<ncnn::Mat> weights;
  104. if (constantA) weights.push_back(transA ? RandomS8Mat(M, K) : RandomS8Mat(K, M));
  105. if (constantB) weights.push_back(transB ? RandomS8Mat(K, N) : RandomS8Mat(N, K));
  106. if (constantC) weights.push_back(C);
  107. if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f));
  108. if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f));
  109. std::vector<ncnn::Mat> a;
  110. if (!constantA) a.push_back(transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M));
  111. if (!constantB) a.push_back(transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K));
  112. if (!constantC) a.push_back(C);
  113. for (size_t i = 0; i < a.size(); i++)
  114. {
  115. Randomize(a[i], -10.f, 10.f);
  116. }
  117. int ret = test_layer("Gemm", pd, weights, a);
  118. if (ret != 0)
  119. {
  120. fprintf(stderr, "test_gemm_int8_bias failed M=%d N=%d K=%d C.dims=%d C=(%d %d %d) alpha=%f beta=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d constantC=%d\n", M, N, K, C.dims, C.w, C.h, C.c, alpha, beta, transA, transB, output_elemtype, output_transpose, constantA, constantB, constantC);
  121. }
  122. return ret;
  123. }
  124. static int test_gemm_int8_fp16s(int M, int N, int K, float alpha, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int output_N1M)
  125. {
  126. ncnn::ParamDict pd;
  127. pd.set(0, alpha);
  128. pd.set(1, 1.f); // beta
  129. pd.set(2, transA);
  130. pd.set(3, transB);
  131. pd.set(4, constantA);
  132. pd.set(5, constantB);
  133. pd.set(6, 1);
  134. pd.set(7, M);
  135. pd.set(8, N);
  136. pd.set(9, K);
  137. pd.set(10, -1);
  138. pd.set(11, output_N1M);
  139. pd.set(13, output_elemtype);
  140. pd.set(14, output_transpose);
  141. pd.set(18, 2); // int8_scale_term
  142. std::vector<ncnn::Mat> weights;
  143. if (constantA) weights.push_back(transA ? (output_N1M ? RandomS8Mat(M, 1, K) : RandomS8Mat(M, K)) : (output_N1M ? RandomS8Mat(K, 1, M) : RandomS8Mat(K, M)));
  144. if (constantB) weights.push_back(transB ? (output_N1M ? RandomS8Mat(K, 1, N) : RandomS8Mat(K, N)) : (output_N1M ? RandomS8Mat(N, 1, K) : RandomS8Mat(N, K)));
  145. if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f));
  146. if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f));
  147. std::vector<ncnn::Mat> a;
  148. if (!constantA) a.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M)));
  149. if (!constantB) a.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K)));
  150. for (size_t i = 0; i < a.size(); i++)
  151. {
  152. Randomize(a[i], -10.f, 10.f);
  153. }
  154. ncnn::Option opt;
  155. opt.num_threads = 1;
  156. opt.use_packing_layout = true;
  157. opt.use_fp16_packed = false;
  158. opt.use_fp16_storage = true;
  159. opt.use_fp16_arithmetic = false;
  160. opt.use_bf16_storage = false;
  161. float epsilon = 0.001;
  162. int ret = test_layer_opt("Gemm", pd, weights, opt, a, 1, epsilon);
  163. if (ret != 0)
  164. {
  165. fprintf(stderr, "test_gemm_int8_fp16s failed M=%d N=%d K=%d alpha=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d output_N1M=%d\n", M, N, K, alpha, transA, transB, output_elemtype, output_transpose, constantA, constantB, output_N1M);
  166. return ret;
  167. }
  168. return 0;
  169. }
  170. static int test_gemm_0(int M, int N, int K)
  171. {
  172. return 0
  173. || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 0, 0, 0, 0)
  174. || test_gemm_int8(M, N, K, 3.1f, 1, 1, 0, 0, 0, 0, 0)
  175. || test_gemm_int8(M, N, K, 4.1f, 0, 0, 0, 0, 0, 0, 1)
  176. || test_gemm_int8(M, N, K, 5.1f, 1, 0, 0, 0, 0, 0, 1)
  177. || test_gemm_int8(M, N, K, 0.2f, 0, 1, 0, 0, 1, 0, 1)
  178. || test_gemm_int8(M, N, K, 0.3f, 1, 1, 0, 0, 1, 0, 1)
  179. || test_gemm_int8(M, N, K, 0.4f, 0, 0, 0, 0, 0, 1, 0)
  180. || test_gemm_int8(M, N, K, 0.5f, 0, 1, 0, 0, 0, 1, 0)
  181. || test_gemm_int8(M, N, K, 1.2f, 0, 1, 0, 0, 1, 1, 0)
  182. || test_gemm_int8(M, N, K, 1.3f, 1, 1, 0, 0, 1, 1, 1)
  183. || test_gemm_int8(M, N, K, 1.4f, 0, 0, 0, 0, 1, 1, 0)
  184. || test_gemm_int8(M, N, K, 1.5f, 1, 0, 0, 0, 1, 1, 1)
  185. || test_gemm_int8(M, N, K, -1.2f, 0, 1, 0, 1, 0, 0, 0)
  186. || test_gemm_int8(M, N, K, -1.3f, 1, 1, 0, 1, 0, 0, 0)
  187. || test_gemm_int8(M, N, K, -1.4f, 0, 0, 0, 1, 0, 0, 1)
  188. || test_gemm_int8(M, N, K, -1.5f, 1, 0, 0, 1, 0, 0, 1)
  189. || test_gemm_int8(M, N, K, -2.0f, 0, 1, 0, 1, 1, 0, 1)
  190. || test_gemm_int8(M, N, K, -3.0f, 1, 1, 0, 1, 1, 0, 1)
  191. || test_gemm_int8(M, N, K, -4.0f, 0, 0, 0, 1, 0, 1, 0)
  192. || test_gemm_int8(M, N, K, -5.0f, 0, 1, 0, 1, 0, 1, 0)
  193. || test_gemm_int8(M, N, K, -2.1f, 0, 1, 0, 1, 1, 1, 0)
  194. || test_gemm_int8(M, N, K, -3.1f, 1, 1, 0, 1, 1, 1, 1)
  195. || test_gemm_int8(M, N, K, -4.1f, 0, 0, 0, 1, 1, 1, 0)
  196. || test_gemm_int8(M, N, K, -5.1f, 1, 0, 0, 1, 1, 1, 1)
  197. || test_gemm_int8_fp16s(M, N, K, 1.f, 0, 1, 0, 0, 0, 0, 0)
  198. || test_gemm_int8_fp16s(M, N, K, 1.f, 1, 0, 0, 1, 0, 0, 0);
  199. }
  200. static int test_gemm_1(int M, int N, int K)
  201. {
  202. return 0
  203. || test_gemm_int8_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 0, 0, 0, 0)
  204. || test_gemm_int8_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 1, 1, 0, 0, 0)
  205. || test_gemm_int8_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 2, 0, 0, 0, 0)
  206. || test_gemm_int8_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 3, 1, 0, 0, 0)
  207. || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 0, 0, 0, 0, 0)
  208. || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 1, 0, 0, 0)
  209. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, -0.8f, 1, 1, 2, 0, 0, 0, 0)
  210. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, -0.8f, 1, 1, 3, 1, 0, 0, 0)
  211. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 0, 0, 0, 0, 0)
  212. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 1, 1, 0, 0, 0)
  213. || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 2, 0, 0, 0, 0)
  214. || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 3, 1, 0, 0, 0)
  215. || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 0.8f, 1.f, 0, 0, 0, 0, 0, 0, 0)
  216. || test_gemm_int8_bias(M, N, K, RandomMat(N), 0.8f, 1.f, 0, 0, 1, 1, 0, 0, 0)
  217. || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 2, 0, 0, 0, 0)
  218. || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 3, 1, 0, 0, 0)
  219. || test_gemm_int8_bias(M, N, K, RandomMat(1), -2.1f, 0.5f, 0, 0, 0, 0, 1, 1, 1)
  220. || test_gemm_int8_bias(M, N, K, RandomMat(1), -2.1f, 0.5f, 0, 0, 1, 1, 1, 1, 1)
  221. || test_gemm_int8_bias(M, N, K, RandomMat(M), -3.1f, 0.6f, 0, 1, 2, 0, 1, 1, 1)
  222. || test_gemm_int8_bias(M, N, K, RandomMat(M), -3.1f, 0.6f, 0, 1, 3, 1, 1, 1, 1)
  223. || test_gemm_int8_bias(M, N, K, RandomMat(1, M), -4.1f, 0.7f, 1, 0, 0, 0, 1, 1, 1)
  224. || test_gemm_int8_bias(M, N, K, RandomMat(1, M), -4.1f, 0.7f, 1, 0, 1, 1, 1, 1, 1)
  225. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 2, 0, 1, 1, 1)
  226. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 3, 1, 1, 1, 1)
  227. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 0, 0, 1, 1, 1)
  228. || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 1, 1, 1, 1, 1)
  229. || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 2, 0, 1, 1, 1)
  230. || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 3, 1, 1, 1, 1)
  231. || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 0.8f, 1.f, 0, 0, 0, 0, 1, 1, 1)
  232. || test_gemm_int8_bias(M, N, K, RandomMat(N), 0.8f, 1.f, 0, 0, 1, 1, 1, 1, 1)
  233. || test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 2, 0, 1, 1, 1)
  234. || test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 3, 1, 1, 1, 1);
  235. }
  236. #endif // NCNN_INT8
  237. int main()
  238. {
  239. SRAND(7767517);
  240. #if NCNN_INT8
  241. int mnk[][3] = {
  242. {1, 1, 1},
  243. {1, 1, 23},
  244. {1, 1, 47},
  245. {1, 23, 1},
  246. {1, 23, 23},
  247. {1, 31, 1},
  248. {1, 35, 1},
  249. {1, 35, 47},
  250. {1, 47, 1},
  251. {2, 2, 2},
  252. {3, 3, 3},
  253. {4, 4, 4},
  254. {5, 5, 5},
  255. {6, 6, 6},
  256. {7, 7, 7},
  257. {7, 31, 3},
  258. {8, 8, 8},
  259. {12, 12, 23},
  260. {12, 23, 12},
  261. {12, 31, 12},
  262. {15, 15, 15},
  263. {16, 16, 16},
  264. {19, 44, 7},
  265. {20, 28, 7},
  266. {23, 31, 1},
  267. {23, 31, 23},
  268. {24, 24, 47},
  269. {24, 35, 24},
  270. {24, 47, 24},
  271. {31, 31, 31},
  272. {32, 32, 9},
  273. {35, 47, 48},
  274. {35, 48, 47},
  275. {40, 40, 40},
  276. {47, 48, 47}
  277. };
  278. int mnk_count = sizeof(mnk) / sizeof(int) / 3;
  279. for (int i = 0; i < mnk_count; i++)
  280. {
  281. int M = mnk[i][0];
  282. int N = mnk[i][1];
  283. int K = mnk[i][2];
  284. int ret = test_gemm_0(M, N, K) || test_gemm_1(M, N, K);
  285. if (ret != 0)
  286. return ret;
  287. if (M != N)
  288. {
  289. int ret = test_gemm_0(N, M, K) || test_gemm_1(N, M, K);
  290. if (ret != 0)
  291. return ret;
  292. }
  293. }
  294. #else
  295. // test nothing for non-int8 build
  296. #endif
  297. return 0;
  298. }