You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_flip.cpp 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. // Copyright 2025 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "testutil.h"
  4. static std::vector<int> IntArray(int a0)
  5. {
  6. std::vector<int> m(1);
  7. m[0] = a0;
  8. return m;
  9. }
  10. static std::vector<int> IntArray(int a0, int a1)
  11. {
  12. std::vector<int> m(2);
  13. m[0] = a0;
  14. m[1] = a1;
  15. return m;
  16. }
  17. static std::vector<int> IntArray(int a0, int a1, int a2)
  18. {
  19. std::vector<int> m(3);
  20. m[0] = a0;
  21. m[1] = a1;
  22. m[2] = a2;
  23. return m;
  24. }
  25. static std::vector<int> IntArray(int a0, int a1, int a2, int a3)
  26. {
  27. std::vector<int> m(4);
  28. m[0] = a0;
  29. m[1] = a1;
  30. m[2] = a2;
  31. m[3] = a3;
  32. return m;
  33. }
  34. static void print_int_array(const std::vector<int>& a)
  35. {
  36. fprintf(stderr, "[");
  37. for (size_t i = 0; i < a.size(); i++)
  38. {
  39. fprintf(stderr, " %d", a[i]);
  40. }
  41. fprintf(stderr, " ]");
  42. }
  43. static int test_flip(const ncnn::Mat& a, const std::vector<int>& axes_array)
  44. {
  45. ncnn::Mat axes(axes_array.size());
  46. {
  47. int* p = axes;
  48. for (size_t i = 0; i < axes_array.size(); i++)
  49. {
  50. p[i] = axes_array[i];
  51. }
  52. }
  53. ncnn::ParamDict pd;
  54. pd.set(0, axes);
  55. std::vector<ncnn::Mat> weights(0);
  56. int ret = test_layer("Flip", pd, weights, a);
  57. if (ret != 0)
  58. {
  59. fprintf(stderr, "test_flip failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c);
  60. fprintf(stderr, " axes=");
  61. print_int_array(axes_array);
  62. fprintf(stderr, "\n");
  63. }
  64. return ret;
  65. }
  66. static int test_flip_nd(const ncnn::Mat& a)
  67. {
  68. int ret1 = test_flip(a, IntArray(0));
  69. if (a.dims == 1 || ret1 != 0)
  70. return ret1;
  71. int ret2 = 0
  72. || test_flip(a, IntArray(0))
  73. || test_flip(a, IntArray(1))
  74. || test_flip(a, IntArray(0, 1));
  75. if (a.dims == 2 || ret2 != 0)
  76. return ret2;
  77. int ret3 = 0
  78. || test_flip(a, IntArray(0))
  79. || test_flip(a, IntArray(1))
  80. || test_flip(a, IntArray(2))
  81. || test_flip(a, IntArray(0, 1))
  82. || test_flip(a, IntArray(0, 2))
  83. || test_flip(a, IntArray(1, 2))
  84. || test_flip(a, IntArray(0, 1, 2));
  85. if (a.dims == 3 || ret3 != 0)
  86. return ret3;
  87. int ret4 = 0
  88. || test_flip(a, IntArray(0))
  89. || test_flip(a, IntArray(1))
  90. || test_flip(a, IntArray(2))
  91. || test_flip(a, IntArray(3))
  92. || test_flip(a, IntArray(0, 1))
  93. || test_flip(a, IntArray(0, 2))
  94. || test_flip(a, IntArray(0, 3))
  95. || test_flip(a, IntArray(1, 2))
  96. || test_flip(a, IntArray(1, 3))
  97. || test_flip(a, IntArray(2, 3))
  98. || test_flip(a, IntArray(0, 1, 2))
  99. || test_flip(a, IntArray(0, 1, 3))
  100. || test_flip(a, IntArray(0, 2, 3))
  101. || test_flip(a, IntArray(1, 2, 3))
  102. || test_flip(a, IntArray(0, 1, 2, 3));
  103. return ret4;
  104. }
  105. static int test_flip_0()
  106. {
  107. ncnn::Mat a = RandomMat(5, 6, 7, 24);
  108. ncnn::Mat b = RandomMat(7, 8, 9, 12);
  109. ncnn::Mat c = RandomMat(3, 4, 5, 13);
  110. return 0
  111. || test_flip_nd(a)
  112. || test_flip_nd(b)
  113. || test_flip_nd(c);
  114. }
  115. static int test_flip_1()
  116. {
  117. ncnn::Mat a = RandomMat(5, 7, 24);
  118. ncnn::Mat b = RandomMat(7, 9, 12);
  119. ncnn::Mat c = RandomMat(3, 5, 13);
  120. return 0
  121. || test_flip_nd(a)
  122. || test_flip_nd(b)
  123. || test_flip_nd(c);
  124. }
  125. static int test_flip_2()
  126. {
  127. ncnn::Mat a = RandomMat(15, 24);
  128. ncnn::Mat b = RandomMat(17, 12);
  129. ncnn::Mat c = RandomMat(19, 15);
  130. return 0
  131. || test_flip_nd(a)
  132. || test_flip_nd(b)
  133. || test_flip_nd(c);
  134. }
  135. static int test_flip_3()
  136. {
  137. ncnn::Mat a = RandomMat(128);
  138. ncnn::Mat b = RandomMat(124);
  139. ncnn::Mat c = RandomMat(127);
  140. return 0
  141. || test_flip_nd(a)
  142. || test_flip_nd(b)
  143. || test_flip_nd(c);
  144. }
  145. int main()
  146. {
  147. SRAND(7767517);
  148. return 0
  149. || test_flip_0()
  150. || test_flip_1()
  151. || test_flip_2()
  152. || test_flip_3();
  153. }