You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gemm.cpp 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "testutil.h"
  15. static int test_gemm(int M, int N, int K, float alpha, int transA, int transB, int output_transpose, int constantA, int constantB, int output_N1M = 0)
  16. {
  17. ncnn::ParamDict pd;
  18. pd.set(0, alpha);
  19. pd.set(1, 1.f); // beta
  20. pd.set(2, transA);
  21. pd.set(3, transB);
  22. pd.set(4, constantA);
  23. pd.set(5, constantB);
  24. pd.set(6, 1);
  25. pd.set(7, M);
  26. pd.set(8, N);
  27. pd.set(9, K);
  28. pd.set(10, -1);
  29. pd.set(11, output_N1M);
  30. pd.set(14, output_transpose);
  31. std::vector<ncnn::Mat> weights;
  32. if (constantA) weights.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M)));
  33. if (constantB) weights.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K)));
  34. std::vector<ncnn::Mat> a;
  35. if (!constantA) a.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M)));
  36. if (!constantB) a.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K)));
  37. for (size_t i = 0; i < weights.size(); i++)
  38. {
  39. Randomize(weights[i]);
  40. }
  41. for (size_t i = 0; i < a.size(); i++)
  42. {
  43. Randomize(a[i]);
  44. }
  45. int ret = test_layer("Gemm", pd, weights, a);
  46. if (ret != 0)
  47. {
  48. fprintf(stderr, "test_gemm failed M=%d N=%d K=%d alpha=%f transA=%d transB=%d output_transpose=%d constantA=%d constantB=%d output_N1M=%d\n", M, N, K, alpha, transA, transB, output_transpose, constantA, constantB, output_N1M);
  49. }
  50. return ret;
  51. }
  52. static int test_gemm_bias(int M, int N, int K, const ncnn::Mat& C, float alpha, float beta, int transA, int transB, int output_transpose, int constantA, int constantB, int constantC)
  53. {
  54. int broadcast_type_C = 0;
  55. if (C.dims == 1 && C.w == 1)
  56. {
  57. // scalar
  58. broadcast_type_C = 0;
  59. }
  60. if (C.dims == 1 && C.w == M)
  61. {
  62. // M
  63. // auto broadcast from h to w is the ncnn-style convention
  64. broadcast_type_C = 1;
  65. }
  66. if (C.dims == 1 && C.w == N)
  67. {
  68. // N
  69. broadcast_type_C = 4;
  70. }
  71. if (C.dims == 2 && C.w == 1 && C.h == M)
  72. {
  73. // Mx1
  74. broadcast_type_C = 2;
  75. }
  76. if (C.dims == 2 && C.w == N && C.h == M)
  77. {
  78. // MxN
  79. broadcast_type_C = 3;
  80. }
  81. if (C.dims == 2 && C.w == N && C.h == 1)
  82. {
  83. // 1xN
  84. broadcast_type_C = 4;
  85. }
  86. ncnn::ParamDict pd;
  87. pd.set(0, alpha);
  88. pd.set(1, beta);
  89. pd.set(2, transA);
  90. pd.set(3, transB);
  91. pd.set(4, constantA);
  92. pd.set(5, constantB);
  93. pd.set(6, constantC);
  94. pd.set(7, M);
  95. pd.set(8, N);
  96. pd.set(9, K);
  97. pd.set(10, broadcast_type_C);
  98. pd.set(14, output_transpose);
  99. std::vector<ncnn::Mat> weights;
  100. if (constantA) weights.push_back(transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M));
  101. if (constantB) weights.push_back(transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K));
  102. if (constantC) weights.push_back(C);
  103. std::vector<ncnn::Mat> a;
  104. if (!constantA) a.push_back(transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M));
  105. if (!constantB) a.push_back(transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K));
  106. if (!constantC) a.push_back(C);
  107. for (size_t i = 0; i < weights.size(); i++)
  108. {
  109. Randomize(weights[i]);
  110. }
  111. for (size_t i = 0; i < a.size(); i++)
  112. {
  113. Randomize(a[i]);
  114. }
  115. int ret = test_layer("Gemm", pd, weights, a);
  116. if (ret != 0)
  117. {
  118. fprintf(stderr, "test_gemm_bias failed M=%d N=%d K=%d C.dims=%d C=(%d %d %d) alpha=%f beta=%f transA=%d transB=%d output_transpose=%d constantA=%d constantB=%d constantC=%d\n", M, N, K, C.dims, C.w, C.h, C.c, alpha, beta, transA, transB, output_transpose, constantA, constantB, constantC);
  119. }
  120. return ret;
  121. }
  122. static int test_gemm_0(int M, int N, int K)
  123. {
  124. return 0
  125. || test_gemm(M, N, K, 2.1f, 0, 0, 0, 0, 0)
  126. || test_gemm(M, N, K, 3.1f, 0, 1, 0, 0, 0)
  127. || test_gemm(M, N, K, 4.1f, 1, 0, 0, 0, 0)
  128. || test_gemm(M, N, K, 5.1f, 1, 1, 0, 0, 0)
  129. || test_gemm(M, N, K, 2.1f, 0, 0, 1, 0, 0)
  130. || test_gemm(M, N, K, 3.1f, 0, 1, 1, 0, 0)
  131. || test_gemm(M, N, K, 4.1f, 1, 0, 1, 0, 0)
  132. || test_gemm(M, N, K, 5.1f, 1, 1, 1, 0, 0)
  133. || test_gemm(M, N, K, 1.7f, 0, 1, 0, 0, 0, 1)
  134. || test_gemm(M, N, K, 1.7f, 1, 1, 0, 1, 0, 1)
  135. || test_gemm(M, N, K, 1.9f, 0, 0, 0, 0, 1, 1)
  136. || test_gemm(M, N, K, 1.9f, 1, 0, 0, 1, 1, 1)
  137. || test_gemm(M, N, K, 1.7f, 0, 1, 1, 1, 0, 1)
  138. || test_gemm(M, N, K, 1.7f, 1, 1, 1, 0, 1, 1)
  139. || test_gemm(M, N, K, 1.9f, 0, 0, 1, 1, 1, 1)
  140. || test_gemm(M, N, K, 1.9f, 1, 0, 1, 0, 0, 1)
  141. || test_gemm(M, N, K, 2.1f, 0, 0, 0, 1, 0)
  142. || test_gemm(M, N, K, 3.1f, 0, 1, 0, 1, 0)
  143. || test_gemm(M, N, K, 4.1f, 1, 0, 0, 1, 0)
  144. || test_gemm(M, N, K, 5.1f, 1, 1, 0, 1, 0)
  145. || test_gemm(M, N, K, 2.1f, 0, 0, 1, 1, 0)
  146. || test_gemm(M, N, K, 3.1f, 0, 1, 1, 1, 0)
  147. || test_gemm(M, N, K, 4.1f, 1, 0, 1, 1, 0)
  148. || test_gemm(M, N, K, 5.1f, 1, 1, 1, 1, 0)
  149. || test_gemm(M, N, K, 2.1f, 0, 0, 0, 0, 1)
  150. || test_gemm(M, N, K, 3.1f, 0, 1, 0, 0, 1)
  151. || test_gemm(M, N, K, 4.1f, 1, 0, 0, 0, 1)
  152. || test_gemm(M, N, K, 5.1f, 1, 1, 0, 0, 1)
  153. || test_gemm(M, N, K, 2.1f, 0, 0, 1, 0, 1)
  154. || test_gemm(M, N, K, 3.1f, 0, 1, 1, 0, 1)
  155. || test_gemm(M, N, K, 4.1f, 1, 0, 1, 0, 1)
  156. || test_gemm(M, N, K, 5.1f, 1, 1, 1, 0, 1)
  157. || test_gemm(M, N, K, 2.1f, 0, 0, 0, 1, 1)
  158. || test_gemm(M, N, K, 3.1f, 0, 1, 0, 1, 1)
  159. || test_gemm(M, N, K, 4.1f, 1, 0, 0, 1, 1)
  160. || test_gemm(M, N, K, 5.1f, 1, 1, 0, 1, 1)
  161. || test_gemm(M, N, K, 2.1f, 0, 0, 1, 1, 1)
  162. || test_gemm(M, N, K, 3.1f, 0, 1, 1, 1, 1)
  163. || test_gemm(M, N, K, 4.1f, 1, 0, 1, 1, 1)
  164. || test_gemm(M, N, K, 5.1f, 1, 1, 1, 1, 1);
  165. }
  166. static int test_gemm_1(int M, int N, int K)
  167. {
  168. return 0
  169. || test_gemm_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 0, 0, 0)
  170. || test_gemm_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 0, 0, 0, 0)
  171. || test_gemm_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 0, 0, 0)
  172. || test_gemm_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1, 1, 0, 0, 0)
  173. || test_gemm_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0, 0, 0, 0, 0)
  174. || test_gemm_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1, 0, 0, 0, 0)
  175. || test_gemm_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 1, 0, 0)
  176. || test_gemm_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 0, 1, 0, 0)
  177. || test_gemm_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 1, 0, 0)
  178. || test_gemm_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1, 1, 1, 0, 0)
  179. || test_gemm_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0, 0, 1, 0, 0)
  180. || test_gemm_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1, 0, 1, 0, 0)
  181. || test_gemm_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 0, 1, 0)
  182. || test_gemm_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 0, 0, 1, 0)
  183. || test_gemm_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 0, 1, 0)
  184. || test_gemm_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1, 1, 0, 1, 0)
  185. || test_gemm_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0, 0, 0, 1, 0)
  186. || test_gemm_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1, 0, 0, 1, 0)
  187. || test_gemm_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 1, 1, 0)
  188. || test_gemm_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 0, 1, 1, 0)
  189. || test_gemm_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 1, 1, 0)
  190. || test_gemm_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1, 1, 1, 1, 0)
  191. || test_gemm_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0, 0, 1, 1, 0)
  192. || test_gemm_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1, 0, 1, 1, 0)
  193. || test_gemm_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 0, 0, 1)
  194. || test_gemm_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 0, 0, 0, 1)
  195. || test_gemm_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 0, 0, 1)
  196. || test_gemm_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1, 1, 0, 0, 1)
  197. || test_gemm_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0, 0, 0, 0, 1)
  198. || test_gemm_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1, 0, 0, 0, 1)
  199. || test_gemm_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 1, 0, 1)
  200. || test_gemm_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 0, 1, 0, 1)
  201. || test_gemm_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 1, 0, 1)
  202. || test_gemm_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1, 1, 1, 0, 1)
  203. || test_gemm_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0, 0, 1, 0, 1)
  204. || test_gemm_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1, 0, 1, 0, 1)
  205. || test_gemm_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 0, 1, 1)
  206. || test_gemm_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 0, 0, 1, 1)
  207. || test_gemm_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 0, 1, 1)
  208. || test_gemm_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1, 1, 0, 1, 1)
  209. || test_gemm_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0, 0, 0, 1, 1)
  210. || test_gemm_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1, 0, 0, 1, 1)
  211. || test_gemm_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 1, 1, 1)
  212. || test_gemm_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 0, 1, 1, 1)
  213. || test_gemm_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 1, 1, 1)
  214. || test_gemm_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1, 1, 1, 1, 1)
  215. || test_gemm_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0, 0, 1, 1, 1)
  216. || test_gemm_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1, 0, 1, 1, 1);
  217. }
  218. int main()
  219. {
  220. SRAND(7767517);
  221. int mnk[][3] = {
  222. {1, 1, 1},
  223. {2, 2, 2},
  224. {3, 3, 3},
  225. {4, 4, 4},
  226. {5, 5, 5},
  227. {6, 6, 6},
  228. {7, 7, 7},
  229. {8, 8, 8},
  230. {15, 15, 15},
  231. {16, 16, 16},
  232. {31, 31, 31},
  233. {40, 40, 40},
  234. {1, 1, 23},
  235. {1, 31, 1},
  236. {23, 1, 1},
  237. {12, 12, 23},
  238. {12, 31, 12},
  239. {23, 12, 12},
  240. {1, 1, 47},
  241. {1, 35, 1},
  242. {47, 1, 1},
  243. {24, 24, 47},
  244. {24, 35, 24},
  245. {47, 24, 24},
  246. {1, 35, 47},
  247. {23, 31, 1},
  248. {23, 1, 23},
  249. {23, 31, 23},
  250. {31, 7, 3},
  251. {28, 20, 7},
  252. {32, 32, 9},
  253. {44, 19, 7},
  254. {47, 35, 48},
  255. {47, 48, 47},
  256. {48, 35, 47}
  257. };
  258. int mnk_count = sizeof(mnk) / sizeof(int) / 3;
  259. for (int i = 0; i < mnk_count; i++)
  260. {
  261. int M = mnk[i][0];
  262. int N = mnk[i][1];
  263. int K = mnk[i][2];
  264. int ret = 0
  265. || test_gemm_0(M, N, K)
  266. || test_gemm_1(M, N, K);
  267. if (ret != 0)
  268. return ret;
  269. }
  270. return 0;
  271. }