You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_expression.cpp 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include <stdio.h>
  15. #include <stdarg.h>
  16. #include "expression.h"
  17. static int test_count_expression_blobs(const std::string& expr, int true_count)
  18. {
  19. int count = ncnn::count_expression_blobs(expr);
  20. if (count != true_count)
  21. {
  22. fprintf(stderr, "test_count_expression_blobs failed expr=%s got %d\n", expr.c_str(), count);
  23. return -1;
  24. }
  25. return 0;
  26. }
  27. static int test_expression_0()
  28. {
  29. return 0
  30. || test_count_expression_blobs("1,2,3,4,5,6", 0)
  31. || test_count_expression_blobs("-1,1h,2w", 3)
  32. || test_count_expression_blobs("2,9d,2c,-1", 10);
  33. }
  34. static int test_eval_list_expression(const std::string& expr, std::vector<ncnn::Mat>& blobs, int ndim, ...)
  35. {
  36. // construct true list
  37. std::vector<int> true_list(ndim);
  38. va_list ap;
  39. va_start(ap, ndim);
  40. for (int i = 0; i < ndim; i++)
  41. {
  42. true_list[i] = va_arg(ap, int);
  43. }
  44. va_end(ap);
  45. std::vector<int> list;
  46. int er = ncnn::eval_list_expression(expr, blobs, list);
  47. if (er != 0)
  48. return -1;
  49. bool failed = false;
  50. if (list.size() != true_list.size())
  51. {
  52. failed = true;
  53. }
  54. else
  55. {
  56. for (size_t i = 0; i < list.size(); i++)
  57. {
  58. if (list[i] != true_list[i])
  59. {
  60. failed = true;
  61. break;
  62. }
  63. }
  64. }
  65. if (failed)
  66. {
  67. fprintf(stderr, "test_eval_list_expression failed expr=%s got [", expr.c_str());
  68. for (size_t i = 0; i < list.size(); i++)
  69. {
  70. fprintf(stderr, "%d", list[i]);
  71. if (i + 1 != list.size())
  72. fprintf(stderr, ",");
  73. }
  74. fprintf(stderr, "]\n");
  75. return -1;
  76. }
  77. return 0;
  78. }
  79. static int test_expression_1()
  80. {
  81. std::vector<ncnn::Mat> blobs(2);
  82. blobs[0] = ncnn::Mat(100, 200, 44);
  83. blobs[1] = ncnn::Mat(10, 20, 2, 4);
  84. return 0
  85. || test_eval_list_expression("+(trunc(*(0w,0.5)),-(0c,10)),floor(/(1h,0.5)),+(0c,1c),round(2.0)", blobs, 4, 84, 40, 48, 2)
  86. || test_eval_list_expression("//(0w,3),+(0w,1w),-(0h,1h),*(0c,1c)", blobs, 4, 33, 110, 180, 176)
  87. || test_eval_list_expression("floor(//(0w,2.99)),round(+(0w,1.01)),trunc(-(0h,1.9)),ceil(*(1d,2.99))", blobs, 4, 33, 101, 198, 6)
  88. || test_eval_list_expression("round(*(abs(asin(sin(0w))),10.11)),ceil(*(abs(acos(cos(+(0w,3)))),10.11))", blobs, 2, 5, 25)
  89. || test_eval_list_expression("floor(*(abs(asinh(sinh(/(0w,100)))),10.11)),trunc(*(abs(acosh(cosh(*(0w,0.004)))),10.11))", blobs, 2, 10, 4)
  90. || test_eval_list_expression("round(*(abs(atan(tan(0w))),10.11)),ceil(*(abs(atanh(tanh(-(0w,99)))),10.11))", blobs, 2, 5, 11)
  91. || test_eval_list_expression("floor(min(max(*(square(sqrt(0w)),1.2121),100),120))", blobs, 1, 120)
  92. || test_eval_list_expression("min(max(trunc(*(log(exp(*(neg(0w),0.001))),-144)),15),20)", blobs, 1, 15)
  93. || test_eval_list_expression("round(*(erf(reciprocal(log10(1h))),999))", blobs, 1, 722)
  94. || test_eval_list_expression("ceil(pow(fmod(atan2(0w,1d),1c),14.14)),floor(logaddexp(remainder(0c,10),6))", blobs, 2, 495, 6)
  95. || test_eval_list_expression("floor(*(square(sqrt(0w)),1.2121))", blobs, 1, 121)
  96. || test_eval_list_expression("rshift(lshift(xor(or(and(1d,18),9),4),4),2)", blobs, 1, 60)
  97. || test_eval_list_expression("ceil(*(rsqrt(+(+(sign(1w),10),*(sign(-(neg(1d)),0.5),3))),100))", blobs, 1, 36);
  98. }
  99. static int test_expression_2()
  100. {
  101. std::vector<ncnn::Mat> blobs(2);
  102. blobs[0] = ncnn::Mat(10, 20, 4);
  103. blobs[1] = ncnn::Mat(1, 2, 3, 4);
  104. // expect error blob index out of bound
  105. if (test_eval_list_expression("0w,1h,2c,1d", blobs, 0) != -1)
  106. return -1;
  107. // expect error divide by zero
  108. if (test_eval_list_expression("//(0w,-(0c,1c))", blobs, 0) != -1)
  109. return -1;
  110. // expect error malformed token
  111. if (test_eval_list_expression("1c,#(0w,1)", blobs, 0) != -1)
  112. return -1;
  113. if (test_eval_list_expression("1c,+(qwq,1w)", blobs, 0) != -1)
  114. return -1;
  115. return 0;
  116. }
  117. int main()
  118. {
  119. return 0
  120. || test_expression_0()
  121. || test_expression_1()
  122. || test_expression_2();
  123. }