You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

p2pnet.cpp 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "net.h"
  15. #if defined(USE_NCNN_SIMPLEOCV)
  16. #include "simpleocv.h"
  17. #else
  18. #include <opencv2/core/core.hpp>
  19. #include <opencv2/highgui/highgui.hpp>
  20. #include <opencv2/imgproc/imgproc.hpp>
  21. #endif
  22. #include <stdlib.h>
  23. #include <float.h>
  24. #include <stdio.h>
  25. #include <vector>
  26. struct CrowdPoint
  27. {
  28. cv::Point pt;
  29. float prob;
  30. };
  31. static void shift(int w, int h, int stride, std::vector<float> anchor_points, std::vector<float>& shifted_anchor_points)
  32. {
  33. std::vector<float> x_, y_;
  34. for (int i = 0; i < w; i++)
  35. {
  36. float x = (i + 0.5) * stride;
  37. x_.push_back(x);
  38. }
  39. for (int i = 0; i < h; i++)
  40. {
  41. float y = (i + 0.5) * stride;
  42. y_.push_back(y);
  43. }
  44. std::vector<float> shift_x((size_t)w * h, 0), shift_y((size_t)w * h, 0);
  45. for (int i = 0; i < h; i++)
  46. {
  47. for (int j = 0; j < w; j++)
  48. {
  49. shift_x[i * w + j] = x_[j];
  50. }
  51. }
  52. for (int i = 0; i < h; i++)
  53. {
  54. for (int j = 0; j < w; j++)
  55. {
  56. shift_y[i * w + j] = y_[i];
  57. }
  58. }
  59. std::vector<float> shifts((size_t)w * h * 2, 0);
  60. for (int i = 0; i < w * h; i++)
  61. {
  62. shifts[i * 2] = shift_x[i];
  63. shifts[i * 2 + 1] = shift_y[i];
  64. }
  65. shifted_anchor_points.resize((size_t)2 * w * h * anchor_points.size() / 2, 0);
  66. for (int i = 0; i < w * h; i++)
  67. {
  68. for (int j = 0; j < anchor_points.size() / 2; j++)
  69. {
  70. float x = anchor_points[j * 2] + shifts[i * 2];
  71. float y = anchor_points[j * 2 + 1] + shifts[i * 2 + 1];
  72. shifted_anchor_points[i * anchor_points.size() / 2 * 2 + j * 2] = x;
  73. shifted_anchor_points[i * anchor_points.size() / 2 * 2 + j * 2 + 1] = y;
  74. }
  75. }
  76. }
  77. static void generate_anchor_points(int stride, int row, int line, std::vector<float>& anchor_points)
  78. {
  79. float row_step = (float)stride / row;
  80. float line_step = (float)stride / line;
  81. std::vector<float> x_, y_;
  82. for (int i = 1; i < line + 1; i++)
  83. {
  84. float x = (i - 0.5) * line_step - stride / 2;
  85. x_.push_back(x);
  86. }
  87. for (int i = 1; i < row + 1; i++)
  88. {
  89. float y = (i - 0.5) * row_step - stride / 2;
  90. y_.push_back(y);
  91. }
  92. std::vector<float> shift_x((size_t)row * line, 0), shift_y((size_t)row * line, 0);
  93. for (int i = 0; i < row; i++)
  94. {
  95. for (int j = 0; j < line; j++)
  96. {
  97. shift_x[i * line + j] = x_[j];
  98. }
  99. }
  100. for (int i = 0; i < row; i++)
  101. {
  102. for (int j = 0; j < line; j++)
  103. {
  104. shift_y[i * line + j] = y_[i];
  105. }
  106. }
  107. anchor_points.resize((size_t)row * line * 2, 0);
  108. for (int i = 0; i < row * line; i++)
  109. {
  110. float x = shift_x[i];
  111. float y = shift_y[i];
  112. anchor_points[i * 2] = x;
  113. anchor_points[i * 2 + 1] = y;
  114. }
  115. }
  116. static void generate_anchor_points(int img_w, int img_h, std::vector<int> pyramid_levels, int row, int line, std::vector<float>& all_anchor_points)
  117. {
  118. std::vector<std::pair<int, int> > image_shapes;
  119. std::vector<int> strides;
  120. for (int i = 0; i < pyramid_levels.size(); i++)
  121. {
  122. int new_h = std::floor((img_h + std::pow(2, pyramid_levels[i]) - 1) / std::pow(2, pyramid_levels[i]));
  123. int new_w = std::floor((img_w + std::pow(2, pyramid_levels[i]) - 1) / std::pow(2, pyramid_levels[i]));
  124. image_shapes.push_back(std::make_pair(new_w, new_h));
  125. strides.push_back(std::pow(2, pyramid_levels[i]));
  126. }
  127. all_anchor_points.clear();
  128. for (int i = 0; i < pyramid_levels.size(); i++)
  129. {
  130. std::vector<float> anchor_points;
  131. generate_anchor_points(std::pow(2, pyramid_levels[i]), row, line, anchor_points);
  132. std::vector<float> shifted_anchor_points;
  133. shift(image_shapes[i].first, image_shapes[i].second, strides[i], anchor_points, shifted_anchor_points);
  134. all_anchor_points.insert(all_anchor_points.end(), shifted_anchor_points.begin(), shifted_anchor_points.end());
  135. }
  136. }
  137. static int detect_crowd(const cv::Mat& bgr, std::vector<CrowdPoint>& crowd_points)
  138. {
  139. ncnn::Option opt;
  140. opt.num_threads = 4;
  141. opt.use_vulkan_compute = false;
  142. opt.use_bf16_storage = false;
  143. ncnn::Net net;
  144. net.opt = opt;
  145. // model is converted from
  146. // https://github.com/TencentYoutuResearch/CrowdCounting-P2PNet
  147. // the ncnn model https://pan.baidu.com/s/1O1CBgvY6yJkrK8Npxx3VMg pwd: ezhx
  148. if (net.load_param("p2pnet.param"))
  149. exit(-1);
  150. if (net.load_model("p2pnet.bin"))
  151. exit(-1);
  152. int width = bgr.cols;
  153. int height = bgr.rows;
  154. int new_width = width / 128 * 128;
  155. int new_height = height / 128 * 128;
  156. ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, width, height, new_width, new_height);
  157. std::vector<int> pyramid_levels(1, 3);
  158. std::vector<float> all_anchor_points;
  159. generate_anchor_points(in.w, in.h, pyramid_levels, 2, 2, all_anchor_points);
  160. ncnn::Mat anchor_points = ncnn::Mat(2, all_anchor_points.size() / 2, all_anchor_points.data());
  161. ncnn::Extractor ex = net.create_extractor();
  162. const float mean_vals1[3] = {123.675f, 116.28f, 103.53f};
  163. const float norm_vals1[3] = {0.01712475f, 0.0175f, 0.01742919f};
  164. in.substract_mean_normalize(mean_vals1, norm_vals1);
  165. ex.input("input", in);
  166. ex.input("anchor", anchor_points);
  167. ncnn::Mat score, points;
  168. ex.extract("pred_scores", score);
  169. ex.extract("pred_points", points);
  170. for (int i = 0; i < points.h; i++)
  171. {
  172. float* score_data = score.row(i);
  173. float* points_data = points.row(i);
  174. CrowdPoint cp;
  175. int x = points_data[0] / new_width * width;
  176. int y = points_data[1] / new_height * height;
  177. cp.pt = cv::Point(x, y);
  178. cp.prob = score_data[1];
  179. crowd_points.push_back(cp);
  180. }
  181. return 0;
  182. }
  183. static void draw_result(const cv::Mat& bgr, const std::vector<CrowdPoint>& crowd_points)
  184. {
  185. cv::Mat image = bgr.clone();
  186. const float threshold = 0.5f;
  187. for (int i = 0; i < crowd_points.size(); i++)
  188. {
  189. if (crowd_points[i].prob > threshold)
  190. {
  191. cv::circle(image, crowd_points[i].pt, 4, cv::Scalar(0, 0, 255), -1, 8, 0);
  192. }
  193. }
  194. cv::imshow("image", image);
  195. cv::waitKey();
  196. }
  197. int main(int argc, char** argv)
  198. {
  199. if (argc != 2)
  200. {
  201. fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
  202. return -1;
  203. }
  204. const char* imagepath = argv[1];
  205. cv::Mat bgr = cv::imread(imagepath, 1);
  206. if (bgr.empty())
  207. {
  208. fprintf(stderr, "cv::imread %s failed\n", imagepath);
  209. return -1;
  210. }
  211. std::vector<CrowdPoint> crowd_points;
  212. detect_crowd(bgr, crowd_points);
  213. draw_result(bgr, crowd_points);
  214. return 0;
  215. }