You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

p2pnet.cpp 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. // Copyright 2021 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "net.h"
  4. #if defined(USE_NCNN_SIMPLEOCV)
  5. #include "simpleocv.h"
  6. #else
  7. #include <opencv2/core/core.hpp>
  8. #include <opencv2/highgui/highgui.hpp>
  9. #include <opencv2/imgproc/imgproc.hpp>
  10. #endif
  11. #include <stdlib.h>
  12. #include <float.h>
  13. #include <stdio.h>
  14. #include <vector>
  15. struct CrowdPoint
  16. {
  17. cv::Point pt;
  18. float prob;
  19. };
  20. static void shift(int w, int h, int stride, std::vector<float> anchor_points, std::vector<float>& shifted_anchor_points)
  21. {
  22. std::vector<float> x_, y_;
  23. for (int i = 0; i < w; i++)
  24. {
  25. float x = (i + 0.5) * stride;
  26. x_.push_back(x);
  27. }
  28. for (int i = 0; i < h; i++)
  29. {
  30. float y = (i + 0.5) * stride;
  31. y_.push_back(y);
  32. }
  33. std::vector<float> shift_x((size_t)w * h, 0), shift_y((size_t)w * h, 0);
  34. for (int i = 0; i < h; i++)
  35. {
  36. for (int j = 0; j < w; j++)
  37. {
  38. shift_x[i * w + j] = x_[j];
  39. }
  40. }
  41. for (int i = 0; i < h; i++)
  42. {
  43. for (int j = 0; j < w; j++)
  44. {
  45. shift_y[i * w + j] = y_[i];
  46. }
  47. }
  48. std::vector<float> shifts((size_t)w * h * 2, 0);
  49. for (int i = 0; i < w * h; i++)
  50. {
  51. shifts[i * 2] = shift_x[i];
  52. shifts[i * 2 + 1] = shift_y[i];
  53. }
  54. shifted_anchor_points.resize((size_t)2 * w * h * anchor_points.size() / 2, 0);
  55. for (int i = 0; i < w * h; i++)
  56. {
  57. for (int j = 0; j < anchor_points.size() / 2; j++)
  58. {
  59. float x = anchor_points[j * 2] + shifts[i * 2];
  60. float y = anchor_points[j * 2 + 1] + shifts[i * 2 + 1];
  61. shifted_anchor_points[i * anchor_points.size() / 2 * 2 + j * 2] = x;
  62. shifted_anchor_points[i * anchor_points.size() / 2 * 2 + j * 2 + 1] = y;
  63. }
  64. }
  65. }
  66. static void generate_anchor_points(int stride, int row, int line, std::vector<float>& anchor_points)
  67. {
  68. float row_step = (float)stride / row;
  69. float line_step = (float)stride / line;
  70. std::vector<float> x_, y_;
  71. for (int i = 1; i < line + 1; i++)
  72. {
  73. float x = (i - 0.5) * line_step - stride / 2;
  74. x_.push_back(x);
  75. }
  76. for (int i = 1; i < row + 1; i++)
  77. {
  78. float y = (i - 0.5) * row_step - stride / 2;
  79. y_.push_back(y);
  80. }
  81. std::vector<float> shift_x((size_t)row * line, 0), shift_y((size_t)row * line, 0);
  82. for (int i = 0; i < row; i++)
  83. {
  84. for (int j = 0; j < line; j++)
  85. {
  86. shift_x[i * line + j] = x_[j];
  87. }
  88. }
  89. for (int i = 0; i < row; i++)
  90. {
  91. for (int j = 0; j < line; j++)
  92. {
  93. shift_y[i * line + j] = y_[i];
  94. }
  95. }
  96. anchor_points.resize((size_t)row * line * 2, 0);
  97. for (int i = 0; i < row * line; i++)
  98. {
  99. float x = shift_x[i];
  100. float y = shift_y[i];
  101. anchor_points[i * 2] = x;
  102. anchor_points[i * 2 + 1] = y;
  103. }
  104. }
  105. static void generate_anchor_points(int img_w, int img_h, std::vector<int> pyramid_levels, int row, int line, std::vector<float>& all_anchor_points)
  106. {
  107. std::vector<std::pair<int, int> > image_shapes;
  108. std::vector<int> strides;
  109. for (int i = 0; i < pyramid_levels.size(); i++)
  110. {
  111. int new_h = std::floor((img_h + std::pow(2, pyramid_levels[i]) - 1) / std::pow(2, pyramid_levels[i]));
  112. int new_w = std::floor((img_w + std::pow(2, pyramid_levels[i]) - 1) / std::pow(2, pyramid_levels[i]));
  113. image_shapes.push_back(std::make_pair(new_w, new_h));
  114. strides.push_back(std::pow(2, pyramid_levels[i]));
  115. }
  116. all_anchor_points.clear();
  117. for (int i = 0; i < pyramid_levels.size(); i++)
  118. {
  119. std::vector<float> anchor_points;
  120. generate_anchor_points(std::pow(2, pyramid_levels[i]), row, line, anchor_points);
  121. std::vector<float> shifted_anchor_points;
  122. shift(image_shapes[i].first, image_shapes[i].second, strides[i], anchor_points, shifted_anchor_points);
  123. all_anchor_points.insert(all_anchor_points.end(), shifted_anchor_points.begin(), shifted_anchor_points.end());
  124. }
  125. }
  126. static int detect_crowd(const cv::Mat& bgr, std::vector<CrowdPoint>& crowd_points)
  127. {
  128. ncnn::Option opt;
  129. opt.num_threads = 4;
  130. opt.use_vulkan_compute = false;
  131. opt.use_bf16_storage = false;
  132. ncnn::Net net;
  133. net.opt = opt;
  134. // model is converted from
  135. // https://github.com/TencentYoutuResearch/CrowdCounting-P2PNet
  136. // the ncnn model https://pan.baidu.com/s/1O1CBgvY6yJkrK8Npxx3VMg pwd: ezhx
  137. if (net.load_param("p2pnet.param"))
  138. exit(-1);
  139. if (net.load_model("p2pnet.bin"))
  140. exit(-1);
  141. int width = bgr.cols;
  142. int height = bgr.rows;
  143. int new_width = width / 128 * 128;
  144. int new_height = height / 128 * 128;
  145. ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, width, height, new_width, new_height);
  146. std::vector<int> pyramid_levels(1, 3);
  147. std::vector<float> all_anchor_points;
  148. generate_anchor_points(in.w, in.h, pyramid_levels, 2, 2, all_anchor_points);
  149. ncnn::Mat anchor_points = ncnn::Mat(2, all_anchor_points.size() / 2, all_anchor_points.data());
  150. ncnn::Extractor ex = net.create_extractor();
  151. const float mean_vals1[3] = {123.675f, 116.28f, 103.53f};
  152. const float norm_vals1[3] = {0.01712475f, 0.0175f, 0.01742919f};
  153. in.substract_mean_normalize(mean_vals1, norm_vals1);
  154. ex.input("input", in);
  155. ex.input("anchor", anchor_points);
  156. ncnn::Mat score, points;
  157. ex.extract("pred_scores", score);
  158. ex.extract("pred_points", points);
  159. for (int i = 0; i < points.h; i++)
  160. {
  161. float* score_data = score.row(i);
  162. float* points_data = points.row(i);
  163. CrowdPoint cp;
  164. int x = points_data[0] / new_width * width;
  165. int y = points_data[1] / new_height * height;
  166. cp.pt = cv::Point(x, y);
  167. cp.prob = score_data[1];
  168. crowd_points.push_back(cp);
  169. }
  170. return 0;
  171. }
  172. static void draw_result(const cv::Mat& bgr, const std::vector<CrowdPoint>& crowd_points)
  173. {
  174. cv::Mat image = bgr.clone();
  175. const float threshold = 0.5f;
  176. for (int i = 0; i < crowd_points.size(); i++)
  177. {
  178. if (crowd_points[i].prob > threshold)
  179. {
  180. cv::circle(image, crowd_points[i].pt, 4, cv::Scalar(0, 0, 255), -1, 8, 0);
  181. }
  182. }
  183. cv::imshow("image", image);
  184. cv::waitKey();
  185. }
  186. int main(int argc, char** argv)
  187. {
  188. if (argc != 2)
  189. {
  190. fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
  191. return -1;
  192. }
  193. const char* imagepath = argv[1];
  194. cv::Mat bgr = cv::imread(imagepath, 1);
  195. if (bgr.empty())
  196. {
  197. fprintf(stderr, "cv::imread %s failed\n", imagepath);
  198. return -1;
  199. }
  200. std::vector<CrowdPoint> crowd_points;
  201. detect_crowd(bgr, crowd_points);
  202. draw_result(bgr, crowd_points);
  203. return 0;
  204. }