You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

post_process.py 9.6 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. """
  2. /**
  3. * Copyright 2020 Zhejiang Lab. All Rights Reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. * =============================================================
  17. */
  18. """
  19. import os
  20. from collections import defaultdict
  21. import numpy as np
  22. def removeSmallOrBigBbox(
  23. fid_pid_tlwhc,
  24. area_min_scale=0.25,
  25. area_max_scale=1.75,
  26. conf_thresh=0.95,
  27. cp_count_thresh=100,
  28. height=1080,
  29. split=10,
  30. show_delete_items=False):
  31. selected_ind = []
  32. filtered_ind = []
  33. ys = np.linspace(0, height, split + 1).astype(np.int)
  34. y_range2cps = {}
  35. for i in range(len(ys) - 1):
  36. y_range2cps[(ys[i], ys[i + 1])] = []
  37. filter_info = defaultdict(list)
  38. filter_counter = 0
  39. person_id2count = defaultdict(int)
  40. for ind, (frame_id, person_id, x, y, w, h,
  41. conf) in enumerate(fid_pid_tlwhc):
  42. if (ind + 1) % (len(fid_pid_tlwhc) // 10) == 0:
  43. print('filtering %d/%d...' % (ind + 1, len(fid_pid_tlwhc)))
  44. cp_x, cp_y = x + w / 2, y + h / 2
  45. for (y0, y1), cps in y_range2cps.items():
  46. if y0 <= cp_y <= y1:
  47. # print('y0: {}, y1: {}, cp_y: {}'.format(y0, y1, cp_y))
  48. # print((y0, y1), cps)
  49. if len(cps) >= cp_count_thresh:
  50. cps_arr = np.array(cps)
  51. area_mean = np.mean(cps_arr[:, 3])
  52. # sorted_area_ind = np.argsort(cps_arr[:, 3])
  53. # sorted_areas = cps_arr[:, 3][sorted_area_ind]
  54. # area_min = sorted_areas[int(len(sorted_areas) * 0.05)]
  55. # area_max = sorted_areas[int(len(sorted_areas) * 0.95)]
  56. area_min = area_mean * area_min_scale
  57. area_max = area_mean * area_max_scale
  58. # print('area_per_20: {}, area_per_80: {}, area: {}'.format(area_per_20, area_per_80, w*h))
  59. if area_min <= w * \
  60. h <= area_max or conf > conf_thresh: # or person_id2count[person_id] < 10
  61. person_id2count[person_id] += 1
  62. cps.append([person_id, w, h, w * h])
  63. selected_ind.append(ind)
  64. else:
  65. filter_info[frame_id].append(person_id)
  66. filter_counter += 1
  67. filtered_ind.append(ind)
  68. # print('filter line: %d, frame id: %d, person id: %d' % (ind+1, frame_id, person_id))
  69. else:
  70. cps.append([person_id, w, h, w * h])
  71. selected_ind.append(ind)
  72. break
  73. # print('delete totally %d items' % filter_counter)
  74. result = []
  75. for ind, data in enumerate(fid_pid_tlwhc):
  76. if ind in selected_ind:
  77. result.append(data)
  78. if show_delete_items:
  79. for f_id, p_ids in filter_info.items():
  80. print('filter frame id: {}, person ids: {}'.format(f_id, p_ids))
  81. return np.array(result)
  82. def writeResult(fid_pid_bbox_xywhc, save_file_pth, with_conf=False):
  83. with open(save_file_pth, 'w') as f:
  84. frame_max = int(np.max(fid_pid_bbox_xywhc[:, 0]))
  85. for frame_id in range(frame_max):
  86. frame_data = fid_pid_bbox_xywhc[fid_pid_bbox_xywhc[:, 0]
  87. == frame_id + 1]
  88. tempdata = frame_data[np.argsort(frame_data[:, 1])]
  89. for d in tempdata:
  90. if with_conf:
  91. msg = '%d,%d,%.2f,%.2f,%.2f,%.2f,%.3f\n' % (
  92. d[0], d[1], d[2], d[3], d[4], d[5], d[6])
  93. else:
  94. msg = '%d,%d,%.2f,%.2f,%.2f,%.2f\n' % (
  95. d[0], d[1], d[2], d[3], d[4], d[5])
  96. f.write(msg)
  97. def seleteDateByIndex(index_, data):
  98. temp = []
  99. for pid in index_:
  100. temp.append(data[data[:, 1] == pid])
  101. res = np.vstack(temp)
  102. return res
  103. def removeDateByIndex(index_, data):
  104. temp = []
  105. # index_ = index_.tolist()
  106. for item in data:
  107. if np.any(item[1] == index_):
  108. continue
  109. temp.append(item)
  110. res = np.vstack(temp)
  111. return res
  112. def printResult(bbox_xyxyc):
  113. for i, d in enumerate(bbox_xyxyc):
  114. pid, conf_max, conf_mean, mid_point_stdx, mid_point_stdy, start_frame, frame_len = d
  115. print(
  116. 'pid:{:.0f},cof_max:{:.2f},conf_mean:{:.2f},mid_point_stdx:{:.2f},mid_point_stdy:{:.2f},start_frame:{:.0f},frame_len:{:.0f}'.format(
  117. pid,
  118. conf_max,
  119. conf_mean,
  120. mid_point_stdx,
  121. mid_point_stdy,
  122. start_frame,
  123. frame_len))
  124. def getMindPointArray(xywhcs_np):
  125. xywhcs_np_temp = xywhcs_np.copy()
  126. cxcy = xywhcs_np_temp[:, 1:3] + xywhcs_np_temp[:, 3:4] / 2
  127. return cxcy
  128. def getTrackItemInfo(track_data):
  129. res = []
  130. pid_max = int(np.max(data[:, 1]))
  131. for pid in range(pid_max):
  132. data_p = data[data[:, 1] == pid]
  133. if len(data_p) > 0:
  134. confs = data_p[:, -1]
  135. conf_max = np.max(confs)
  136. conf_mean = np.mean(confs)
  137. start_frame = np.min(data_p[:, 0])
  138. cxcy = getMindPointArray(data_p)
  139. mid_point_std = np.std(cxcy, axis=0)
  140. frame_len = len(data_p)
  141. res.append([pid, conf_max, conf_mean, mid_point_std[0],
  142. mid_point_std[1], start_frame, frame_len])
  143. print(
  144. 'pid:{:d},cof_max:{:.2f},conf_mean:{:.2f},mid_point_std:{},start_frame:{:.0f},frame_len:{:d}'.format(
  145. pid,
  146. conf_max,
  147. conf_mean,
  148. mid_point_std,
  149. start_frame,
  150. frame_len))
  151. res = np.array(res)
  152. return res
  153. def getUnMoveLowConfObjInfo(fid_pid_tlwhc, std_th=1., conf_th=0.5):
  154. res = []
  155. pid_max = int(np.max(fid_pid_tlwhc[:, 1]))
  156. for pid in range(pid_max):
  157. data_p = fid_pid_tlwhc[fid_pid_tlwhc[:, 1] == pid]
  158. if len(data_p) > 0:
  159. confs = data_p[:, -1]
  160. conf_max = np.max(confs)
  161. conf_mean = np.mean(confs)
  162. start_frame = np.min(data_p[:, 0])
  163. cxcy = getMindPointArray(data_p)
  164. mid_point_std = np.std(cxcy, axis=0)
  165. frame_len = len(data_p)
  166. res.append([pid, conf_max, conf_mean, mid_point_std[0],
  167. mid_point_std[1], start_frame, frame_len])
  168. print(
  169. 'pid:{:d},cof_max:{:.2f},conf_mean:{:.2f},mid_point_std:{},start_frame:{:.0f},frame_len:{:d}'.format(
  170. pid,
  171. conf_max,
  172. conf_mean,
  173. mid_point_std,
  174. start_frame,
  175. frame_len))
  176. res = np.array(res)
  177. res_conf_le_0_5_T = res[(res[:, 2] < conf_th) & (
  178. res[:, 3] < std_th) & (res[:, 4] < std_th)]
  179. return res_conf_le_0_5_T
  180. def removeUnMoveLowConfObj(fid_pid_tlwhc, std_th=1, conf_th=0.5):
  181. res_info = []
  182. pid_max = int(np.max(fid_pid_tlwhc[:, 1]))
  183. for pid in range(pid_max):
  184. data_p = fid_pid_tlwhc[fid_pid_tlwhc[:, 1] == pid]
  185. if len(data_p) > 0:
  186. confs = data_p[:, -1]
  187. conf_max = np.max(confs)
  188. conf_mean = np.mean(confs)
  189. start_frame = np.min(data_p[:, 0])
  190. cxcy = getMindPointArray(data_p)
  191. mid_point_std = np.std(cxcy, axis=0)
  192. frame_len = len(data_p)
  193. res_info.append([pid,
  194. conf_max,
  195. conf_mean,
  196. mid_point_std[0],
  197. mid_point_std[1],
  198. start_frame,
  199. frame_len])
  200. # print('pid:{:d},cof_max:{:.2f},conf_mean:{:.2f},mid_point_std:{},start_frame:{:.0f},frame_len:{:d}'.format(
  201. # pid, conf_max, conf_mean, mid_point_std, start_frame, frame_len))
  202. res = np.array(res_info)
  203. res_conf_le_0_5 = res[(res[:, 2] < conf_th) & (
  204. res[:, 3] < std_th) & (res[:, 4] < std_th)]
  205. remove_inds = res_conf_le_0_5[:, 0]
  206. selet_data = removeDateByIndex(remove_inds, fid_pid_tlwhc)
  207. return selet_data
  208. if __name__ == "__main__":
  209. result = './xxx.txt'
  210. res_dir, res_name = os.path.split(result)
  211. data = np.loadtxt(result, delimiter=',')
  212. res1 = getTrackItemInfo(data)
  213. res2 = getUnMoveLowConfObjInfo(data)
  214. res3 = removeUnMoveLowConfObj(data)
  215. print('##################################')
  216. printResult(res1)
  217. print('##################################')
  218. printResult(res2)
  219. data_conf_le0_5 = seleteDateByIndex(res1[:, 0], data)
  220. writeResult(
  221. data_conf_le0_5,
  222. os.path.join(
  223. res_dir,
  224. 'A1_' +
  225. res_name),
  226. with_conf=True)
  227. data_conf_le0_5 = seleteDateByIndex(res2[:, 0], data)
  228. writeResult(
  229. data_conf_le0_5,
  230. os.path.join(
  231. res_dir,
  232. 'A2_' +
  233. res_name),
  234. with_conf=True)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)