You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolact.cpp 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. // Copyright 2020 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "net.h"
  4. #if defined(USE_NCNN_SIMPLEOCV)
  5. #include "simpleocv.h"
  6. #else
  7. #include <opencv2/core/core.hpp>
  8. #include <opencv2/highgui/highgui.hpp>
  9. #include <opencv2/imgproc/imgproc.hpp>
  10. #endif
  11. #include <stdio.h>
  12. #include <vector>
  13. struct Object
  14. {
  15. cv::Rect_<float> rect;
  16. int label;
  17. float prob;
  18. std::vector<float> maskdata;
  19. cv::Mat mask;
  20. };
  21. static inline float intersection_area(const Object& a, const Object& b)
  22. {
  23. cv::Rect_<float> inter = a.rect & b.rect;
  24. return inter.area();
  25. }
  26. static void qsort_descent_inplace(std::vector<Object>& objects, int left, int right)
  27. {
  28. int i = left;
  29. int j = right;
  30. float p = objects[(left + right) / 2].prob;
  31. while (i <= j)
  32. {
  33. while (objects[i].prob > p)
  34. i++;
  35. while (objects[j].prob < p)
  36. j--;
  37. if (i <= j)
  38. {
  39. // swap
  40. std::swap(objects[i], objects[j]);
  41. i++;
  42. j--;
  43. }
  44. }
  45. #pragma omp parallel sections
  46. {
  47. #pragma omp section
  48. {
  49. if (left < j) qsort_descent_inplace(objects, left, j);
  50. }
  51. #pragma omp section
  52. {
  53. if (i < right) qsort_descent_inplace(objects, i, right);
  54. }
  55. }
  56. }
  57. static void qsort_descent_inplace(std::vector<Object>& objects)
  58. {
  59. if (objects.empty())
  60. return;
  61. qsort_descent_inplace(objects, 0, objects.size() - 1);
  62. }
  63. static void nms_sorted_bboxes(const std::vector<Object>& faceobjects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
  64. {
  65. picked.clear();
  66. const int n = faceobjects.size();
  67. std::vector<float> areas(n);
  68. for (int i = 0; i < n; i++)
  69. {
  70. areas[i] = faceobjects[i].rect.area();
  71. }
  72. for (int i = 0; i < n; i++)
  73. {
  74. const Object& a = faceobjects[i];
  75. int keep = 1;
  76. for (int j = 0; j < (int)picked.size(); j++)
  77. {
  78. const Object& b = faceobjects[picked[j]];
  79. if (!agnostic && a.label != b.label)
  80. continue;
  81. // intersection over union
  82. float inter_area = intersection_area(a, b);
  83. float union_area = areas[i] + areas[picked[j]] - inter_area;
  84. // float IoU = inter_area / union_area
  85. if (inter_area / union_area > nms_threshold)
  86. keep = 0;
  87. }
  88. if (keep)
  89. picked.push_back(i);
  90. }
  91. }
  92. static int detect_yolact(const cv::Mat& bgr, std::vector<Object>& objects)
  93. {
  94. ncnn::Net yolact;
  95. yolact.opt.use_vulkan_compute = true;
  96. // original model converted from https://github.com/dbolya/yolact
  97. // yolact_resnet50_54_800000.pth
  98. // the ncnn model https://github.com/nihui/ncnn-assets/tree/master/models
  99. if (yolact.load_param("yolact.param"))
  100. exit(-1);
  101. if (yolact.load_model("yolact.bin"))
  102. exit(-1);
  103. const int target_size = 550;
  104. int img_w = bgr.cols;
  105. int img_h = bgr.rows;
  106. ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, target_size, target_size);
  107. const float mean_vals[3] = {123.68f, 116.78f, 103.94f};
  108. const float norm_vals[3] = {1.0 / 58.40f, 1.0 / 57.12f, 1.0 / 57.38f};
  109. in.substract_mean_normalize(mean_vals, norm_vals);
  110. ncnn::Extractor ex = yolact.create_extractor();
  111. ex.input("input.1", in);
  112. ncnn::Mat maskmaps;
  113. ncnn::Mat location;
  114. ncnn::Mat mask;
  115. ncnn::Mat confidence;
  116. ex.extract("619", maskmaps); // 138x138 x 32
  117. ex.extract("816", location); // 4 x 19248
  118. ex.extract("818", mask); // maskdim 32 x 19248
  119. ex.extract("820", confidence); // 81 x 19248
  120. int num_class = confidence.w;
  121. int num_priors = confidence.h;
  122. // make priorbox
  123. ncnn::Mat priorbox(4, num_priors);
  124. {
  125. const int conv_ws[5] = {69, 35, 18, 9, 5};
  126. const int conv_hs[5] = {69, 35, 18, 9, 5};
  127. const float aspect_ratios[3] = {1.f, 0.5f, 2.f};
  128. const float scales[5] = {24.f, 48.f, 96.f, 192.f, 384.f};
  129. float* pb = priorbox;
  130. for (int p = 0; p < 5; p++)
  131. {
  132. int conv_w = conv_ws[p];
  133. int conv_h = conv_hs[p];
  134. float scale = scales[p];
  135. for (int i = 0; i < conv_h; i++)
  136. {
  137. for (int j = 0; j < conv_w; j++)
  138. {
  139. // +0.5 because priors are in center-size notation
  140. float cx = (j + 0.5f) / conv_w;
  141. float cy = (i + 0.5f) / conv_h;
  142. for (int k = 0; k < 3; k++)
  143. {
  144. float ar = aspect_ratios[k];
  145. ar = sqrt(ar);
  146. float w = scale * ar / 550;
  147. float h = scale / ar / 550;
  148. // This is for backward compatibility with a bug where I made everything square by accident
  149. // cfg.backbone.use_square_anchors:
  150. h = w;
  151. pb[0] = cx;
  152. pb[1] = cy;
  153. pb[2] = w;
  154. pb[3] = h;
  155. pb += 4;
  156. }
  157. }
  158. }
  159. }
  160. }
  161. const float confidence_thresh = 0.05f;
  162. const float nms_threshold = 0.5f;
  163. const int keep_top_k = 200;
  164. std::vector<std::vector<Object> > class_candidates;
  165. class_candidates.resize(num_class);
  166. for (int i = 0; i < num_priors; i++)
  167. {
  168. const float* conf = confidence.row(i);
  169. const float* loc = location.row(i);
  170. const float* pb = priorbox.row(i);
  171. const float* maskdata = mask.row(i);
  172. // find class id with highest score
  173. // start from 1 to skip background
  174. int label = 0;
  175. float score = 0.f;
  176. for (int j = 1; j < num_class; j++)
  177. {
  178. float class_score = conf[j];
  179. if (class_score > score)
  180. {
  181. label = j;
  182. score = class_score;
  183. }
  184. }
  185. // ignore background or low score
  186. if (label == 0 || score <= confidence_thresh)
  187. continue;
  188. // CENTER_SIZE
  189. float var[4] = {0.1f, 0.1f, 0.2f, 0.2f};
  190. float pb_cx = pb[0];
  191. float pb_cy = pb[1];
  192. float pb_w = pb[2];
  193. float pb_h = pb[3];
  194. float bbox_cx = var[0] * loc[0] * pb_w + pb_cx;
  195. float bbox_cy = var[1] * loc[1] * pb_h + pb_cy;
  196. float bbox_w = (float)(exp(var[2] * loc[2]) * pb_w);
  197. float bbox_h = (float)(exp(var[3] * loc[3]) * pb_h);
  198. float obj_x1 = bbox_cx - bbox_w * 0.5f;
  199. float obj_y1 = bbox_cy - bbox_h * 0.5f;
  200. float obj_x2 = bbox_cx + bbox_w * 0.5f;
  201. float obj_y2 = bbox_cy + bbox_h * 0.5f;
  202. // clip
  203. obj_x1 = std::max(std::min(obj_x1 * bgr.cols, (float)(bgr.cols - 1)), 0.f);
  204. obj_y1 = std::max(std::min(obj_y1 * bgr.rows, (float)(bgr.rows - 1)), 0.f);
  205. obj_x2 = std::max(std::min(obj_x2 * bgr.cols, (float)(bgr.cols - 1)), 0.f);
  206. obj_y2 = std::max(std::min(obj_y2 * bgr.rows, (float)(bgr.rows - 1)), 0.f);
  207. // append object
  208. Object obj;
  209. obj.rect = cv::Rect_<float>(obj_x1, obj_y1, obj_x2 - obj_x1 + 1, obj_y2 - obj_y1 + 1);
  210. obj.label = label;
  211. obj.prob = score;
  212. obj.maskdata = std::vector<float>(maskdata, maskdata + mask.w);
  213. class_candidates[label].push_back(obj);
  214. }
  215. objects.clear();
  216. for (int i = 0; i < (int)class_candidates.size(); i++)
  217. {
  218. std::vector<Object>& candidates = class_candidates[i];
  219. qsort_descent_inplace(candidates);
  220. std::vector<int> picked;
  221. nms_sorted_bboxes(candidates, picked, nms_threshold);
  222. for (int j = 0; j < (int)picked.size(); j++)
  223. {
  224. int z = picked[j];
  225. objects.push_back(candidates[z]);
  226. }
  227. }
  228. qsort_descent_inplace(objects);
  229. // keep_top_k
  230. if (keep_top_k < (int)objects.size())
  231. {
  232. objects.resize(keep_top_k);
  233. }
  234. // generate mask
  235. for (int i = 0; i < (int)objects.size(); i++)
  236. {
  237. Object& obj = objects[i];
  238. cv::Mat mask(maskmaps.h, maskmaps.w, CV_32FC1);
  239. {
  240. mask = cv::Scalar(0.f);
  241. for (int p = 0; p < maskmaps.c; p++)
  242. {
  243. const float* maskmap = maskmaps.channel(p);
  244. float coeff = obj.maskdata[p];
  245. float* mp = (float*)mask.data;
  246. // mask += m * coeff
  247. for (int j = 0; j < maskmaps.w * maskmaps.h; j++)
  248. {
  249. mp[j] += maskmap[j] * coeff;
  250. }
  251. }
  252. }
  253. cv::Mat mask2;
  254. cv::resize(mask, mask2, cv::Size(img_w, img_h));
  255. // crop obj box and binarize
  256. obj.mask = cv::Mat(img_h, img_w, CV_8UC1);
  257. {
  258. obj.mask = cv::Scalar(0);
  259. for (int y = 0; y < img_h; y++)
  260. {
  261. if (y < obj.rect.y || y > obj.rect.y + obj.rect.height)
  262. continue;
  263. const float* mp2 = mask2.ptr<const float>(y);
  264. uchar* bmp = obj.mask.ptr<uchar>(y);
  265. for (int x = 0; x < img_w; x++)
  266. {
  267. if (x < obj.rect.x || x > obj.rect.x + obj.rect.width)
  268. continue;
  269. bmp[x] = mp2[x] > 0.5f ? 255 : 0;
  270. }
  271. }
  272. }
  273. }
  274. return 0;
  275. }
  276. static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
  277. {
  278. static const char* class_names[] = {"background",
  279. "person", "bicycle", "car", "motorcycle", "airplane", "bus",
  280. "train", "truck", "boat", "traffic light", "fire hydrant",
  281. "stop sign", "parking meter", "bench", "bird", "cat", "dog",
  282. "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
  283. "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
  284. "skis", "snowboard", "sports ball", "kite", "baseball bat",
  285. "baseball glove", "skateboard", "surfboard", "tennis racket",
  286. "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
  287. "banana", "apple", "sandwich", "orange", "broccoli", "carrot",
  288. "hot dog", "pizza", "donut", "cake", "chair", "couch",
  289. "potted plant", "bed", "dining table", "toilet", "tv", "laptop",
  290. "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
  291. "toaster", "sink", "refrigerator", "book", "clock", "vase",
  292. "scissors", "teddy bear", "hair drier", "toothbrush"
  293. };
  294. static const unsigned char colors[81][3] = {
  295. {56, 0, 255},
  296. {226, 255, 0},
  297. {0, 94, 255},
  298. {0, 37, 255},
  299. {0, 255, 94},
  300. {255, 226, 0},
  301. {0, 18, 255},
  302. {255, 151, 0},
  303. {170, 0, 255},
  304. {0, 255, 56},
  305. {255, 0, 75},
  306. {0, 75, 255},
  307. {0, 255, 169},
  308. {255, 0, 207},
  309. {75, 255, 0},
  310. {207, 0, 255},
  311. {37, 0, 255},
  312. {0, 207, 255},
  313. {94, 0, 255},
  314. {0, 255, 113},
  315. {255, 18, 0},
  316. {255, 0, 56},
  317. {18, 0, 255},
  318. {0, 255, 226},
  319. {170, 255, 0},
  320. {255, 0, 245},
  321. {151, 255, 0},
  322. {132, 255, 0},
  323. {75, 0, 255},
  324. {151, 0, 255},
  325. {0, 151, 255},
  326. {132, 0, 255},
  327. {0, 255, 245},
  328. {255, 132, 0},
  329. {226, 0, 255},
  330. {255, 37, 0},
  331. {207, 255, 0},
  332. {0, 255, 207},
  333. {94, 255, 0},
  334. {0, 226, 255},
  335. {56, 255, 0},
  336. {255, 94, 0},
  337. {255, 113, 0},
  338. {0, 132, 255},
  339. {255, 0, 132},
  340. {255, 170, 0},
  341. {255, 0, 188},
  342. {113, 255, 0},
  343. {245, 0, 255},
  344. {113, 0, 255},
  345. {255, 188, 0},
  346. {0, 113, 255},
  347. {255, 0, 0},
  348. {0, 56, 255},
  349. {255, 0, 113},
  350. {0, 255, 188},
  351. {255, 0, 94},
  352. {255, 0, 18},
  353. {18, 255, 0},
  354. {0, 255, 132},
  355. {0, 188, 255},
  356. {0, 245, 255},
  357. {0, 169, 255},
  358. {37, 255, 0},
  359. {255, 0, 151},
  360. {188, 0, 255},
  361. {0, 255, 37},
  362. {0, 255, 0},
  363. {255, 0, 170},
  364. {255, 0, 37},
  365. {255, 75, 0},
  366. {0, 0, 255},
  367. {255, 207, 0},
  368. {255, 0, 226},
  369. {255, 245, 0},
  370. {188, 255, 0},
  371. {0, 255, 18},
  372. {0, 255, 75},
  373. {0, 255, 151},
  374. {255, 56, 0},
  375. {245, 255, 0}
  376. };
  377. cv::Mat image = bgr.clone();
  378. int color_index = 0;
  379. for (size_t i = 0; i < objects.size(); i++)
  380. {
  381. const Object& obj = objects[i];
  382. if (obj.prob < 0.15)
  383. continue;
  384. fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
  385. obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);
  386. const unsigned char* color = colors[color_index % 81];
  387. color_index++;
  388. cv::rectangle(image, obj.rect, cv::Scalar(color[0], color[1], color[2]));
  389. char text[256];
  390. sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);
  391. int baseLine = 0;
  392. cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
  393. int x = obj.rect.x;
  394. int y = obj.rect.y - label_size.height - baseLine;
  395. if (y < 0)
  396. y = 0;
  397. if (x + label_size.width > image.cols)
  398. x = image.cols - label_size.width;
  399. cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
  400. cv::Scalar(255, 255, 255), -1);
  401. cv::putText(image, text, cv::Point(x, y + label_size.height),
  402. cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));
  403. // draw mask
  404. for (int y = 0; y < image.rows; y++)
  405. {
  406. const uchar* mp = obj.mask.ptr(y);
  407. uchar* p = image.ptr(y);
  408. for (int x = 0; x < image.cols; x++)
  409. {
  410. if (mp[x] == 255)
  411. {
  412. p[0] = cv::saturate_cast<uchar>(p[0] * 0.5 + color[0] * 0.5);
  413. p[1] = cv::saturate_cast<uchar>(p[1] * 0.5 + color[1] * 0.5);
  414. p[2] = cv::saturate_cast<uchar>(p[2] * 0.5 + color[2] * 0.5);
  415. }
  416. p += 3;
  417. }
  418. }
  419. }
  420. cv::imwrite("result.png", image);
  421. cv::imshow("image", image);
  422. cv::waitKey(0);
  423. }
  424. int main(int argc, char** argv)
  425. {
  426. if (argc != 2)
  427. {
  428. fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
  429. return -1;
  430. }
  431. const char* imagepath = argv[1];
  432. cv::Mat m = cv::imread(imagepath, 1);
  433. if (m.empty())
  434. {
  435. fprintf(stderr, "cv::imread %s failed\n", imagepath);
  436. return -1;
  437. }
  438. std::vector<Object> objects;
  439. detect_yolact(m, objects);
  440. draw_objects(m, objects);
  441. return 0;
  442. }