You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolact.cpp 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "net.h"
  15. #if defined(USE_NCNN_SIMPLEOCV)
  16. #include "simpleocv.h"
  17. #else
  18. #include <opencv2/core/core.hpp>
  19. #include <opencv2/highgui/highgui.hpp>
  20. #include <opencv2/imgproc/imgproc.hpp>
  21. #endif
  22. #include <stdio.h>
  23. #include <vector>
  24. struct Object
  25. {
  26. cv::Rect_<float> rect;
  27. int label;
  28. float prob;
  29. std::vector<float> maskdata;
  30. cv::Mat mask;
  31. };
  32. static inline float intersection_area(const Object& a, const Object& b)
  33. {
  34. cv::Rect_<float> inter = a.rect & b.rect;
  35. return inter.area();
  36. }
  37. static void qsort_descent_inplace(std::vector<Object>& objects, int left, int right)
  38. {
  39. int i = left;
  40. int j = right;
  41. float p = objects[(left + right) / 2].prob;
  42. while (i <= j)
  43. {
  44. while (objects[i].prob > p)
  45. i++;
  46. while (objects[j].prob < p)
  47. j--;
  48. if (i <= j)
  49. {
  50. // swap
  51. std::swap(objects[i], objects[j]);
  52. i++;
  53. j--;
  54. }
  55. }
  56. #pragma omp parallel sections
  57. {
  58. #pragma omp section
  59. {
  60. if (left < j) qsort_descent_inplace(objects, left, j);
  61. }
  62. #pragma omp section
  63. {
  64. if (i < right) qsort_descent_inplace(objects, i, right);
  65. }
  66. }
  67. }
  68. static void qsort_descent_inplace(std::vector<Object>& objects)
  69. {
  70. if (objects.empty())
  71. return;
  72. qsort_descent_inplace(objects, 0, objects.size() - 1);
  73. }
  74. static void nms_sorted_bboxes(const std::vector<Object>& faceobjects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
  75. {
  76. picked.clear();
  77. const int n = faceobjects.size();
  78. std::vector<float> areas(n);
  79. for (int i = 0; i < n; i++)
  80. {
  81. areas[i] = faceobjects[i].rect.area();
  82. }
  83. for (int i = 0; i < n; i++)
  84. {
  85. const Object& a = faceobjects[i];
  86. int keep = 1;
  87. for (int j = 0; j < (int)picked.size(); j++)
  88. {
  89. const Object& b = faceobjects[picked[j]];
  90. if (!agnostic && a.label != b.label)
  91. continue;
  92. // intersection over union
  93. float inter_area = intersection_area(a, b);
  94. float union_area = areas[i] + areas[picked[j]] - inter_area;
  95. // float IoU = inter_area / union_area
  96. if (inter_area / union_area > nms_threshold)
  97. keep = 0;
  98. }
  99. if (keep)
  100. picked.push_back(i);
  101. }
  102. }
  103. static int detect_yolact(const cv::Mat& bgr, std::vector<Object>& objects)
  104. {
  105. ncnn::Net yolact;
  106. yolact.opt.use_vulkan_compute = true;
  107. // original model converted from https://github.com/dbolya/yolact
  108. // yolact_resnet50_54_800000.pth
  109. // the ncnn model https://github.com/nihui/ncnn-assets/tree/master/models
  110. if (yolact.load_param("yolact.param"))
  111. exit(-1);
  112. if (yolact.load_model("yolact.bin"))
  113. exit(-1);
  114. const int target_size = 550;
  115. int img_w = bgr.cols;
  116. int img_h = bgr.rows;
  117. ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, target_size, target_size);
  118. const float mean_vals[3] = {123.68f, 116.78f, 103.94f};
  119. const float norm_vals[3] = {1.0 / 58.40f, 1.0 / 57.12f, 1.0 / 57.38f};
  120. in.substract_mean_normalize(mean_vals, norm_vals);
  121. ncnn::Extractor ex = yolact.create_extractor();
  122. ex.input("input.1", in);
  123. ncnn::Mat maskmaps;
  124. ncnn::Mat location;
  125. ncnn::Mat mask;
  126. ncnn::Mat confidence;
  127. ex.extract("619", maskmaps); // 138x138 x 32
  128. ex.extract("816", location); // 4 x 19248
  129. ex.extract("818", mask); // maskdim 32 x 19248
  130. ex.extract("820", confidence); // 81 x 19248
  131. int num_class = confidence.w;
  132. int num_priors = confidence.h;
  133. // make priorbox
  134. ncnn::Mat priorbox(4, num_priors);
  135. {
  136. const int conv_ws[5] = {69, 35, 18, 9, 5};
  137. const int conv_hs[5] = {69, 35, 18, 9, 5};
  138. const float aspect_ratios[3] = {1.f, 0.5f, 2.f};
  139. const float scales[5] = {24.f, 48.f, 96.f, 192.f, 384.f};
  140. float* pb = priorbox;
  141. for (int p = 0; p < 5; p++)
  142. {
  143. int conv_w = conv_ws[p];
  144. int conv_h = conv_hs[p];
  145. float scale = scales[p];
  146. for (int i = 0; i < conv_h; i++)
  147. {
  148. for (int j = 0; j < conv_w; j++)
  149. {
  150. // +0.5 because priors are in center-size notation
  151. float cx = (j + 0.5f) / conv_w;
  152. float cy = (i + 0.5f) / conv_h;
  153. for (int k = 0; k < 3; k++)
  154. {
  155. float ar = aspect_ratios[k];
  156. ar = sqrt(ar);
  157. float w = scale * ar / 550;
  158. float h = scale / ar / 550;
  159. // This is for backward compatibility with a bug where I made everything square by accident
  160. // cfg.backbone.use_square_anchors:
  161. h = w;
  162. pb[0] = cx;
  163. pb[1] = cy;
  164. pb[2] = w;
  165. pb[3] = h;
  166. pb += 4;
  167. }
  168. }
  169. }
  170. }
  171. }
  172. const float confidence_thresh = 0.05f;
  173. const float nms_threshold = 0.5f;
  174. const int keep_top_k = 200;
  175. std::vector<std::vector<Object> > class_candidates;
  176. class_candidates.resize(num_class);
  177. for (int i = 0; i < num_priors; i++)
  178. {
  179. const float* conf = confidence.row(i);
  180. const float* loc = location.row(i);
  181. const float* pb = priorbox.row(i);
  182. const float* maskdata = mask.row(i);
  183. // find class id with highest score
  184. // start from 1 to skip background
  185. int label = 0;
  186. float score = 0.f;
  187. for (int j = 1; j < num_class; j++)
  188. {
  189. float class_score = conf[j];
  190. if (class_score > score)
  191. {
  192. label = j;
  193. score = class_score;
  194. }
  195. }
  196. // ignore background or low score
  197. if (label == 0 || score <= confidence_thresh)
  198. continue;
  199. // CENTER_SIZE
  200. float var[4] = {0.1f, 0.1f, 0.2f, 0.2f};
  201. float pb_cx = pb[0];
  202. float pb_cy = pb[1];
  203. float pb_w = pb[2];
  204. float pb_h = pb[3];
  205. float bbox_cx = var[0] * loc[0] * pb_w + pb_cx;
  206. float bbox_cy = var[1] * loc[1] * pb_h + pb_cy;
  207. float bbox_w = (float)(exp(var[2] * loc[2]) * pb_w);
  208. float bbox_h = (float)(exp(var[3] * loc[3]) * pb_h);
  209. float obj_x1 = bbox_cx - bbox_w * 0.5f;
  210. float obj_y1 = bbox_cy - bbox_h * 0.5f;
  211. float obj_x2 = bbox_cx + bbox_w * 0.5f;
  212. float obj_y2 = bbox_cy + bbox_h * 0.5f;
  213. // clip
  214. obj_x1 = std::max(std::min(obj_x1 * bgr.cols, (float)(bgr.cols - 1)), 0.f);
  215. obj_y1 = std::max(std::min(obj_y1 * bgr.rows, (float)(bgr.rows - 1)), 0.f);
  216. obj_x2 = std::max(std::min(obj_x2 * bgr.cols, (float)(bgr.cols - 1)), 0.f);
  217. obj_y2 = std::max(std::min(obj_y2 * bgr.rows, (float)(bgr.rows - 1)), 0.f);
  218. // append object
  219. Object obj;
  220. obj.rect = cv::Rect_<float>(obj_x1, obj_y1, obj_x2 - obj_x1 + 1, obj_y2 - obj_y1 + 1);
  221. obj.label = label;
  222. obj.prob = score;
  223. obj.maskdata = std::vector<float>(maskdata, maskdata + mask.w);
  224. class_candidates[label].push_back(obj);
  225. }
  226. objects.clear();
  227. for (int i = 0; i < (int)class_candidates.size(); i++)
  228. {
  229. std::vector<Object>& candidates = class_candidates[i];
  230. qsort_descent_inplace(candidates);
  231. std::vector<int> picked;
  232. nms_sorted_bboxes(candidates, picked, nms_threshold);
  233. for (int j = 0; j < (int)picked.size(); j++)
  234. {
  235. int z = picked[j];
  236. objects.push_back(candidates[z]);
  237. }
  238. }
  239. qsort_descent_inplace(objects);
  240. // keep_top_k
  241. if (keep_top_k < (int)objects.size())
  242. {
  243. objects.resize(keep_top_k);
  244. }
  245. // generate mask
  246. for (int i = 0; i < (int)objects.size(); i++)
  247. {
  248. Object& obj = objects[i];
  249. cv::Mat mask(maskmaps.h, maskmaps.w, CV_32FC1);
  250. {
  251. mask = cv::Scalar(0.f);
  252. for (int p = 0; p < maskmaps.c; p++)
  253. {
  254. const float* maskmap = maskmaps.channel(p);
  255. float coeff = obj.maskdata[p];
  256. float* mp = (float*)mask.data;
  257. // mask += m * coeff
  258. for (int j = 0; j < maskmaps.w * maskmaps.h; j++)
  259. {
  260. mp[j] += maskmap[j] * coeff;
  261. }
  262. }
  263. }
  264. cv::Mat mask2;
  265. cv::resize(mask, mask2, cv::Size(img_w, img_h));
  266. // crop obj box and binarize
  267. obj.mask = cv::Mat(img_h, img_w, CV_8UC1);
  268. {
  269. obj.mask = cv::Scalar(0);
  270. for (int y = 0; y < img_h; y++)
  271. {
  272. if (y < obj.rect.y || y > obj.rect.y + obj.rect.height)
  273. continue;
  274. const float* mp2 = mask2.ptr<const float>(y);
  275. uchar* bmp = obj.mask.ptr<uchar>(y);
  276. for (int x = 0; x < img_w; x++)
  277. {
  278. if (x < obj.rect.x || x > obj.rect.x + obj.rect.width)
  279. continue;
  280. bmp[x] = mp2[x] > 0.5f ? 255 : 0;
  281. }
  282. }
  283. }
  284. }
  285. return 0;
  286. }
  287. static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
  288. {
  289. static const char* class_names[] = {"background",
  290. "person", "bicycle", "car", "motorcycle", "airplane", "bus",
  291. "train", "truck", "boat", "traffic light", "fire hydrant",
  292. "stop sign", "parking meter", "bench", "bird", "cat", "dog",
  293. "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
  294. "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
  295. "skis", "snowboard", "sports ball", "kite", "baseball bat",
  296. "baseball glove", "skateboard", "surfboard", "tennis racket",
  297. "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
  298. "banana", "apple", "sandwich", "orange", "broccoli", "carrot",
  299. "hot dog", "pizza", "donut", "cake", "chair", "couch",
  300. "potted plant", "bed", "dining table", "toilet", "tv", "laptop",
  301. "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
  302. "toaster", "sink", "refrigerator", "book", "clock", "vase",
  303. "scissors", "teddy bear", "hair drier", "toothbrush"
  304. };
  305. static const unsigned char colors[81][3] = {
  306. {56, 0, 255},
  307. {226, 255, 0},
  308. {0, 94, 255},
  309. {0, 37, 255},
  310. {0, 255, 94},
  311. {255, 226, 0},
  312. {0, 18, 255},
  313. {255, 151, 0},
  314. {170, 0, 255},
  315. {0, 255, 56},
  316. {255, 0, 75},
  317. {0, 75, 255},
  318. {0, 255, 169},
  319. {255, 0, 207},
  320. {75, 255, 0},
  321. {207, 0, 255},
  322. {37, 0, 255},
  323. {0, 207, 255},
  324. {94, 0, 255},
  325. {0, 255, 113},
  326. {255, 18, 0},
  327. {255, 0, 56},
  328. {18, 0, 255},
  329. {0, 255, 226},
  330. {170, 255, 0},
  331. {255, 0, 245},
  332. {151, 255, 0},
  333. {132, 255, 0},
  334. {75, 0, 255},
  335. {151, 0, 255},
  336. {0, 151, 255},
  337. {132, 0, 255},
  338. {0, 255, 245},
  339. {255, 132, 0},
  340. {226, 0, 255},
  341. {255, 37, 0},
  342. {207, 255, 0},
  343. {0, 255, 207},
  344. {94, 255, 0},
  345. {0, 226, 255},
  346. {56, 255, 0},
  347. {255, 94, 0},
  348. {255, 113, 0},
  349. {0, 132, 255},
  350. {255, 0, 132},
  351. {255, 170, 0},
  352. {255, 0, 188},
  353. {113, 255, 0},
  354. {245, 0, 255},
  355. {113, 0, 255},
  356. {255, 188, 0},
  357. {0, 113, 255},
  358. {255, 0, 0},
  359. {0, 56, 255},
  360. {255, 0, 113},
  361. {0, 255, 188},
  362. {255, 0, 94},
  363. {255, 0, 18},
  364. {18, 255, 0},
  365. {0, 255, 132},
  366. {0, 188, 255},
  367. {0, 245, 255},
  368. {0, 169, 255},
  369. {37, 255, 0},
  370. {255, 0, 151},
  371. {188, 0, 255},
  372. {0, 255, 37},
  373. {0, 255, 0},
  374. {255, 0, 170},
  375. {255, 0, 37},
  376. {255, 75, 0},
  377. {0, 0, 255},
  378. {255, 207, 0},
  379. {255, 0, 226},
  380. {255, 245, 0},
  381. {188, 255, 0},
  382. {0, 255, 18},
  383. {0, 255, 75},
  384. {0, 255, 151},
  385. {255, 56, 0},
  386. {245, 255, 0}
  387. };
  388. cv::Mat image = bgr.clone();
  389. int color_index = 0;
  390. for (size_t i = 0; i < objects.size(); i++)
  391. {
  392. const Object& obj = objects[i];
  393. if (obj.prob < 0.15)
  394. continue;
  395. fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
  396. obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);
  397. const unsigned char* color = colors[color_index % 81];
  398. color_index++;
  399. cv::rectangle(image, obj.rect, cv::Scalar(color[0], color[1], color[2]));
  400. char text[256];
  401. sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);
  402. int baseLine = 0;
  403. cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
  404. int x = obj.rect.x;
  405. int y = obj.rect.y - label_size.height - baseLine;
  406. if (y < 0)
  407. y = 0;
  408. if (x + label_size.width > image.cols)
  409. x = image.cols - label_size.width;
  410. cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
  411. cv::Scalar(255, 255, 255), -1);
  412. cv::putText(image, text, cv::Point(x, y + label_size.height),
  413. cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));
  414. // draw mask
  415. for (int y = 0; y < image.rows; y++)
  416. {
  417. const uchar* mp = obj.mask.ptr(y);
  418. uchar* p = image.ptr(y);
  419. for (int x = 0; x < image.cols; x++)
  420. {
  421. if (mp[x] == 255)
  422. {
  423. p[0] = cv::saturate_cast<uchar>(p[0] * 0.5 + color[0] * 0.5);
  424. p[1] = cv::saturate_cast<uchar>(p[1] * 0.5 + color[1] * 0.5);
  425. p[2] = cv::saturate_cast<uchar>(p[2] * 0.5 + color[2] * 0.5);
  426. }
  427. p += 3;
  428. }
  429. }
  430. }
  431. cv::imwrite("result.png", image);
  432. cv::imshow("image", image);
  433. cv::waitKey(0);
  434. }
  435. int main(int argc, char** argv)
  436. {
  437. if (argc != 2)
  438. {
  439. fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
  440. return -1;
  441. }
  442. const char* imagepath = argv[1];
  443. cv::Mat m = cv::imread(imagepath, 1);
  444. if (m.empty())
  445. {
  446. fprintf(stderr, "cv::imread %s failed\n", imagepath);
  447. return -1;
  448. }
  449. std::vector<Object> objects;
  450. detect_yolact(m, objects);
  451. draw_objects(m, objects);
  452. return 0;
  453. }