You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gemm.cpp 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "layer/gemm.h"
  15. #include "testutil.h"
  16. static int test_gemm(int M, int N, int K, float alpha, int transA, int transB, int output_N1M = 0)
  17. {
  18. ncnn::ParamDict pd;
  19. pd.set(0, alpha);
  20. pd.set(1, 1.f); // beta
  21. pd.set(2, transA);
  22. pd.set(3, transB);
  23. pd.set(11, output_N1M);
  24. std::vector<ncnn::Mat> weights(0);
  25. std::vector<ncnn::Mat> a(2);
  26. if (output_N1M)
  27. {
  28. a[0] = transA ? ncnn::Mat(M, 1, K) : ncnn::Mat(K, 1, M);
  29. a[1] = transB ? ncnn::Mat(K, 1, N) : ncnn::Mat(N, 1, K);
  30. }
  31. else
  32. {
  33. a[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M);
  34. a[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K);
  35. }
  36. Randomize(a[0]);
  37. Randomize(a[1]);
  38. int ret = test_layer<ncnn::Gemm>("Gemm", pd, weights, a);
  39. if (ret != 0)
  40. {
  41. fprintf(stderr, "test_gemm failed M=%d N=%d K=%d alpha=%f transA=%d transB=%d output_N1M=%d\n", M, N, K, alpha, transA, transB, output_N1M);
  42. }
  43. return ret;
  44. }
  45. static int test_gemm_constantA(int M, int N, int K, float alpha, int transA, int transB)
  46. {
  47. ncnn::ParamDict pd;
  48. pd.set(0, alpha);
  49. pd.set(1, 1.f); // beta
  50. pd.set(2, transA);
  51. pd.set(3, transB);
  52. pd.set(4, 1);
  53. pd.set(5, 0);
  54. pd.set(6, 1);
  55. pd.set(7, M);
  56. pd.set(8, N);
  57. pd.set(9, K);
  58. pd.set(10, -1);
  59. std::vector<ncnn::Mat> weights(1);
  60. weights[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M);
  61. ncnn::Mat B = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K);
  62. Randomize(weights[0]);
  63. Randomize(B);
  64. int ret = test_layer<ncnn::Gemm>("Gemm", pd, weights, B);
  65. if (ret != 0)
  66. {
  67. fprintf(stderr, "test_gemm_constantA failed M=%d N=%d K=%d alpha=%f transA=%d transB=%d\n", M, N, K, alpha, transA, transB);
  68. }
  69. return ret;
  70. }
  71. static int test_gemm_constantB(int M, int N, int K, float alpha, int transA, int transB)
  72. {
  73. ncnn::ParamDict pd;
  74. pd.set(0, alpha);
  75. pd.set(1, 1.f); // beta
  76. pd.set(2, transA);
  77. pd.set(3, transB);
  78. pd.set(4, 0);
  79. pd.set(5, 1);
  80. pd.set(6, 1);
  81. pd.set(7, M);
  82. pd.set(8, N);
  83. pd.set(9, K);
  84. pd.set(10, -1);
  85. std::vector<ncnn::Mat> weights(1);
  86. weights[0] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K);
  87. ncnn::Mat A = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M);
  88. Randomize(weights[0]);
  89. Randomize(A);
  90. int ret = test_layer<ncnn::Gemm>("Gemm", pd, weights, A);
  91. if (ret != 0)
  92. {
  93. fprintf(stderr, "test_gemm_constantB failed M=%d N=%d K=%d alpha=%f transA=%d transB=%d\n", M, N, K, alpha, transA, transB);
  94. }
  95. return ret;
  96. }
  97. static int test_gemm_constantAB(int M, int N, int K, float alpha, int transA, int transB)
  98. {
  99. ncnn::ParamDict pd;
  100. pd.set(0, alpha);
  101. pd.set(1, 1.f); // beta
  102. pd.set(2, transA);
  103. pd.set(3, transB);
  104. pd.set(4, 1);
  105. pd.set(5, 1);
  106. pd.set(6, 1);
  107. pd.set(7, M);
  108. pd.set(8, N);
  109. pd.set(9, K);
  110. pd.set(10, -1);
  111. std::vector<ncnn::Mat> weights(2);
  112. weights[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M);
  113. weights[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K);
  114. std::vector<ncnn::Mat> a(0);
  115. Randomize(weights[0]);
  116. Randomize(weights[1]);
  117. int ret = test_layer<ncnn::Gemm>("Gemm", pd, weights, a);
  118. if (ret != 0)
  119. {
  120. fprintf(stderr, "test_gemm_constantAB failed M=%d N=%d K=%d alpha=%f transA=%d transB=%d\n", M, N, K, alpha, transA, transB);
  121. }
  122. return ret;
  123. }
  124. static int test_gemm_bias(int M, int N, int K, const ncnn::Mat& C, float alpha, float beta, int transA, int transB)
  125. {
  126. ncnn::ParamDict pd;
  127. pd.set(0, alpha);
  128. pd.set(1, beta);
  129. pd.set(2, transA);
  130. pd.set(3, transB);
  131. std::vector<ncnn::Mat> weights(0);
  132. std::vector<ncnn::Mat> a(3);
  133. a[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M);
  134. a[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K);
  135. a[2] = C;
  136. Randomize(a[0]);
  137. Randomize(a[1]);
  138. int ret = test_layer<ncnn::Gemm>("Gemm", pd, weights, a);
  139. if (ret != 0)
  140. {
  141. fprintf(stderr, "test_gemm_bias failed M=%d N=%d K=%d C.dims=%d C=(%d %d %d) alpha=%f beta=%f transA=%d transB=%d\n", M, N, K, C.dims, C.w, C.h, C.c, alpha, beta, transA, transB);
  142. }
  143. return ret;
  144. }
  145. static int test_gemm_constantABC_bias(int M, int N, int K, const ncnn::Mat& C, float alpha, float beta, int transA, int transB)
  146. {
  147. int broadcast_type_C = 0;
  148. if (C.dims == 1 && C.w == 1)
  149. {
  150. // scalar
  151. broadcast_type_C = 0;
  152. }
  153. if (C.dims == 1 && C.w == M)
  154. {
  155. // M
  156. // auto broadcast from h to w is the ncnn-style convention
  157. broadcast_type_C = 1;
  158. }
  159. if (C.dims == 1 && C.w == N)
  160. {
  161. // N
  162. broadcast_type_C = 4;
  163. }
  164. if (C.dims == 2 && C.w == 1 && C.h == M)
  165. {
  166. // Mx1
  167. broadcast_type_C = 2;
  168. }
  169. if (C.dims == 2 && C.w == N && C.h == M)
  170. {
  171. // MxN
  172. broadcast_type_C = 3;
  173. }
  174. if (C.dims == 2 && C.w == N && C.h == 1)
  175. {
  176. // 1xN
  177. broadcast_type_C = 4;
  178. }
  179. ncnn::ParamDict pd;
  180. pd.set(0, alpha);
  181. pd.set(1, 1.f); // beta
  182. pd.set(2, transA);
  183. pd.set(3, transB);
  184. pd.set(4, 1);
  185. pd.set(5, 1);
  186. pd.set(6, 1);
  187. pd.set(7, M);
  188. pd.set(8, N);
  189. pd.set(9, K);
  190. pd.set(10, broadcast_type_C);
  191. std::vector<ncnn::Mat> weights(3);
  192. weights[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M);
  193. weights[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K);
  194. weights[2] = C;
  195. std::vector<ncnn::Mat> a(0);
  196. Randomize(weights[0]);
  197. Randomize(weights[1]);
  198. Randomize(weights[2]);
  199. int ret = test_layer<ncnn::Gemm>("Gemm", pd, weights, a);
  200. if (ret != 0)
  201. {
  202. fprintf(stderr, "test_gemm_constantABC_bias failed M=%d N=%d K=%d C.dims=%d C=(%d %d %d) broadcast_type_C=%d alpha=%f beta=%f transA=%d transB=%d\n", M, N, K, C.dims, C.w, C.h, C.c, broadcast_type_C, alpha, beta, transA, transB);
  203. }
  204. return ret;
  205. }
  206. static int test_gemm_constantAB_bias(int M, int N, int K, const ncnn::Mat& C, float alpha, float beta, int transA, int transB)
  207. {
  208. ncnn::ParamDict pd;
  209. pd.set(0, alpha);
  210. pd.set(1, 1.f); // beta
  211. pd.set(2, transA);
  212. pd.set(3, transB);
  213. pd.set(4, 1);
  214. pd.set(5, 1);
  215. pd.set(6, 0);
  216. pd.set(7, M);
  217. pd.set(8, N);
  218. pd.set(9, K);
  219. std::vector<ncnn::Mat> weights(2);
  220. weights[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M);
  221. weights[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K);
  222. std::vector<ncnn::Mat> a(1);
  223. a[0] = C;
  224. Randomize(weights[0]);
  225. Randomize(weights[1]);
  226. Randomize(a[0]);
  227. int ret = test_layer<ncnn::Gemm>("Gemm", pd, weights, a);
  228. if (ret != 0)
  229. {
  230. fprintf(stderr, "test_gemm_constantAB_bias failed M=%d N=%d K=%d C.dims=%d C=(%d %d %d) alpha=%f beta=%f transA=%d transB=%d\n", M, N, K, C.dims, C.w, C.h, C.c, alpha, beta, transA, transB);
  231. }
  232. return ret;
  233. }
  234. static int test_gemm_0(int M, int N, int K)
  235. {
  236. return 0
  237. || test_gemm(M, N, K, 2.1f, 0, 0)
  238. || test_gemm(M, N, K, 3.1f, 0, 1)
  239. || test_gemm(M, N, K, 4.1f, 1, 0)
  240. || test_gemm(M, N, K, 5.1f, 1, 1)
  241. || test_gemm(M, N, K, 1.7f, 0, 1, 1)
  242. || test_gemm(M, N, K, 1.7f, 1, 1, 1)
  243. || test_gemm(M, N, K, 1.9f, 0, 0, 1)
  244. || test_gemm(M, N, K, 1.9f, 1, 0, 1)
  245. || test_gemm_constantA(M, N, K, 2.1f, 0, 0)
  246. || test_gemm_constantA(M, N, K, 3.1f, 0, 1)
  247. || test_gemm_constantA(M, N, K, 4.1f, 1, 0)
  248. || test_gemm_constantA(M, N, K, 5.1f, 1, 1)
  249. || test_gemm_constantB(M, N, K, 2.1f, 0, 0)
  250. || test_gemm_constantB(M, N, K, 3.1f, 0, 1)
  251. || test_gemm_constantB(M, N, K, 4.1f, 1, 0)
  252. || test_gemm_constantB(M, N, K, 5.1f, 1, 1)
  253. || test_gemm_constantAB(M, N, K, 2.1f, 0, 0)
  254. || test_gemm_constantAB(M, N, K, 3.1f, 0, 1)
  255. || test_gemm_constantAB(M, N, K, 4.1f, 1, 0)
  256. || test_gemm_constantAB(M, N, K, 5.1f, 1, 1);
  257. }
  258. static int test_gemm_1(int M, int N, int K)
  259. {
  260. return 0
  261. || test_gemm_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0)
  262. || test_gemm_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1)
  263. || test_gemm_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0)
  264. || test_gemm_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1)
  265. || test_gemm_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0)
  266. || test_gemm_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1);
  267. }
  268. static int test_gemm_2(int M, int N, int K)
  269. {
  270. return 0
  271. || test_gemm_constantABC_bias(M, N, K, RandomMat(1), 4.1f, 0.7f, 1, 0)
  272. || test_gemm_constantABC_bias(M, N, K, RandomMat(M), 5.1f, 0.8f, 1, 1)
  273. || test_gemm_constantABC_bias(M, N, K, RandomMat(1, M), 2.1f, 0.5f, 0, 0)
  274. || test_gemm_constantABC_bias(M, N, K, RandomMat(N, M), 3.1f, 0.6f, 0, 1)
  275. || test_gemm_constantABC_bias(M, N, K, RandomMat(N, 1), 4.1f, 0.7f, 1, 0)
  276. || test_gemm_constantABC_bias(M, N, K, RandomMat(N), 5.1f, 0.8f, 1, 1);
  277. }
  278. static int test_gemm_3(int M, int N, int K)
  279. {
  280. return 0
  281. || test_gemm_constantAB_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0)
  282. || test_gemm_constantAB_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1)
  283. || test_gemm_constantAB_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0)
  284. || test_gemm_constantAB_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1)
  285. || test_gemm_constantAB_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0)
  286. || test_gemm_constantAB_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1);
  287. }
  288. int main()
  289. {
  290. SRAND(7767517);
  291. int mnk[][3] = {
  292. {1, 1, 1},
  293. {2, 2, 2},
  294. {3, 3, 3},
  295. {4, 4, 4},
  296. {5, 5, 5},
  297. {6, 6, 6},
  298. {7, 7, 7},
  299. {8, 8, 8},
  300. {15, 15, 15},
  301. {16, 16, 16},
  302. {31, 31, 31},
  303. {40, 40, 40},
  304. {1, 1, 23},
  305. {1, 31, 1},
  306. {23, 1, 1},
  307. {12, 12, 23},
  308. {12, 31, 12},
  309. {23, 12, 12},
  310. {1, 1, 47},
  311. {1, 35, 1},
  312. {47, 1, 1},
  313. {24, 24, 47},
  314. {24, 35, 24},
  315. {47, 24, 24},
  316. {1, 35, 47},
  317. {23, 31, 1},
  318. {23, 1, 23},
  319. {23, 31, 23},
  320. {31, 7, 3},
  321. {28, 20, 7},
  322. {32, 32, 9},
  323. {44, 19, 7},
  324. {47, 35, 48},
  325. {47, 48, 47},
  326. {48, 35, 47},
  327. {25, 25, 527},
  328. {30, 30, 527},
  329. {28, 28, 527},
  330. {40, 40, 527},
  331. {64, 64, 527}
  332. };
  333. int mnk_count = sizeof(mnk) / sizeof(int) / 3;
  334. for (int i = 0; i < mnk_count; i++)
  335. {
  336. int M = mnk[i][0];
  337. int N = mnk[i][1];
  338. int K = mnk[i][2];
  339. int ret = 0
  340. || test_gemm_0(M, N, K)
  341. || test_gemm_1(M, N, K)
  342. || test_gemm_2(M, N, K)
  343. || test_gemm_3(M, N, K);
  344. if (ret != 0)
  345. return 0;
  346. }
  347. return 0;
  348. }