You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reduction.cpp 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. // Copyright 2021 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "testutil.h"
  4. #define OP_TYPE_MAX 11
  5. static int op_type = 0;
  6. static std::vector<int> IntArray(int a0)
  7. {
  8. std::vector<int> m(1);
  9. m[0] = a0;
  10. return m;
  11. }
  12. static std::vector<int> IntArray(int a0, int a1)
  13. {
  14. std::vector<int> m(2);
  15. m[0] = a0;
  16. m[1] = a1;
  17. return m;
  18. }
  19. static std::vector<int> IntArray(int a0, int a1, int a2)
  20. {
  21. std::vector<int> m(3);
  22. m[0] = a0;
  23. m[1] = a1;
  24. m[2] = a2;
  25. return m;
  26. }
  27. static std::vector<int> IntArray(int a0, int a1, int a2, int a3)
  28. {
  29. std::vector<int> m(4);
  30. m[0] = a0;
  31. m[1] = a1;
  32. m[2] = a2;
  33. m[3] = a3;
  34. return m;
  35. }
  36. static void print_int_array(const std::vector<int>& a)
  37. {
  38. fprintf(stderr, "[");
  39. for (size_t i = 0; i < a.size(); i++)
  40. {
  41. fprintf(stderr, " %d", a[i]);
  42. }
  43. fprintf(stderr, " ]");
  44. }
  45. static int test_reduction(const ncnn::Mat& _a, float coeff, int keepdims)
  46. {
  47. ncnn::Mat a = _a;
  48. if (op_type == 9 || op_type == 10)
  49. {
  50. // value must be positive for logsum and logsumexp
  51. Randomize(a, 0.001f, 2.f);
  52. }
  53. ncnn::ParamDict pd;
  54. pd.set(0, op_type);
  55. pd.set(1, 1); // reduce_all
  56. pd.set(2, coeff);
  57. pd.set(4, keepdims);
  58. std::vector<ncnn::Mat> weights(0);
  59. int ret = test_layer("Reduction", pd, weights, a);
  60. if (ret != 0)
  61. {
  62. fprintf(stderr, "test_reduction failed a.dims=%d a=(%d %d %d %d) op_type=%d coeff=%f keepdims=%d reduce_all=1\n", a.dims, a.w, a.h, a.d, a.c, op_type, coeff, keepdims);
  63. }
  64. return ret;
  65. }
  66. static int test_reduction(const ncnn::Mat& _a, float coeff, int keepdims, const std::vector<int>& axes_array)
  67. {
  68. ncnn::Mat a = _a;
  69. if (op_type == 9 || op_type == 10)
  70. {
  71. // value must be positive for logsum and logsumexp
  72. Randomize(a, 0.001f, 2.f);
  73. }
  74. ncnn::Mat axes(axes_array.size());
  75. {
  76. int* p = axes;
  77. for (size_t i = 0; i < axes_array.size(); i++)
  78. {
  79. p[i] = axes_array[i];
  80. }
  81. }
  82. ncnn::ParamDict pd;
  83. pd.set(0, op_type);
  84. pd.set(1, 0); // reduce_all
  85. pd.set(2, coeff);
  86. pd.set(3, axes);
  87. pd.set(4, keepdims);
  88. pd.set(5, 1); // fixbug0
  89. std::vector<ncnn::Mat> weights(0);
  90. int ret = test_layer("Reduction", pd, weights, a);
  91. if (ret != 0)
  92. {
  93. fprintf(stderr, "test_reduction failed a.dims=%d a=(%d %d %d %d) op_type=%d coeff=%f keepdims=%d", a.dims, a.w, a.h, a.d, a.c, op_type, coeff, keepdims);
  94. fprintf(stderr, " axes=");
  95. print_int_array(axes_array);
  96. fprintf(stderr, "\n");
  97. }
  98. return ret;
  99. }
  100. static int test_reduction_nd(const ncnn::Mat& a)
  101. {
  102. int ret1 = 0
  103. || test_reduction(a, 1.f, 0)
  104. || test_reduction(a, 2.f, 0)
  105. || test_reduction(a, 1.f, 1)
  106. || test_reduction(a, 2.f, 1)
  107. || test_reduction(a, 1.f, 0, IntArray(0))
  108. || test_reduction(a, 1.f, 1, IntArray(0));
  109. if (a.dims == 1 || ret1 != 0)
  110. return ret1;
  111. int ret2 = 0
  112. || test_reduction(a, 2.f, 0, IntArray(1))
  113. || test_reduction(a, 2.f, 1, IntArray(1))
  114. || test_reduction(a, 1.f, 0, IntArray(0, 1))
  115. || test_reduction(a, 1.f, 1, IntArray(0, 1));
  116. if (a.dims == 2 || ret2 != 0)
  117. return ret2;
  118. int ret3 = 0
  119. || test_reduction(a, 1.f, 0, IntArray(2))
  120. || test_reduction(a, 1.f, 1, IntArray(2))
  121. || test_reduction(a, 2.f, 0, IntArray(0, 2))
  122. || test_reduction(a, 2.f, 0, IntArray(1, 2))
  123. || test_reduction(a, 2.f, 1, IntArray(0, 2))
  124. || test_reduction(a, 2.f, 1, IntArray(1, 2))
  125. || test_reduction(a, 1.f, 0, IntArray(0, 1, 2))
  126. || test_reduction(a, 1.f, 1, IntArray(0, 1, 2));
  127. if (a.dims == 3 || ret3 != 0)
  128. return ret3;
  129. int ret4 = 0
  130. || test_reduction(a, 2.f, 0, IntArray(3))
  131. || test_reduction(a, 2.f, 1, IntArray(3))
  132. || test_reduction(a, 1.f, 0, IntArray(0, 3))
  133. || test_reduction(a, 1.f, 0, IntArray(1, 3))
  134. || test_reduction(a, 2.f, 0, IntArray(2, 3))
  135. || test_reduction(a, 1.f, 1, IntArray(0, 3))
  136. || test_reduction(a, 1.f, 1, IntArray(1, 3))
  137. || test_reduction(a, 2.f, 1, IntArray(2, 3))
  138. || test_reduction(a, 2.f, 0, IntArray(0, 1, 3))
  139. || test_reduction(a, 1.f, 0, IntArray(0, 2, 3))
  140. || test_reduction(a, 2.f, 0, IntArray(1, 2, 3))
  141. || test_reduction(a, 2.f, 1, IntArray(0, 1, 3))
  142. || test_reduction(a, 1.f, 1, IntArray(0, 2, 3))
  143. || test_reduction(a, 2.f, 1, IntArray(1, 2, 3))
  144. || test_reduction(a, 1.f, 0, IntArray(0, 1, 2, 3))
  145. || test_reduction(a, 1.f, 1, IntArray(0, 1, 2, 3));
  146. return ret4;
  147. }
  148. static int test_reduction_0()
  149. {
  150. ncnn::Mat a = RandomMat(5, 6, 7, 24);
  151. ncnn::Mat b = RandomMat(7, 8, 9, 12);
  152. ncnn::Mat c = RandomMat(3, 4, 5, 13);
  153. return 0
  154. || test_reduction_nd(a)
  155. || test_reduction_nd(b)
  156. || test_reduction_nd(c);
  157. }
  158. static int test_reduction_1()
  159. {
  160. ncnn::Mat a = RandomMat(5, 7, 24);
  161. ncnn::Mat b = RandomMat(7, 9, 12);
  162. ncnn::Mat c = RandomMat(3, 5, 13);
  163. return 0
  164. || test_reduction_nd(a)
  165. || test_reduction_nd(b)
  166. || test_reduction_nd(c);
  167. }
  168. static int test_reduction_2()
  169. {
  170. ncnn::Mat a = RandomMat(15, 24);
  171. ncnn::Mat b = RandomMat(17, 12);
  172. ncnn::Mat c = RandomMat(19, 15);
  173. return 0
  174. || test_reduction_nd(a)
  175. || test_reduction_nd(b)
  176. || test_reduction_nd(c);
  177. }
  178. static int test_reduction_3()
  179. {
  180. ncnn::Mat a = RandomMat(128);
  181. ncnn::Mat b = RandomMat(124);
  182. ncnn::Mat c = RandomMat(127);
  183. return 0
  184. || test_reduction_nd(a)
  185. || test_reduction_nd(b)
  186. || test_reduction_nd(c);
  187. }
  188. int main()
  189. {
  190. SRAND(7767517);
  191. for (op_type = 0; op_type < OP_TYPE_MAX; op_type++)
  192. {
  193. int ret = 0
  194. || test_reduction_0()
  195. || test_reduction_1()
  196. || test_reduction_2()
  197. || test_reduction_3();
  198. if (ret != 0)
  199. return ret;
  200. }
  201. return 0;
  202. }