You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

detection_post_process.c 12 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/fp32/detection_post_process.h"
  17. #include <math.h>
  18. #include "nnacl/errorcode.h"
  19. #include "nnacl/op_base.h"
  20. bool ScoreWithIndexCmp(ScoreWithIndex *pa, ScoreWithIndex *pb) {
  21. if (pa->score > pb->score) {
  22. return true;
  23. } else if (pa->score < pb->score) {
  24. return false;
  25. } else {
  26. return pa->index < pb->index;
  27. }
  28. }
  29. void PushHeap(ScoreWithIndex *root, int cur, int top_index, ScoreWithIndex value) {
  30. int parent = (cur - 1) / 2;
  31. while (cur > top_index && ScoreWithIndexCmp(root + parent, &value)) {
  32. *(root + cur) = root[parent];
  33. cur = parent;
  34. parent = (cur - 1) / 2;
  35. }
  36. *(root + cur) = value;
  37. }
  38. void AdjustHeap(ScoreWithIndex *root, int cur, int limit, ScoreWithIndex value) {
  39. int top_index = cur;
  40. int second_child = cur;
  41. while (second_child < (limit - 1) / 2) {
  42. second_child = 2 * (second_child + 1);
  43. if (ScoreWithIndexCmp(root + second_child, root + second_child - 1)) {
  44. second_child--;
  45. }
  46. *(root + cur) = *(root + second_child);
  47. cur = second_child;
  48. }
  49. if ((limit & 1) == 0 && second_child == (limit - 2) / 2) {
  50. second_child = 2 * (second_child + 1);
  51. *(root + cur) = *(root + second_child - 1);
  52. cur = second_child - 1;
  53. }
  54. PushHeap(root, cur, top_index, value);
  55. }
  56. void PopHeap(ScoreWithIndex *root, int limit, ScoreWithIndex *result) {
  57. ScoreWithIndex value = *result;
  58. *result = *root;
  59. AdjustHeap(root, 0, limit, value);
  60. }
  61. void MakeHeap(ScoreWithIndex *values, int limit) {
  62. if (limit < 2) return;
  63. int parent = (limit - 2) / 2;
  64. while (true) {
  65. AdjustHeap(values, parent, limit, values[parent]);
  66. if (parent == 0) {
  67. return;
  68. }
  69. parent--;
  70. }
  71. }
  72. void SortHeap(ScoreWithIndex *root, int limit) {
  73. while (limit > 1) {
  74. --limit;
  75. PopHeap(root, limit, root + limit);
  76. }
  77. }
  78. void HeapSelect(ScoreWithIndex *root, int cur, int limit) {
  79. MakeHeap(root, cur);
  80. for (int i = cur; i < limit; ++i) {
  81. if (ScoreWithIndexCmp(root + i, root)) {
  82. PopHeap(root, cur, root + i);
  83. }
  84. }
  85. }
  86. void PartialSort(ScoreWithIndex *values, int num_to_sort, int num_values) {
  87. HeapSelect(values, num_to_sort, num_values);
  88. SortHeap(values, num_to_sort);
  89. }
  90. float IntersectionOverUnion(const BboxCorner *a, const BboxCorner *b) {
  91. const float area_a = (a->ymax - a->ymin) * (a->xmax - a->xmin);
  92. const float area_b = (b->ymax - b->ymin) * (b->xmax - b->xmin);
  93. if (area_a <= 0 || area_b <= 0) {
  94. return 0.0f;
  95. }
  96. const float ymin = a->ymin > b->ymin ? a->ymin : b->ymin;
  97. const float xmin = a->xmin > b->xmin ? a->xmin : b->xmin;
  98. const float ymax = a->ymax < b->ymax ? a->ymax : b->ymax;
  99. const float xmax = a->xmax < b->xmax ? a->xmax : b->xmax;
  100. const float h = ymax - ymin > 0.0f ? ymax - ymin : 0.0f;
  101. const float w = xmax - xmin > 0.0f ? xmax - xmin : 0.0f;
  102. const float inter = h * w;
  103. return inter / (area_a + area_b - inter);
  104. }
  105. void DecodeBoxes(const int num_boxes, const float *input_boxes, const float *anchors, const BboxCenter scaler,
  106. float *decoded_boxes) {
  107. for (int i = 0; i < num_boxes; ++i) {
  108. BboxCenter *box = (BboxCenter *)(input_boxes) + i;
  109. BboxCenter *anchor = (BboxCenter *)(anchors) + i;
  110. BboxCorner *decoded_box = (BboxCorner *)(decoded_boxes) + i;
  111. float y_center = box->y / scaler.y * anchor->h + anchor->y;
  112. float x_center = box->x / scaler.x * anchor->w + anchor->x;
  113. const float h_half = 0.5f * expf(box->h / scaler.h) * anchor->h;
  114. const float w_half = 0.5f * expf(box->w / scaler.w) * anchor->w;
  115. decoded_box->ymin = y_center - h_half;
  116. decoded_box->xmin = x_center - w_half;
  117. decoded_box->ymax = y_center + h_half;
  118. decoded_box->xmax = x_center + w_half;
  119. }
  120. }
  121. int NmsSingleClass(const int candidate_num, const float *decoded_boxes, const int max_detections,
  122. ScoreWithIndex *score_with_index, int *selected, const DetectionPostProcessParameter *param) {
  123. uint8_t *nms_candidate = param->nms_candidate_;
  124. const int output_num = candidate_num < max_detections ? candidate_num : max_detections;
  125. int possible_candidate_num = candidate_num;
  126. int selected_num = 0;
  127. PartialSort(score_with_index, candidate_num, candidate_num);
  128. for (int i = 0; i < candidate_num; ++i) {
  129. nms_candidate[i] = 1;
  130. }
  131. for (int i = 0; i < candidate_num; ++i) {
  132. if (possible_candidate_num == 0 || selected_num >= output_num) {
  133. break;
  134. }
  135. if (nms_candidate[i] == 0) {
  136. continue;
  137. }
  138. selected[selected_num++] = score_with_index[i].index;
  139. nms_candidate[i] = 0;
  140. possible_candidate_num--;
  141. for (int t = i + 1; t < candidate_num; ++t) {
  142. if (nms_candidate[t] == 1) {
  143. const BboxCorner *bbox_i = (BboxCorner *)(decoded_boxes) + score_with_index[i].index;
  144. const BboxCorner *bbox_t = (BboxCorner *)(decoded_boxes) + score_with_index[t].index;
  145. const float iou = IntersectionOverUnion(bbox_i, bbox_t);
  146. if (iou > param->nms_iou_threshold_) {
  147. nms_candidate[t] = 0;
  148. possible_candidate_num--;
  149. }
  150. }
  151. }
  152. }
  153. return selected_num;
  154. }
  155. int NmsMultiClassesRegular(const int num_boxes, const int num_classes_with_bg, const float *decoded_boxes,
  156. const float *input_scores, float *output_boxes, float *output_classes, float *output_scores,
  157. const DetectionPostProcessParameter *param) {
  158. const int first_class_index = num_classes_with_bg - (int)(param->num_classes_);
  159. int *selected = (int *)(param->selected_);
  160. ScoreWithIndex *score_with_index_single = (ScoreWithIndex *)(param->score_with_class_);
  161. int all_classes_sorted_num = 0;
  162. int all_classes_output_num = 0;
  163. ScoreWithIndex *score_with_index_all = (ScoreWithIndex *)(param->score_with_class_all_);
  164. int *indexes = (int *)(param->indexes_);
  165. for (int j = first_class_index; j < num_classes_with_bg; ++j) {
  166. int candidate_num = 0;
  167. // process single class
  168. for (int i = 0; i < num_boxes; ++i) {
  169. const float score = input_scores[i * num_classes_with_bg + j];
  170. if (score >= param->nms_score_threshold_) {
  171. score_with_index_single[candidate_num].score = score;
  172. score_with_index_single[candidate_num++].index = i;
  173. }
  174. }
  175. int selected_num = NmsSingleClass(candidate_num, decoded_boxes, param->detections_per_class_,
  176. score_with_index_single, selected, param);
  177. for (int i = 0; i < all_classes_sorted_num; ++i) {
  178. indexes[i] = score_with_index_all[i].index;
  179. score_with_index_all[i].index = i;
  180. }
  181. // process all classes
  182. for (int i = 0; i < selected_num; ++i) {
  183. // store class to index
  184. indexes[all_classes_sorted_num] = selected[i] * num_classes_with_bg + j;
  185. score_with_index_all[all_classes_sorted_num].index = all_classes_sorted_num;
  186. score_with_index_all[all_classes_sorted_num++].score = input_scores[selected[i] * num_classes_with_bg + j];
  187. }
  188. all_classes_output_num =
  189. all_classes_sorted_num < param->max_detections_ ? all_classes_sorted_num : param->max_detections_;
  190. PartialSort(score_with_index_all, all_classes_output_num, all_classes_sorted_num);
  191. for (int i = 0; i < all_classes_output_num; ++i) {
  192. score_with_index_all[i].index = indexes[score_with_index_all[i].index];
  193. }
  194. all_classes_sorted_num = all_classes_output_num;
  195. }
  196. for (int i = 0; i < param->max_detections_ * param->max_classes_per_detection_; ++i) {
  197. if (i < all_classes_output_num) {
  198. const int box_index = score_with_index_all[i].index / num_classes_with_bg;
  199. const int class_index = score_with_index_all[i].index - box_index * num_classes_with_bg - first_class_index;
  200. *((BboxCorner *)(output_boxes) + i) = *((BboxCorner *)(decoded_boxes) + box_index);
  201. output_classes[i] = (float)class_index;
  202. output_scores[i] = score_with_index_all[i].score;
  203. } else {
  204. ((BboxCorner *)(output_boxes) + i)->ymin = 0;
  205. ((BboxCorner *)(output_boxes) + i)->xmin = 0;
  206. ((BboxCorner *)(output_boxes) + i)->ymax = 0;
  207. ((BboxCorner *)(output_boxes) + i)->xmax = 0;
  208. output_classes[i] = 0.0f;
  209. output_scores[i] = 0.0f;
  210. }
  211. }
  212. return all_classes_output_num;
  213. }
  214. int NmsMultiClassesFast(const int num_boxes, const int num_classes_with_bg, const float *decoded_boxes,
  215. const float *input_scores, float *output_boxes, float *output_classes, float *output_scores,
  216. const DetectionPostProcessParameter *param) {
  217. const int first_class_index = num_classes_with_bg - (int)(param->num_classes_);
  218. const int64_t max_classes_per_anchor =
  219. param->max_classes_per_detection_ < param->num_classes_ ? param->max_classes_per_detection_ : param->num_classes_;
  220. int candidate_num = 0;
  221. ScoreWithIndex *score_with_class_all = (ScoreWithIndex *)(param->score_with_class_all_);
  222. ScoreWithIndex *score_with_class = (ScoreWithIndex *)(param->score_with_class_);
  223. int *selected = (int *)(param->selected_);
  224. int selected_num;
  225. int output_num = 0;
  226. for (int i = 0; i < num_boxes; ++i) {
  227. for (int j = first_class_index; j < num_classes_with_bg; ++j) {
  228. float score_t = *(input_scores + i * num_classes_with_bg + j);
  229. score_with_class_all[i * param->num_classes_ + j - first_class_index].score = score_t;
  230. // save box and class info to index
  231. score_with_class_all[i * param->num_classes_ + j - first_class_index].index = i * num_classes_with_bg + j;
  232. }
  233. PartialSort(score_with_class_all + i * param->num_classes_, max_classes_per_anchor, param->num_classes_);
  234. const float score_max = (score_with_class_all + i * param->num_classes_)->score;
  235. if (score_max >= param->nms_score_threshold_) {
  236. score_with_class[candidate_num].index = i;
  237. score_with_class[candidate_num++].score = score_max;
  238. }
  239. }
  240. selected_num =
  241. NmsSingleClass(candidate_num, decoded_boxes, param->max_detections_, score_with_class, selected, param);
  242. for (int i = 0; i < selected_num; ++i) {
  243. const ScoreWithIndex *box_score_with_class = score_with_class_all + selected[i] * param->num_classes_;
  244. const int box_index = box_score_with_class->index / num_classes_with_bg;
  245. for (int j = 0; j < max_classes_per_anchor; ++j) {
  246. *((BboxCorner *)(output_boxes) + output_num) = *((BboxCorner *)(decoded_boxes) + box_index);
  247. output_scores[output_num] = (box_score_with_class + j)->score;
  248. output_classes[output_num++] =
  249. (float)((box_score_with_class + j)->index % num_classes_with_bg - first_class_index);
  250. }
  251. }
  252. for (int i = output_num; i < param->max_detections_ * param->max_classes_per_detection_; ++i) {
  253. ((BboxCorner *)(output_boxes) + i)->ymin = 0;
  254. ((BboxCorner *)(output_boxes) + i)->xmin = 0;
  255. ((BboxCorner *)(output_boxes) + i)->ymax = 0;
  256. ((BboxCorner *)(output_boxes) + i)->xmax = 0;
  257. output_scores[i] = 0;
  258. output_classes[i] = 0;
  259. }
  260. return output_num;
  261. }
  262. int DetectionPostProcess(const int num_boxes, const int num_classes_with_bg, float *input_boxes, float *input_scores,
  263. float *input_anchors, float *output_boxes, float *output_classes, float *output_scores,
  264. float *output_num, DetectionPostProcessParameter *param) {
  265. BboxCenter scaler;
  266. scaler.y = param->y_scale_;
  267. scaler.x = param->x_scale_;
  268. scaler.h = param->h_scale_;
  269. scaler.w = param->w_scale_;
  270. DecodeBoxes(num_boxes, input_boxes, input_anchors, scaler, param->decoded_boxes_);
  271. if (param->use_regular_nms_) {
  272. *output_num = NmsMultiClassesRegular(num_boxes, num_classes_with_bg, param->decoded_boxes_, input_scores,
  273. output_boxes, output_classes, output_scores, param);
  274. } else {
  275. *output_num = NmsMultiClassesFast(num_boxes, num_classes_with_bg, param->decoded_boxes_, input_scores, output_boxes,
  276. output_classes, output_scores, param);
  277. }
  278. return NNACL_OK;
  279. }