You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_thread.cpp 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. #include "testutil.h"
  2. #include "thread.h"
  3. class TestLayer : public ncnn::Layer
  4. {
  5. public:
  6. virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt)
  7. {
  8. ThreadWorkspace workspace;
  9. workspace.layer = (Layer*)this;
  10. MutilThread thread(workspace, opt);
  11. std::vector<Mat> workspace_blobs;
  12. workspace_blobs.push_back(bottom_top_blob);
  13. thread.join(workspace_blobs);
  14. return 0;
  15. }
  16. virtual int forward_thread(void* workspace)
  17. {
  18. ThreadInfoExc* info = (ThreadInfoExc*)workspace;
  19. Mat& bottom_top_blob = info->mats->at(0);
  20. if (bottom_top_blob.elemsize == 1)
  21. {
  22. int8_t* ptr = (int8_t*)bottom_top_blob.data;
  23. const int8_t flag = 1 << 7;
  24. for (size_t i = info->start_index; i < info->end_index; i++)
  25. {
  26. if (ptr[i] & flag)
  27. {
  28. ptr[i] = -ptr[i];
  29. }
  30. }
  31. }
  32. else if (bottom_top_blob.elemsize == 2)
  33. {
  34. int16_t* ptr = (int16_t*)bottom_top_blob.data;
  35. const int16_t flag = 1 << 15;
  36. for (size_t i = info->start_index; i < info->end_index; i++)
  37. {
  38. if (ptr[i] & flag)
  39. {
  40. ptr[i] = -ptr[i];
  41. }
  42. }
  43. }
  44. else
  45. {
  46. float* ptr = (float*)bottom_top_blob.data;
  47. for (size_t i = info->start_index; i < info->end_index; i++)
  48. {
  49. if (ptr[i] < 0)
  50. {
  51. ptr[i] = -ptr[i];
  52. }
  53. }
  54. }
  55. return 0;
  56. }
  57. };
  58. static int test_thread(const ncnn::Mat& a)
  59. {
  60. ncnn::ParamDict pd;
  61. std::vector<ncnn::Mat> weights(0);
  62. int ret = test_layer("TestLayer", pd, weights, a);
  63. if (ret != 0)
  64. {
  65. fprintf(stderr, "test_thread failed a.dims=%d a=(%d %d %d %d)\n", a.dims, a.w, a.h, a.d, a.c);
  66. }
  67. return ret;
  68. }
  69. static int test_thread_0()
  70. {
  71. return 0
  72. || test_thread(RandomMat(5, 6, 7, 24))
  73. || test_thread(RandomMat(5, 6, 7, 12))
  74. || test_thread(RandomMat(5, 6, 7, 13));
  75. }
  76. static int test_thread_1()
  77. {
  78. return 0
  79. || test_thread(RandomMat(5, 7, 24))
  80. || test_thread(RandomMat(5, 6, 24))
  81. || test_thread(RandomMat(7, 9, 24));
  82. }
  83. static int test_thread_2()
  84. {
  85. return 0
  86. || test_thread(RandomMat(7, 12))
  87. || test_thread(RandomMat(5, 12))
  88. || test_thread(RandomMat(9, 12));
  89. }
  90. static int test_thread_3()
  91. {
  92. return 0
  93. || test_thread(RandomMat(7))
  94. || test_thread(RandomMat(128))
  95. || test_thread(RandomMat(256));
  96. }
  97. int main()
  98. {
  99. return 0
  100. || test_thread_0()
  101. || test_thread_1()
  102. || test_thread_2()
  103. || test_thread_3();
  104. }