You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gru.cpp 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. // Copyright 2021 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "testutil.h"
  4. static int test_gru(int size, int T, int outch, int direction)
  5. {
  6. ncnn::Mat a = RandomMat(size, T);
  7. int num_directions = direction == 2 ? 2 : 1;
  8. ncnn::ParamDict pd;
  9. pd.set(0, outch);
  10. pd.set(1, outch * size * 3 * num_directions);
  11. pd.set(2, direction);
  12. std::vector<ncnn::Mat> weights(3);
  13. weights[0] = RandomMat(outch * size * 3 * num_directions);
  14. weights[1] = RandomMat(outch * 4 * num_directions);
  15. weights[2] = RandomMat(outch * outch * 3 * num_directions);
  16. int ret = test_layer("GRU", pd, weights, a);
  17. if (ret != 0)
  18. {
  19. fprintf(stderr, "test_gru failed size=%d T=%d outch=%d direction=%d\n", size, T, outch, direction);
  20. }
  21. return ret;
  22. }
  23. static int test_gru_with_hidden(int size, int T, int outch, int direction)
  24. {
  25. ncnn::Mat a = RandomMat(size, T);
  26. int num_directions = direction == 2 ? 2 : 1;
  27. ncnn::ParamDict pd;
  28. pd.set(0, outch);
  29. pd.set(1, outch * size * 3 * num_directions);
  30. pd.set(2, direction);
  31. std::vector<ncnn::Mat> weights(3);
  32. weights[0] = RandomMat(outch * size * 3 * num_directions);
  33. weights[1] = RandomMat(outch * 4 * num_directions);
  34. weights[2] = RandomMat(outch * outch * 3 * num_directions);
  35. // initial hidden state
  36. ncnn::Mat hidden = RandomMat(outch, num_directions);
  37. std::vector<ncnn::Mat> as(2);
  38. as[0] = a;
  39. as[1] = hidden;
  40. int ret = test_layer("GRU", pd, weights, as, 2);
  41. if (ret != 0)
  42. {
  43. fprintf(stderr, "test_gru_with_hidden failed size=%d T=%d outch=%d direction=%d\n", size, T, outch, direction);
  44. }
  45. return ret;
  46. }
  47. static int test_gru_with_hidden_input(int size, int T, int outch, int direction)
  48. {
  49. ncnn::Mat a = RandomMat(size, T);
  50. int num_directions = direction == 2 ? 2 : 1;
  51. ncnn::ParamDict pd;
  52. pd.set(0, outch);
  53. pd.set(1, outch * size * 3 * num_directions);
  54. pd.set(2, direction);
  55. std::vector<ncnn::Mat> weights(3);
  56. weights[0] = RandomMat(outch * size * 3 * num_directions);
  57. weights[1] = RandomMat(outch * 4 * num_directions);
  58. weights[2] = RandomMat(outch * outch * 3 * num_directions);
  59. // initial hidden state
  60. ncnn::Mat hidden = RandomMat(outch, num_directions);
  61. std::vector<ncnn::Mat> as(2);
  62. as[0] = a;
  63. as[1] = hidden;
  64. int ret = test_layer("GRU", pd, weights, as, 1);
  65. if (ret != 0)
  66. {
  67. fprintf(stderr, "test_gru_with_hidden_input failed size=%d T=%d outch=%d direction=%d\n", size, T, outch, direction);
  68. }
  69. return ret;
  70. }
  71. static int test_gru_with_hidden_output(int size, int T, int outch, int direction)
  72. {
  73. ncnn::Mat a = RandomMat(size, T);
  74. int num_directions = direction == 2 ? 2 : 1;
  75. ncnn::ParamDict pd;
  76. pd.set(0, outch);
  77. pd.set(1, outch * size * 3 * num_directions);
  78. pd.set(2, direction);
  79. std::vector<ncnn::Mat> weights(3);
  80. weights[0] = RandomMat(outch * size * 3 * num_directions);
  81. weights[1] = RandomMat(outch * 4 * num_directions);
  82. weights[2] = RandomMat(outch * outch * 3 * num_directions);
  83. std::vector<ncnn::Mat> as(1);
  84. as[0] = a;
  85. int ret = test_layer("GRU", pd, weights, as, 2);
  86. if (ret != 0)
  87. {
  88. fprintf(stderr, "test_gru_with_hidden_output failed size=%d T=%d outch=%d direction=%d\n", size, T, outch, direction);
  89. }
  90. return ret;
  91. }
  92. static int test_gru_0()
  93. {
  94. return 0
  95. || test_gru(4, 1, 2, 2)
  96. || test_gru(8, 2, 2, 2)
  97. || test_gru(16, 8, 7, 2)
  98. || test_gru(17, 8, 8, 2)
  99. || test_gru(19, 15, 8, 2)
  100. || test_gru(5, 16, 16, 2)
  101. || test_gru(3, 16, 8, 2)
  102. || test_gru(8, 16, 16, 2)
  103. || test_gru(31, 3, 31, 2)
  104. || test_gru(2, 5, 17, 2);
  105. }
  106. static int test_gru_1()
  107. {
  108. return 0
  109. || test_gru_with_hidden(4, 4, 1, 2)
  110. || test_gru_with_hidden(8, 2, 2, 2)
  111. || test_gru_with_hidden(16, 8, 7, 2)
  112. || test_gru_with_hidden(17, 8, 8, 2)
  113. || test_gru_with_hidden(19, 15, 8, 2)
  114. || test_gru_with_hidden(5, 16, 16, 2)
  115. || test_gru_with_hidden(3, 16, 8, 2)
  116. || test_gru_with_hidden(2, 5, 79, 2)
  117. || test_gru_with_hidden(4, 4, 1, 1)
  118. || test_gru_with_hidden(8, 2, 2, 1)
  119. || test_gru_with_hidden(16, 8, 7, 1)
  120. || test_gru_with_hidden(17, 8, 8, 1)
  121. || test_gru_with_hidden(19, 15, 8, 1)
  122. || test_gru_with_hidden(5, 16, 16, 1)
  123. || test_gru_with_hidden(3, 16, 8, 1)
  124. || test_gru_with_hidden(2, 5, 79, 1)
  125. || test_gru_with_hidden(4, 2, 1, 0)
  126. || test_gru_with_hidden(8, 2, 2, 0)
  127. || test_gru_with_hidden(16, 8, 7, 0)
  128. || test_gru_with_hidden(17, 8, 8, 0)
  129. || test_gru_with_hidden(19, 15, 8, 0)
  130. || test_gru_with_hidden(5, 16, 16, 0)
  131. || test_gru_with_hidden(3, 16, 8, 0)
  132. || test_gru_with_hidden(2, 5, 17, 0)
  133. || test_gru_with_hidden_input(4, 4, 1, 2)
  134. || test_gru_with_hidden_input(8, 2, 2, 2)
  135. || test_gru_with_hidden_input(16, 8, 7, 2)
  136. || test_gru_with_hidden_input(17, 8, 8, 2)
  137. || test_gru_with_hidden_input(19, 15, 8, 2)
  138. || test_gru_with_hidden_input(5, 16, 16, 2)
  139. || test_gru_with_hidden_input(3, 16, 8, 2)
  140. || test_gru_with_hidden_input(2, 5, 79, 2)
  141. || test_gru_with_hidden_input(4, 4, 1, 1)
  142. || test_gru_with_hidden_input(8, 2, 2, 1)
  143. || test_gru_with_hidden_input(16, 8, 7, 1)
  144. || test_gru_with_hidden_input(17, 8, 8, 1)
  145. || test_gru_with_hidden_input(19, 15, 8, 1)
  146. || test_gru_with_hidden_input(5, 16, 16, 1)
  147. || test_gru_with_hidden_input(3, 16, 8, 1)
  148. || test_gru_with_hidden_input(2, 5, 79, 1)
  149. || test_gru_with_hidden_input(4, 2, 1, 0)
  150. || test_gru_with_hidden_input(8, 2, 2, 0)
  151. || test_gru_with_hidden_input(16, 8, 7, 0)
  152. || test_gru_with_hidden_input(17, 8, 8, 0)
  153. || test_gru_with_hidden_input(19, 15, 8, 0)
  154. || test_gru_with_hidden_input(5, 16, 16, 0)
  155. || test_gru_with_hidden_input(3, 16, 8, 0)
  156. || test_gru_with_hidden_input(2, 5, 17, 0)
  157. || test_gru_with_hidden_output(4, 4, 1, 2)
  158. || test_gru_with_hidden_output(8, 2, 2, 2)
  159. || test_gru_with_hidden_output(16, 8, 7, 2)
  160. || test_gru_with_hidden_output(17, 8, 8, 2)
  161. || test_gru_with_hidden_output(19, 15, 8, 2)
  162. || test_gru_with_hidden_output(5, 16, 16, 2)
  163. || test_gru_with_hidden_output(3, 16, 8, 2)
  164. || test_gru_with_hidden_output(2, 5, 79, 2)
  165. || test_gru_with_hidden_output(4, 4, 1, 1)
  166. || test_gru_with_hidden_output(8, 2, 2, 1)
  167. || test_gru_with_hidden_output(16, 8, 7, 1)
  168. || test_gru_with_hidden_output(17, 8, 8, 1)
  169. || test_gru_with_hidden_output(19, 15, 8, 1)
  170. || test_gru_with_hidden_output(5, 16, 16, 1)
  171. || test_gru_with_hidden_output(3, 16, 8, 1)
  172. || test_gru_with_hidden_output(2, 5, 79, 1)
  173. || test_gru_with_hidden_output(4, 2, 1, 0)
  174. || test_gru_with_hidden_output(8, 2, 2, 0)
  175. || test_gru_with_hidden_output(16, 8, 7, 0)
  176. || test_gru_with_hidden_output(17, 8, 8, 0)
  177. || test_gru_with_hidden_output(19, 15, 8, 0)
  178. || test_gru_with_hidden_output(5, 16, 16, 0)
  179. || test_gru_with_hidden_output(3, 16, 8, 0)
  180. || test_gru_with_hidden_output(2, 5, 17, 0);
  181. }
  182. static int test_gru_2()
  183. {
  184. return 0
  185. || test_gru(4, 1, 1, 0)
  186. || test_gru(8, 2, 2, 0)
  187. || test_gru(16, 8, 7, 0)
  188. || test_gru(17, 8, 8, 0)
  189. || test_gru(19, 15, 8, 0)
  190. || test_gru(5, 16, 16, 0)
  191. || test_gru(3, 16, 8, 0)
  192. || test_gru(8, 16, 16, 0)
  193. || test_gru(2, 5, 17, 0);
  194. }
  195. static int test_gru_3()
  196. {
  197. return 0
  198. || test_gru(4, 1, 1, 1)
  199. || test_gru(8, 2, 2, 1)
  200. || test_gru(16, 8, 7, 1)
  201. || test_gru(17, 8, 8, 1)
  202. || test_gru(19, 15, 8, 1)
  203. || test_gru(5, 16, 16, 1)
  204. || test_gru(3, 16, 8, 1)
  205. || test_gru(8, 16, 16, 1)
  206. || test_gru(2, 5, 17, 1);
  207. }
  208. #if NCNN_INT8
  209. static void RandomizeA(ncnn::Mat& m, float absmax)
  210. {
  211. absmax = ncnn::float16_to_float32(ncnn::float32_to_float16(absmax));
  212. absmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(absmax));
  213. const int h = m.h;
  214. float* p = m;
  215. for (int i = 0; i < h; i++)
  216. {
  217. float* p = m.row(i);
  218. for (int j = 0; j < m.w; j++)
  219. {
  220. p[j] = RandomFloat(-absmax, absmax);
  221. // drop 0.45 ~ 0.55
  222. float v = p[j] * (127.f / absmax);
  223. float vv = fabs(v - (int)v);
  224. float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  225. float hv = hp * (127.f / absmax);
  226. float hvv = fabs(hv - (int)hv);
  227. float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  228. float bv = bp * (127.f / absmax);
  229. float bvv = fabs(bv - (int)bv);
  230. while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f))
  231. {
  232. p[j] = RandomFloat(-absmax, absmax);
  233. v = p[j] * (127.f / absmax);
  234. vv = fabs(v - (int)v);
  235. hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j]));
  236. hv = hp * (127.f / absmax);
  237. hvv = fabs(hv - (int)hv);
  238. bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j]));
  239. bv = bp * (127.f / absmax);
  240. bvv = fabs(bv - (int)bv);
  241. }
  242. }
  243. }
  244. // set random a and b
  245. m.row(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = -absmax;
  246. m.row(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = absmax;
  247. }
  248. static int test_gru_int8(int size, int T, int outch, int direction)
  249. {
  250. int num_directions = direction == 2 ? 2 : 1;
  251. ncnn::ParamDict pd;
  252. pd.set(0, outch);
  253. pd.set(1, outch * size * 3 * num_directions);
  254. pd.set(2, direction);
  255. pd.set(8, 2); // int8_scale_term
  256. std::vector<ncnn::Mat> weights(5);
  257. weights[0] = RandomS8Mat(outch * size * 3 * num_directions);
  258. weights[1] = RandomMat(outch * 4 * num_directions);
  259. weights[2] = RandomS8Mat(outch * outch * 3 * num_directions);
  260. weights[3] = RandomMat(outch * 3 * num_directions, 100.f, 200.f);
  261. weights[4] = RandomMat(outch * 3 * num_directions, 100.f, 200.f);
  262. ncnn::Mat a(size, T);
  263. RandomizeA(a, 10.f);
  264. int ret = test_layer("GRU", pd, weights, a);
  265. if (ret != 0)
  266. {
  267. fprintf(stderr, "test_gru_int8 failed size=%d T=%d outch=%d direction=%d\n", size, T, outch, direction);
  268. }
  269. return ret;
  270. }
  271. static int test_gru_int8_with_hidden(int size, int T, int outch, int direction)
  272. {
  273. int num_directions = direction == 2 ? 2 : 1;
  274. ncnn::ParamDict pd;
  275. pd.set(0, outch);
  276. pd.set(1, outch * size * 3 * num_directions);
  277. pd.set(2, direction);
  278. pd.set(8, 2); // int8_scale_term
  279. std::vector<ncnn::Mat> weights(5);
  280. weights[0] = RandomS8Mat(outch * size * 3 * num_directions);
  281. weights[1] = RandomMat(outch * 4 * num_directions);
  282. weights[2] = RandomS8Mat(outch * outch * 3 * num_directions);
  283. weights[3] = RandomMat(outch * 3 * num_directions, 100.f, 200.f);
  284. weights[4] = RandomMat(outch * 3 * num_directions, 100.f, 200.f);
  285. ncnn::Mat a(size, T);
  286. RandomizeA(a, 10.f);
  287. // initial hidden state
  288. ncnn::Mat hidden(outch, num_directions);
  289. RandomizeA(hidden, 10.f);
  290. std::vector<ncnn::Mat> as(2);
  291. as[0] = a;
  292. as[1] = hidden;
  293. int ret = test_layer("GRU", pd, weights, as, 2);
  294. if (ret != 0)
  295. {
  296. fprintf(stderr, "test_gru_int8_with_hidden failed size=%d T=%d outch=%d direction=%d\n", size, T, outch, direction);
  297. }
  298. return ret;
  299. }
  300. static int test_gru_int8_with_hidden_input(int size, int T, int outch, int direction)
  301. {
  302. int num_directions = direction == 2 ? 2 : 1;
  303. ncnn::ParamDict pd;
  304. pd.set(0, outch);
  305. pd.set(1, outch * size * 3 * num_directions);
  306. pd.set(2, direction);
  307. pd.set(8, 2); // int8_scale_term
  308. std::vector<ncnn::Mat> weights(5);
  309. weights[0] = RandomS8Mat(outch * size * 3 * num_directions);
  310. weights[1] = RandomMat(outch * 4 * num_directions);
  311. weights[2] = RandomS8Mat(outch * outch * 3 * num_directions);
  312. weights[3] = RandomMat(outch * 3 * num_directions, 100.f, 200.f);
  313. weights[4] = RandomMat(outch * 3 * num_directions, 100.f, 200.f);
  314. ncnn::Mat a(size, T);
  315. RandomizeA(a, 10.f);
  316. // initial hidden state
  317. ncnn::Mat hidden(outch, num_directions);
  318. RandomizeA(hidden, 10.f);
  319. std::vector<ncnn::Mat> as(2);
  320. as[0] = a;
  321. as[1] = hidden;
  322. int ret = test_layer("GRU", pd, weights, as, 1);
  323. if (ret != 0)
  324. {
  325. fprintf(stderr, "test_gru_int8_with_hidden_input failed size=%d T=%d outch=%d direction=%d\n", size, T, outch, direction);
  326. }
  327. return ret;
  328. }
  329. static int test_gru_int8_with_hidden_output(int size, int T, int outch, int direction)
  330. {
  331. int num_directions = direction == 2 ? 2 : 1;
  332. ncnn::ParamDict pd;
  333. pd.set(0, outch);
  334. pd.set(1, outch * size * 3 * num_directions);
  335. pd.set(2, direction);
  336. pd.set(8, 2); // int8_scale_term
  337. std::vector<ncnn::Mat> weights(5);
  338. weights[0] = RandomS8Mat(outch * size * 3 * num_directions);
  339. weights[1] = RandomMat(outch * 4 * num_directions);
  340. weights[2] = RandomS8Mat(outch * outch * 3 * num_directions);
  341. weights[3] = RandomMat(outch * 3 * num_directions, 100.f, 200.f);
  342. weights[4] = RandomMat(outch * 3 * num_directions, 100.f, 200.f);
  343. ncnn::Mat a(size, T);
  344. RandomizeA(a, 10.f);
  345. std::vector<ncnn::Mat> as(1);
  346. as[0] = a;
  347. int ret = test_layer("GRU", pd, weights, as, 2);
  348. if (ret != 0)
  349. {
  350. fprintf(stderr, "test_gru_int8_with_hidden_output failed size=%d T=%d outch=%d direction=%d\n", size, T, outch, direction);
  351. }
  352. return ret;
  353. }
  354. static int test_gru_4()
  355. {
  356. return 0
  357. || test_gru_int8(4, 1, 2, 2)
  358. || test_gru_int8(8, 2, 2, 2)
  359. || test_gru_int8(16, 8, 7, 2)
  360. || test_gru_int8(17, 8, 8, 2)
  361. || test_gru_int8(19, 15, 8, 2)
  362. || test_gru_int8(5, 16, 16, 2)
  363. || test_gru_int8(3, 16, 8, 2)
  364. || test_gru_int8(8, 16, 16, 2)
  365. || test_gru_int8(31, 3, 31, 2)
  366. || test_gru_int8(2, 5, 17, 2);
  367. }
  368. static int test_gru_5()
  369. {
  370. return 0
  371. || test_gru_int8_with_hidden(4, 4, 1, 2)
  372. || test_gru_int8_with_hidden(8, 2, 2, 2)
  373. || test_gru_int8_with_hidden(16, 8, 7, 2)
  374. || test_gru_int8_with_hidden(17, 8, 8, 2)
  375. || test_gru_int8_with_hidden(19, 15, 8, 2)
  376. || test_gru_int8_with_hidden(5, 16, 16, 2)
  377. || test_gru_int8_with_hidden(3, 16, 8, 2)
  378. || test_gru_int8_with_hidden(2, 5, 79, 2)
  379. || test_gru_int8_with_hidden(4, 4, 1, 1)
  380. || test_gru_int8_with_hidden(8, 2, 2, 1)
  381. || test_gru_int8_with_hidden(16, 8, 7, 1)
  382. || test_gru_int8_with_hidden(17, 8, 8, 1)
  383. || test_gru_int8_with_hidden(19, 15, 8, 1)
  384. || test_gru_int8_with_hidden(5, 16, 16, 1)
  385. || test_gru_int8_with_hidden(3, 16, 8, 1)
  386. || test_gru_int8_with_hidden(2, 5, 79, 1)
  387. || test_gru_int8_with_hidden(4, 2, 1, 0)
  388. || test_gru_int8_with_hidden(8, 2, 2, 0)
  389. || test_gru_int8_with_hidden(16, 8, 7, 0)
  390. || test_gru_int8_with_hidden(17, 8, 8, 0)
  391. || test_gru_int8_with_hidden(19, 15, 8, 0)
  392. || test_gru_int8_with_hidden(5, 16, 16, 0)
  393. || test_gru_int8_with_hidden(3, 16, 8, 0)
  394. || test_gru_int8_with_hidden(2, 5, 17, 0)
  395. || test_gru_int8_with_hidden_input(4, 4, 1, 2)
  396. || test_gru_int8_with_hidden_input(8, 2, 2, 2)
  397. || test_gru_int8_with_hidden_input(16, 8, 7, 2)
  398. || test_gru_int8_with_hidden_input(17, 8, 8, 2)
  399. || test_gru_int8_with_hidden_input(19, 15, 8, 2)
  400. || test_gru_int8_with_hidden_input(5, 16, 16, 2)
  401. || test_gru_int8_with_hidden_input(3, 16, 8, 2)
  402. || test_gru_int8_with_hidden_input(2, 5, 79, 2)
  403. || test_gru_int8_with_hidden_input(4, 4, 1, 1)
  404. || test_gru_int8_with_hidden_input(8, 2, 2, 1)
  405. || test_gru_int8_with_hidden_input(16, 8, 7, 1)
  406. || test_gru_int8_with_hidden_input(17, 8, 8, 1)
  407. || test_gru_int8_with_hidden_input(19, 15, 8, 1)
  408. || test_gru_int8_with_hidden_input(5, 16, 16, 1)
  409. || test_gru_int8_with_hidden_input(3, 16, 8, 1)
  410. || test_gru_int8_with_hidden_input(2, 5, 79, 1)
  411. || test_gru_int8_with_hidden_input(4, 2, 1, 0)
  412. || test_gru_int8_with_hidden_input(8, 2, 2, 0)
  413. || test_gru_int8_with_hidden_input(16, 8, 7, 0)
  414. || test_gru_int8_with_hidden_input(17, 8, 8, 0)
  415. || test_gru_int8_with_hidden_input(19, 15, 8, 0)
  416. || test_gru_int8_with_hidden_input(5, 16, 16, 0)
  417. || test_gru_int8_with_hidden_input(3, 16, 8, 0)
  418. || test_gru_int8_with_hidden_input(2, 5, 17, 0)
  419. || test_gru_int8_with_hidden_output(4, 4, 1, 2)
  420. || test_gru_int8_with_hidden_output(8, 2, 2, 2)
  421. || test_gru_int8_with_hidden_output(16, 8, 7, 2)
  422. || test_gru_int8_with_hidden_output(17, 8, 8, 2)
  423. || test_gru_int8_with_hidden_output(19, 15, 8, 2)
  424. || test_gru_int8_with_hidden_output(5, 16, 16, 2)
  425. || test_gru_int8_with_hidden_output(3, 16, 8, 2)
  426. || test_gru_int8_with_hidden_output(2, 5, 79, 2)
  427. || test_gru_int8_with_hidden_output(4, 4, 1, 1)
  428. || test_gru_int8_with_hidden_output(8, 2, 2, 1)
  429. || test_gru_int8_with_hidden_output(16, 8, 7, 1)
  430. || test_gru_int8_with_hidden_output(17, 8, 8, 1)
  431. || test_gru_int8_with_hidden_output(19, 15, 8, 1)
  432. || test_gru_int8_with_hidden_output(5, 16, 16, 1)
  433. || test_gru_int8_with_hidden_output(3, 16, 8, 1)
  434. || test_gru_int8_with_hidden_output(2, 5, 79, 1)
  435. || test_gru_int8_with_hidden_output(4, 2, 1, 0)
  436. || test_gru_int8_with_hidden_output(8, 2, 2, 0)
  437. || test_gru_int8_with_hidden_output(16, 8, 7, 0)
  438. || test_gru_int8_with_hidden_output(17, 8, 8, 0)
  439. || test_gru_int8_with_hidden_output(19, 15, 8, 0)
  440. || test_gru_int8_with_hidden_output(5, 16, 16, 0)
  441. || test_gru_int8_with_hidden_output(3, 16, 8, 0)
  442. || test_gru_int8_with_hidden_output(2, 5, 17, 0);
  443. }
  444. static int test_gru_6()
  445. {
  446. return 0
  447. || test_gru_int8(4, 1, 1, 0)
  448. || test_gru_int8(8, 2, 2, 0)
  449. || test_gru_int8(16, 8, 7, 0)
  450. || test_gru_int8(17, 8, 8, 0)
  451. || test_gru_int8(19, 15, 8, 0)
  452. || test_gru_int8(5, 16, 16, 0)
  453. || test_gru_int8(3, 16, 8, 0)
  454. || test_gru_int8(8, 16, 16, 0)
  455. || test_gru_int8(2, 5, 17, 0);
  456. }
  457. static int test_gru_7()
  458. {
  459. return 0
  460. || test_gru_int8(4, 1, 1, 1)
  461. || test_gru_int8(8, 2, 2, 1)
  462. || test_gru_int8(16, 8, 7, 1)
  463. || test_gru_int8(17, 8, 8, 1)
  464. || test_gru_int8(19, 15, 8, 1)
  465. || test_gru_int8(5, 16, 16, 1)
  466. || test_gru_int8(3, 16, 8, 1)
  467. || test_gru_int8(8, 16, 16, 1)
  468. || test_gru_int8(2, 5, 17, 1);
  469. }
  470. #endif
  471. int main()
  472. {
  473. SRAND(7767517);
  474. #if NCNN_INT8
  475. return 0
  476. || test_gru_0()
  477. || test_gru_1()
  478. || test_gru_2()
  479. || test_gru_3()
  480. || test_gru_4()
  481. || test_gru_5()
  482. || test_gru_6()
  483. || test_gru_7();
  484. #else
  485. return 0
  486. || test_gru_0()
  487. || test_gru_1()
  488. || test_gru_2()
  489. || test_gru_3();
  490. #endif
  491. }