You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolact.cpp 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include <stdio.h>
  15. #include <vector>
  16. #include <opencv2/core/core.hpp>
  17. #include <opencv2/highgui/highgui.hpp>
  18. #include <opencv2/imgproc/imgproc.hpp>
  19. #include "platform.h"
  20. #include "net.h"
  21. #if NCNN_VULKAN
  22. #include "gpu.h"
  23. #endif // NCNN_VULKAN
  24. struct Object
  25. {
  26. cv::Rect_<float> rect;
  27. int label;
  28. float prob;
  29. std::vector<float> maskdata;
  30. cv::Mat mask;
  31. };
  32. static inline float intersection_area(const Object& a, const Object& b)
  33. {
  34. cv::Rect_<float> inter = a.rect & b.rect;
  35. return inter.area();
  36. }
  37. static void qsort_descent_inplace(std::vector<Object>& objects, int left, int right)
  38. {
  39. int i = left;
  40. int j = right;
  41. float p = objects[(left + right) / 2].prob;
  42. while (i <= j)
  43. {
  44. while (objects[i].prob > p)
  45. i++;
  46. while (objects[j].prob < p)
  47. j--;
  48. if (i <= j)
  49. {
  50. // swap
  51. std::swap(objects[i], objects[j]);
  52. i++;
  53. j--;
  54. }
  55. }
  56. #pragma omp parallel sections
  57. {
  58. #pragma omp section
  59. {
  60. if (left < j) qsort_descent_inplace(objects, left, j);
  61. }
  62. #pragma omp section
  63. {
  64. if (i < right) qsort_descent_inplace(objects, i, right);
  65. }
  66. }
  67. }
  68. static void qsort_descent_inplace(std::vector<Object>& objects)
  69. {
  70. if (objects.empty())
  71. return;
  72. qsort_descent_inplace(objects, 0, objects.size() - 1);
  73. }
  74. static void nms_sorted_bboxes(const std::vector<Object>& objects, std::vector<int>& picked, float nms_threshold)
  75. {
  76. picked.clear();
  77. const int n = objects.size();
  78. std::vector<float> areas(n);
  79. for (int i = 0; i < n; i++)
  80. {
  81. areas[i] = objects[i].rect.area();
  82. }
  83. for (int i = 0; i < n; i++)
  84. {
  85. const Object& a = objects[i];
  86. int keep = 1;
  87. for (int j = 0; j < (int)picked.size(); j++)
  88. {
  89. const Object& b = objects[picked[j]];
  90. // intersection over union
  91. float inter_area = intersection_area(a, b);
  92. float union_area = areas[i] + areas[picked[j]] - inter_area;
  93. // float IoU = inter_area / union_area
  94. if (inter_area / union_area > nms_threshold)
  95. keep = 0;
  96. }
  97. if (keep)
  98. picked.push_back(i);
  99. }
  100. }
  101. static int detect_yolact(const cv::Mat& bgr, std::vector<Object>& objects)
  102. {
  103. ncnn::Net yolact;
  104. #if NCNN_VULKAN
  105. yolact.opt.use_vulkan_compute = true;
  106. #endif // NCNN_VULKAN
  107. // original model converted from https://github.com/dbolya/yolact
  108. // yolact_resnet50_54_800000.pth
  109. // the ncnn model https://github.com/nihui/ncnn-assets/tree/master/models
  110. yolact.load_param("yolact.param");
  111. yolact.load_model("yolact.bin");
  112. const int target_size = 550;
  113. int img_w = bgr.cols;
  114. int img_h = bgr.rows;
  115. ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, target_size, target_size);
  116. const float mean_vals[3] = {123.68f, 116.78f, 103.94f};
  117. const float norm_vals[3] = {1.0/58.40f, 1.0/57.12f, 1.0/57.38f};
  118. in.substract_mean_normalize(mean_vals, norm_vals);
  119. ncnn::Extractor ex = yolact.create_extractor();
  120. // ex.set_num_threads(4);
  121. ex.input("input.1", in);
  122. ncnn::Mat maskmaps;
  123. ncnn::Mat location;
  124. ncnn::Mat mask;
  125. ncnn::Mat confidence;
  126. ex.extract("619", maskmaps);// 138x138 x 32
  127. ex.extract("816", location);// 4 x 19248
  128. ex.extract("818", mask);// maskdim 32 x 19248
  129. ex.extract("820", confidence);// 81 x 19248
  130. int num_class = confidence.w;
  131. int num_priors = confidence.h;
  132. // make priorbox
  133. ncnn::Mat priorbox(4, num_priors);
  134. {
  135. const int conv_ws[5] = {69, 35, 18, 9, 5};
  136. const int conv_hs[5] = {69, 35, 18, 9, 5};
  137. const float aspect_ratios[3] = {1.f, 0.5f, 2.f};
  138. const float scales[5] = {24.f, 48.f, 96.f, 192.f, 384.f};
  139. float* pb = priorbox;
  140. for (int p = 0; p < 5; p++)
  141. {
  142. int conv_w = conv_ws[p];
  143. int conv_h = conv_hs[p];
  144. float scale = scales[p];
  145. for (int i = 0; i < conv_h; i++)
  146. {
  147. for (int j = 0; j < conv_w; j++)
  148. {
  149. // +0.5 because priors are in center-size notation
  150. float cx = (j + 0.5f) / conv_w;
  151. float cy = (i + 0.5f) / conv_h;
  152. for (int k = 0; k < 3; k++)
  153. {
  154. float ar = aspect_ratios[k];
  155. ar = sqrt(ar);
  156. float w = scale * ar / 550;
  157. float h = scale / ar / 550;
  158. // This is for backward compatability with a bug where I made everything square by accident
  159. // cfg.backbone.use_square_anchors:
  160. h = w;
  161. pb[0] = cx;
  162. pb[1] = cy;
  163. pb[2] = w;
  164. pb[3] = h;
  165. pb += 4;
  166. }
  167. }
  168. }
  169. }
  170. }
  171. const float confidence_thresh = 0.05f;
  172. const float nms_threshold = 0.5f;
  173. const int keep_top_k = 200;
  174. std::vector< std::vector<Object> > class_candidates;
  175. class_candidates.resize(num_class);
  176. for (int i = 0; i < num_priors; i++)
  177. {
  178. const float* conf = confidence.row(i);
  179. const float* loc = location.row(i);
  180. const float* pb = priorbox.row(i);
  181. const float* maskdata = mask.row(i);
  182. // find class id with highest score
  183. // start from 1 to skip background
  184. int label = 0;
  185. float score = 0.f;
  186. for (int j=1; j<num_class; j++)
  187. {
  188. float class_score = conf[j];
  189. if (class_score > score)
  190. {
  191. label = j;
  192. score = class_score;
  193. }
  194. }
  195. // ignore background or low score
  196. if (label == 0 || score <= confidence_thresh)
  197. continue;
  198. // CENTER_SIZE
  199. float var[4] = {0.1f, 0.1f, 0.2f, 0.2f};
  200. float pb_cx = pb[0];
  201. float pb_cy = pb[1];
  202. float pb_w = pb[2];
  203. float pb_h = pb[3];
  204. float bbox_cx = var[0] * loc[0] * pb_w + pb_cx;
  205. float bbox_cy = var[1] * loc[1] * pb_h + pb_cy;
  206. float bbox_w = (float)(exp(var[2] * loc[2]) * pb_w);
  207. float bbox_h = (float)(exp(var[3] * loc[3]) * pb_h);
  208. float obj_x1 = bbox_cx - bbox_w * 0.5f;
  209. float obj_y1 = bbox_cy - bbox_h * 0.5f;
  210. float obj_x2 = bbox_cx + bbox_w * 0.5f;
  211. float obj_y2 = bbox_cy + bbox_h * 0.5f;
  212. // clip
  213. obj_x1 = std::max(std::min(obj_x1 * bgr.cols, (float)(bgr.cols - 1)), 0.f);
  214. obj_y1 = std::max(std::min(obj_y1 * bgr.rows, (float)(bgr.rows - 1)), 0.f);
  215. obj_x2 = std::max(std::min(obj_x2 * bgr.cols, (float)(bgr.cols - 1)), 0.f);
  216. obj_y2 = std::max(std::min(obj_y2 * bgr.rows, (float)(bgr.rows - 1)), 0.f);
  217. // append object
  218. Object obj;
  219. obj.rect = cv::Rect_<float>(obj_x1, obj_y1, obj_x2-obj_x1+1, obj_y2-obj_y1+1);
  220. obj.label = label;
  221. obj.prob = score;
  222. obj.maskdata = std::vector<float>(maskdata, maskdata + mask.w);
  223. class_candidates[label].push_back(obj);
  224. }
  225. objects.clear();
  226. for (int i = 0; i < (int)class_candidates.size(); i++)
  227. {
  228. std::vector<Object>& candidates = class_candidates[i];
  229. qsort_descent_inplace(candidates);
  230. std::vector<int> picked;
  231. nms_sorted_bboxes(candidates, picked, nms_threshold);
  232. for (int j = 0; j < (int)picked.size(); j++)
  233. {
  234. int z = picked[j];
  235. objects.push_back(candidates[z]);
  236. }
  237. }
  238. qsort_descent_inplace(objects);
  239. // keep_top_k
  240. if (keep_top_k < (int)objects.size())
  241. {
  242. objects.resize(keep_top_k);
  243. }
  244. // generate mask
  245. for (int i=0; i<objects.size(); i++)
  246. {
  247. Object& obj = objects[i];
  248. cv::Mat mask(maskmaps.h, maskmaps.w, CV_32FC1);
  249. {
  250. mask = cv::Scalar(0.f);
  251. for (int p=0; p<maskmaps.c; p++)
  252. {
  253. const float* maskmap = maskmaps.channel(p);
  254. float coeff = obj.maskdata[p];
  255. float* mp = (float*)mask.data;
  256. // mask += m * coeff
  257. for (int j=0; j<maskmaps.w * maskmaps.h; j++)
  258. {
  259. mp[j] += maskmap[j] * coeff;
  260. }
  261. }
  262. }
  263. cv::Mat mask2;
  264. cv::resize(mask, mask2, cv::Size(img_w, img_h));
  265. // crop obj box and binarize
  266. obj.mask = cv::Mat(img_h, img_w, CV_8UC1);
  267. {
  268. obj.mask = cv::Scalar(0);
  269. for (int y=0; y<img_h; y++)
  270. {
  271. if (y < obj.rect.y || y > obj.rect.y + obj.rect.height)
  272. continue;
  273. const float* mp2 = mask2.ptr<const float>(y);
  274. uchar* bmp = obj.mask.ptr<uchar>(y);
  275. for (int x=0; x<img_w; x++)
  276. {
  277. if (x < obj.rect.x || x > obj.rect.x + obj.rect.width)
  278. continue;
  279. bmp[x] = mp2[x] > 0.5f ? 255 : 0;
  280. }
  281. }
  282. }
  283. }
  284. return 0;
  285. }
  286. static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
  287. {
  288. static const char* class_names[] = {"background",
  289. "person", "bicycle", "car", "motorcycle", "airplane", "bus",
  290. "train", "truck", "boat", "traffic light", "fire hydrant",
  291. "stop sign", "parking meter", "bench", "bird", "cat", "dog",
  292. "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
  293. "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
  294. "skis", "snowboard", "sports ball", "kite", "baseball bat",
  295. "baseball glove", "skateboard", "surfboard", "tennis racket",
  296. "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
  297. "banana", "apple", "sandwich", "orange", "broccoli", "carrot",
  298. "hot dog", "pizza", "donut", "cake", "chair", "couch",
  299. "potted plant", "bed", "dining table", "toilet", "tv", "laptop",
  300. "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
  301. "toaster", "sink", "refrigerator", "book", "clock", "vase",
  302. "scissors", "teddy bear", "hair drier", "toothbrush"};
  303. static const unsigned char colors[19][3] = {
  304. {244, 67, 54},
  305. {233, 30, 99},
  306. {156, 39, 176},
  307. {103, 58, 183},
  308. { 63, 81, 181},
  309. { 33, 150, 243},
  310. { 3, 169, 244},
  311. { 0, 188, 212},
  312. { 0, 150, 136},
  313. { 76, 175, 80},
  314. {139, 195, 74},
  315. {205, 220, 57},
  316. {255, 235, 59},
  317. {255, 193, 7},
  318. {255, 152, 0},
  319. {255, 87, 34},
  320. {121, 85, 72},
  321. {158, 158, 158},
  322. { 96, 125, 139}
  323. };
  324. cv::Mat image = bgr.clone();
  325. int color_index = 0;
  326. for (size_t i = 0; i < objects.size(); i++)
  327. {
  328. const Object& obj = objects[i];
  329. if (obj.prob < 0.15)
  330. continue;
  331. fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
  332. obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);
  333. const unsigned char* color = colors[color_index++];
  334. cv::rectangle(image, obj.rect, cv::Scalar(color[0], color[1], color[2]));
  335. char text[256];
  336. sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);
  337. int baseLine = 0;
  338. cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
  339. int x = obj.rect.x;
  340. int y = obj.rect.y - label_size.height - baseLine;
  341. if (y < 0)
  342. y = 0;
  343. if (x + label_size.width > image.cols)
  344. x = image.cols - label_size.width;
  345. cv::rectangle(image, cv::Rect(cv::Point(x, y),
  346. cv::Size(label_size.width, label_size.height + baseLine)),
  347. cv::Scalar(255, 255, 255), -1);
  348. cv::putText(image, text, cv::Point(x, y + label_size.height),
  349. cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));
  350. // draw mask
  351. for (int y=0; y<image.rows; y++)
  352. {
  353. const uchar* mp = obj.mask.ptr(y);
  354. uchar* p = image.ptr(y);
  355. for (int x=0; x<image.cols; x++)
  356. {
  357. if (mp[x] == 255)
  358. {
  359. p[0] = cv::saturate_cast<uchar>(p[0] * 0.5 + color[0] * 0.5);
  360. p[1] = cv::saturate_cast<uchar>(p[1] * 0.5 + color[1] * 0.5);
  361. p[2] = cv::saturate_cast<uchar>(p[2] * 0.5 + color[2] * 0.5);
  362. }
  363. p += 3;
  364. }
  365. }
  366. }
  367. cv::imwrite("result.png", image);
  368. cv::imshow("image", image);
  369. cv::waitKey(0);
  370. }
  371. int main(int argc, char** argv)
  372. {
  373. if (argc != 2)
  374. {
  375. fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
  376. return -1;
  377. }
  378. const char* imagepath = argv[1];
  379. cv::Mat m = cv::imread(imagepath, 1);
  380. if (m.empty())
  381. {
  382. fprintf(stderr, "cv::imread %s failed\n", imagepath);
  383. return -1;
  384. }
  385. #if NCNN_VULKAN
  386. ncnn::create_gpu_instance();
  387. #endif // NCNN_VULKAN
  388. std::vector<Object> objects;
  389. detect_yolact(m, objects);
  390. #if NCNN_VULKAN
  391. ncnn::destroy_gpu_instance();
  392. #endif // NCNN_VULKAN
  393. draw_objects(m, objects);
  394. return 0;
  395. }