You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_softmax_oom.cpp 1.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. // Copyright 2024 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "testutil.h"
  4. static int test_softmax_oom(const ncnn::Mat& a, int axis)
  5. {
  6. ncnn::ParamDict pd;
  7. pd.set(0, axis); // axis
  8. pd.set(1, 1); // fixbug0
  9. std::vector<ncnn::Mat> weights(0);
  10. int ret = test_layer_oom("Softmax", pd, weights, a);
  11. if (ret != 0)
  12. {
  13. fprintf(stderr, "test_softmax_oom failed a.dims=%d a=(%d %d %d %d) axis=%d\n", a.dims, a.w, a.h, a.d, a.c, axis);
  14. }
  15. return ret;
  16. }
  17. static int test_softmax_0()
  18. {
  19. ncnn::Mat a = RandomMat(18, 17, 19, 32);
  20. return test_softmax_oom(a, 0) || test_softmax_oom(a, 1) || test_softmax_oom(a, 2) || test_softmax_oom(a, 3);
  21. }
  22. static int test_softmax_1()
  23. {
  24. ncnn::Mat a = RandomMat(25, 27, 32);
  25. return test_softmax_oom(a, 0) || test_softmax_oom(a, 1) || test_softmax_oom(a, 2);
  26. }
  27. static int test_softmax_2()
  28. {
  29. ncnn::Mat a = RandomMat(25, 32);
  30. return test_softmax_oom(a, 0) || test_softmax_oom(a, 1);
  31. }
  32. static int test_softmax_3()
  33. {
  34. ncnn::Mat a = RandomMat(128);
  35. return test_softmax_oom(a, 0);
  36. }
  37. int main()
  38. {
  39. SRAND(7767517);
  40. return 0
  41. || test_softmax_0()
  42. || test_softmax_1()
  43. || test_softmax_2()
  44. || test_softmax_3();
  45. }