You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_expression.cpp 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. // Copyright 2025 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include <stdio.h>
  4. #include <stdarg.h>
  5. #include "expression.h"
  6. static int test_count_expression_blobs(const std::string& expr, int true_count)
  7. {
  8. int count = ncnn::count_expression_blobs(expr);
  9. if (count != true_count)
  10. {
  11. fprintf(stderr, "test_count_expression_blobs failed expr=%s got %d\n", expr.c_str(), count);
  12. return -1;
  13. }
  14. return 0;
  15. }
  16. static int test_expression_0()
  17. {
  18. return 0
  19. || test_count_expression_blobs("1,2,3,4,5,6", 0)
  20. || test_count_expression_blobs("-1,1h,2w", 3)
  21. || test_count_expression_blobs("2,9d,2c,-1", 10);
  22. }
  23. static int test_eval_list_expression(const std::string& expr, std::vector<ncnn::Mat>& blobs, int ndim, ...)
  24. {
  25. // construct true list
  26. std::vector<int> true_list(ndim);
  27. va_list ap;
  28. va_start(ap, ndim);
  29. for (int i = 0; i < ndim; i++)
  30. {
  31. true_list[i] = va_arg(ap, int);
  32. }
  33. va_end(ap);
  34. std::vector<int> list;
  35. int er = ncnn::eval_list_expression(expr, blobs, list);
  36. if (er != 0)
  37. return -1;
  38. bool failed = false;
  39. if (list.size() != true_list.size())
  40. {
  41. failed = true;
  42. }
  43. else
  44. {
  45. for (size_t i = 0; i < list.size(); i++)
  46. {
  47. if (list[i] != true_list[i])
  48. {
  49. failed = true;
  50. break;
  51. }
  52. }
  53. }
  54. if (failed)
  55. {
  56. fprintf(stderr, "test_eval_list_expression failed expr=%s got [", expr.c_str());
  57. for (size_t i = 0; i < list.size(); i++)
  58. {
  59. fprintf(stderr, "%d", list[i]);
  60. if (i + 1 != list.size())
  61. fprintf(stderr, ",");
  62. }
  63. fprintf(stderr, "]\n");
  64. return -1;
  65. }
  66. return 0;
  67. }
  68. static int test_expression_1()
  69. {
  70. std::vector<ncnn::Mat> blobs(2);
  71. blobs[0] = ncnn::Mat(100, 200, 44);
  72. blobs[1] = ncnn::Mat(10, 20, 2, 4);
  73. return 0
  74. || test_eval_list_expression("+(trunc(*(0w,0.5)),-(0c,10)),floor(/(1h,0.5)),+(0c,1c),round(2.0)", blobs, 4, 84, 40, 48, 2)
  75. || test_eval_list_expression("//(0w,3),+(0w,1w),-(0h,1h),*(0c,1c)", blobs, 4, 33, 110, 180, 176)
  76. || test_eval_list_expression("floor(//(0w,2.99)),round(+(0w,1.01)),trunc(-(0h,1.9)),ceil(*(1d,2.99))", blobs, 4, 33, 101, 198, 6)
  77. || test_eval_list_expression("round(*(abs(asin(sin(0w))),10.11)),ceil(*(abs(acos(cos(+(0w,3)))),10.11))", blobs, 2, 5, 25)
  78. || test_eval_list_expression("floor(*(abs(asinh(sinh(/(0w,100)))),10.11)),trunc(*(abs(acosh(cosh(*(0w,0.004)))),10.11))", blobs, 2, 10, 4)
  79. || test_eval_list_expression("round(*(abs(atan(tan(0w))),10.11)),ceil(*(abs(atanh(tanh(-(0w,99)))),10.11))", blobs, 2, 5, 11)
  80. || test_eval_list_expression("floor(min(max(*(square(sqrt(0w)),1.2121),100),120))", blobs, 1, 120)
  81. || test_eval_list_expression("min(max(trunc(*(log(exp(*(neg(0w),0.001))),-144)),15),20)", blobs, 1, 15)
  82. || test_eval_list_expression("round(*(erf(reciprocal(log10(1h))),999))", blobs, 1, 722)
  83. || test_eval_list_expression("ceil(pow(fmod(atan2(0w,1d),1c),14.14)),floor(logaddexp(remainder(0c,10),6))", blobs, 2, 495, 6)
  84. || test_eval_list_expression("floor(*(square(sqrt(0w)),1.2121))", blobs, 1, 121)
  85. || test_eval_list_expression("rshift(lshift(xor(or(and(1d,18),9),4),4),2)", blobs, 1, 60)
  86. || test_eval_list_expression("ceil(*(rsqrt(+(+(sign(1w),10),*(sign(-(neg(1d)),0.5),3))),100))", blobs, 1, 36);
  87. }
  88. static int test_expression_2()
  89. {
  90. std::vector<ncnn::Mat> blobs(2);
  91. blobs[0] = ncnn::Mat(10, 20, 4);
  92. blobs[1] = ncnn::Mat(1, 2, 3, 4);
  93. // expect error blob index out of bound
  94. if (test_eval_list_expression("0w,1h,2c,1d", blobs, 0) != -1)
  95. return -1;
  96. // expect error divide by zero
  97. if (test_eval_list_expression("//(0w,-(0c,1c))", blobs, 0) != -1)
  98. return -1;
  99. // expect error malformed token
  100. if (test_eval_list_expression("1c,#(0w,1)", blobs, 0) != -1)
  101. return -1;
  102. if (test_eval_list_expression("1c,+(qwq,1w)", blobs, 0) != -1)
  103. return -1;
  104. return 0;
  105. }
  106. int main()
  107. {
  108. return 0
  109. || test_expression_0()
  110. || test_expression_1()
  111. || test_expression_2();
  112. }