You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

predict_with_print_box.py 11 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. """
  2. /**
  3. * Copyright 2020 Zhejiang Lab. All Rights Reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. * =============================================================
  17. */
  18. """
  19. import json
  20. import time
  21. import cv2
  22. import numpy as np
  23. import oneflow_yolov3
  24. from yolo_net import YoloPredictNet
  25. import oneflow as flow
  26. '''Init oneflow config'''
  27. model_load_dir = "of_model/yolov3_model_python/"
  28. label_to_name_file = "data/coco.names"
  29. use_tensorrt = 0
  30. gpu_num_per_node = 1
  31. batch_size = 16
  32. image_height = 608
  33. image_width = 608
  34. flow.config.load_library(oneflow_yolov3.lib_path())
  35. func_config = flow.FunctionConfig()
  36. func_config.default_distribute_strategy(flow.distribute.consistent_strategy())
  37. func_config.default_data_type(flow.float)
  38. if use_tensorrt != 0:
  39. func_config.use_tensorrt(True)
  40. label_2_name = []
  41. with open(label_to_name_file, 'r') as f:
  42. label_2_name = f.readlines()
  43. nms = True
  44. print("nms:", nms)
  45. input_blob_def_dict = {
  46. "images": flow.FixedTensorDef((batch_size, 3, image_height, image_width), dtype=flow.float),
  47. "origin_image_info": flow.FixedTensorDef((batch_size, 2), dtype=flow.int32),
  48. }
  49. def xywh_2_x1y1x2y2(x, y, w, h, origin_image):
  50. """The format of box transform"""
  51. x1 = (x - w / 2.) * origin_image[1]
  52. x2 = (x + w / 2.) * origin_image[1]
  53. y1 = (y - h / 2.) * origin_image[0]
  54. y2 = (y + h / 2.) * origin_image[0]
  55. return x1, y1, x2, y2
  56. def batch_boxes(positions, probs, origin_image_info):
  57. """The images postprocessing"""
  58. batch_size = positions.shape[0]
  59. batch_list = []
  60. if nms == True:
  61. for k in range(batch_size):
  62. box_list = []
  63. for i in range(1, 81):
  64. for j in range(positions.shape[2]):
  65. if positions[k][i][j][2] != 0 and positions[k][i][j][3] != 0 and probs[k][i][j] != 0:
  66. x1, y1, x2, y2 = xywh_2_x1y1x2y2(positions[k][i][j][0], positions[k][i][j][1],
  67. positions[k][i][j][2], positions[k][i][j][3],
  68. origin_image_info[k])
  69. bbox = [i - 1, x1, y1, x2, y2, probs[k][i][j]]
  70. box_list.append(bbox)
  71. batch_list.append(np.asarray(box_list))
  72. else:
  73. for k in range(batch_size):
  74. box_list = []
  75. for j in range(positions.shape[1]):
  76. for i in range(1, 81):
  77. if positions[k][j][2] != 0 and positions[k][j][3] != 0 and probs[k][j][i] != 0:
  78. x1, y1, x2, y2 = xywh_2_x1y1x2y2(positions[k][j][0], positions[k][j][1], positions[k][j][2],
  79. positions[k][j][3], origin_image_info[k])
  80. bbox = [i - 1, x1, y1, x2, y2, probs[k][j][i]]
  81. box_list.append(bbox)
  82. batch_list.append(np.asarray(box_list))
  83. return batch_list
  84. @flow.function(func_config)
  85. def yolo_user_op_eval_job(images=input_blob_def_dict["images"],
  86. origin_image_info=input_blob_def_dict["origin_image_info"]):
  87. """The model inference"""
  88. yolo_pos_result, yolo_prob_result = YoloPredictNet(images, origin_image_info, trainable=False)
  89. yolo_pos_result = flow.identity(yolo_pos_result, name="yolo_pos_result_end")
  90. yolo_prob_result = flow.identity(yolo_prob_result, name="yolo_prob_result_end")
  91. return yolo_pos_result, yolo_prob_result, origin_image_info
  92. def yolo_show(image_path_list, batch_list):
  93. """Debug the result of Yolov3"""
  94. font = cv2.FONT_HERSHEY_SIMPLEX
  95. for img_path, batch in zip(image_path_list, batch_list):
  96. result_list = batch.tolist()
  97. img = cv2.imread(img_path)
  98. for result in result_list:
  99. cls = int(result[0])
  100. bbox = result[1:-1]
  101. score = result[-1]
  102. print('img_file:', img_path)
  103. print('cls:', cls)
  104. print('bbox:', bbox)
  105. c = ((int(bbox[0]) + int(bbox[2])) / 2, (int(bbox[1] + int(bbox[3])) / 2))
  106. cv2.rectangle(img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 255, 255), 1)
  107. cv2.putText(img, str(cls), (int(c[0]), int(c[1])), font, 1, (0, 0, 255), 1)
  108. result_name = img_path.split('/')[-1]
  109. cv2.imwrite("data/results/" + result_name, img)
  110. def resize_image(img, origin_h, origin_w, image_height, image_width):
  111. """The resize of image preprocessing"""
  112. w = image_width
  113. h = image_height
  114. resized = np.zeros((3, image_height, image_width), dtype=np.float32)
  115. part = np.zeros((3, origin_h, image_width), dtype=np.float32)
  116. w_scale = (float)(origin_w - 1) / (w - 1)
  117. h_scale = (float)(origin_h - 1) / (h - 1)
  118. for c in range(w):
  119. if c == w - 1 or origin_w == 1:
  120. val = img[:, :, origin_w - 1]
  121. else:
  122. sx = c * w_scale
  123. ix = int(sx)
  124. dx = sx - ix
  125. val = (1 - dx) * img[:, :, ix] + dx * img[:, :, ix + 1]
  126. part[:, :, c] = val
  127. for r in range(h):
  128. sy = r * h_scale
  129. iy = int(sy)
  130. dy = sy - iy
  131. val = (1 - dy) * part[:, iy, :]
  132. resized[:, r, :] = val
  133. if r == h - 1 or origin_h == 1:
  134. continue
  135. resized[:, r, :] = resized[:, r, :] + dy * part[:, iy + 1, :]
  136. return resized
  137. def batch_image_preprocess_v2(img_path_list, image_height, image_width):
  138. """The images preprocessing"""
  139. result_list = []
  140. origin_info_list = []
  141. for img_path in img_path_list:
  142. img = cv2.imread(img_path, cv2.IMREAD_COLOR)
  143. img = img.transpose(2, 0, 1).astype(np.float32) # hwc->chw
  144. img = img / 255 # /255
  145. img[[0, 1, 2], :, :] = img[[2, 1, 0], :, :] # bgr2rgb
  146. w = image_width
  147. h = image_height
  148. origin_h = img.shape[1]
  149. origin_w = img.shape[2]
  150. new_w = origin_w
  151. new_h = origin_h
  152. if w / origin_w < h / origin_h:
  153. new_w = w
  154. new_h = origin_h * w // origin_w
  155. else:
  156. new_h = h
  157. new_w = origin_w * h // origin_h
  158. resize_img = resize_image(img, origin_h, origin_w, new_h, new_w)
  159. dw = (w - new_w) // 2
  160. dh = (h - new_h) // 2
  161. padh_before = int(dh)
  162. padh_after = int(h - new_h - padh_before)
  163. padw_before = int(dw)
  164. padw_after = int(w - new_w - padw_before)
  165. result = np.pad(resize_img, pad_width=((0, 0), (padh_before, padh_after), (padw_before, padw_after)),
  166. mode='constant', constant_values=0.5)
  167. origin_image_info = [origin_h, origin_w]
  168. result_list.append(result)
  169. origin_info_list.append(origin_image_info)
  170. results = np.asarray(result_list).astype(np.float32)
  171. origin_image_infos = np.asarray(origin_info_list).astype(np.int32)
  172. return results, origin_image_infos
  173. def coco_format(type_, id_list, file_list, result_list, label_list, coco_flag=0):
  174. """Transform the annotations to coco format"""
  175. annotations = []
  176. for i, result in enumerate(result_list):
  177. temp = {}
  178. id_name = id_list[i]
  179. file_path = file_list[i]
  180. temp['id'] = id_name
  181. temp['annotation'] = []
  182. im = cv2.imread(file_path)
  183. height, width, _ = im.shape
  184. if result.shape[0] == 0:
  185. temp['annotation'] = json.dumps(temp['annotation'])
  186. annotations.append(temp)
  187. continue
  188. else:
  189. for j in range(result.shape[0]):
  190. cls_id = int(result[j][0]) + 1 + coco_flag
  191. x1 = result[j][1]
  192. x2 = result[j][3]
  193. y1 = result[j][2]
  194. y2 = result[j][4]
  195. score = result[j][5]
  196. width = max(0, x2 - x1)
  197. height = max(0, y2 - y1)
  198. if cls_id in label_list:
  199. temp['annotation'].append({
  200. 'area': width * height,
  201. 'bbox': [x1, y1, width, height],
  202. 'category_id': cls_id,
  203. 'iscrowd': 0,
  204. 'segmentation': [[x1, y1, x2, y1, x2, y2, x1, y2]],
  205. 'score': score
  206. })
  207. if type_ == 2 and len(temp['annotation']) > 0:
  208. temp['annotation'] = [temp['annotation'][0]]
  209. temp['annotation'][0].pop('area')
  210. temp['annotation'][0].pop('bbox')
  211. temp['annotation'][0].pop('iscrowd')
  212. temp['annotation'][0].pop('segmentation')
  213. temp['annotation'] = json.dumps(temp['annotation'])
  214. annotations.append(temp)
  215. return annotations
  216. class YoloInference(object):
  217. """Yolov3 detection inference"""
  218. def __init__(self, label_log):
  219. self.label_log = label_log
  220. flow.config.gpu_device_num(gpu_num_per_node)
  221. flow.env.ctrl_port(9789)
  222. check_point = flow.train.CheckPoint()
  223. if not model_load_dir:
  224. check_point.init()
  225. else:
  226. check_point.load(model_load_dir)
  227. print("Load check_point success")
  228. self.label_log.info("Load check_point success")
  229. def yolo_inference(self, type_, id_list, image_path_list, label_list, coco_flag=0):
  230. annotations = []
  231. try:
  232. if len(image_path_list) == 16:
  233. t0 = time.time()
  234. images, origin_image_info = batch_image_preprocess_v2(image_path_list, image_height, image_width)
  235. yolo_pos, yolo_prob, origin_image_info = yolo_user_op_eval_job(images, origin_image_info).get()
  236. batch_list = batch_boxes(yolo_pos, yolo_prob, origin_image_info)
  237. annotations = coco_format(type_, id_list, image_path_list, batch_list, label_list, coco_flag)
  238. t1 = time.time()
  239. print('t1-t0:', t1 - t0)
  240. except:
  241. print("Forward Error")
  242. self.label_log.error("Forward Error")
  243. for i, image_path in enumerate(image_path_list):
  244. temp = {}
  245. id_name = id_list[i]
  246. temp['id'] = id_name
  247. temp['annotation'] = []
  248. temp['annotation'] = json.dumps(temp['annotation'])
  249. annotations.append(temp)
  250. return annotations

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)