You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """YoloV4 eval."""
  16. import os
  17. import argparse
  18. import datetime
  19. import time
  20. import sys
  21. from collections import defaultdict
  22. import numpy as np
  23. from pycocotools.coco import COCO
  24. from pycocotools.cocoeval import COCOeval
  25. from mindspore import Tensor
  26. from mindspore.context import ParallelMode
  27. from mindspore import context
  28. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  29. import mindspore as ms
  30. from src.yolo import YOLOV4CspDarkNet53
  31. from src.logger import get_logger
  32. from src.yolo_dataset import create_yolo_dataset
  33. from src.config import ConfigYOLOV4CspDarkNet53
  34. parser = argparse.ArgumentParser('mindspore coco testing')
  35. # device related
  36. parser.add_argument('--device_target', type=str, default='Ascend',
  37. help='device where the code will be implemented. (Default: Ascend)')
  38. # dataset related
  39. parser.add_argument('--data_dir', type=str, default='', help='train data dir')
  40. parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu')
  41. # network related
  42. parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load')
  43. # logging related
  44. parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location')
  45. # detect_related
  46. parser.add_argument('--nms_thresh', type=float, default=0.5, help='threshold for NMS')
  47. parser.add_argument('--ann_file', type=str, default='', help='path to annotation')
  48. parser.add_argument('--testing_shape', type=str, default='', help='shape for test ')
  49. parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes')
  50. args, _ = parser.parse_known_args()
  51. args.data_root = os.path.join(args.data_dir, 'val2017')
  52. args.ann_file = os.path.join(args.data_dir, 'annotations/instances_val2017.json')
  53. class Redirct:
  54. def __init__(self):
  55. self.content = ""
  56. def write(self, content):
  57. self.content += content
  58. def flush(self):
  59. self.content = ""
  60. class DetectionEngine:
  61. """Detection engine."""
  62. def __init__(self, args_detection):
  63. self.ignore_threshold = args_detection.ignore_threshold
  64. self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
  65. 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
  66. 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
  67. 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
  68. 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
  69. 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
  70. 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
  71. 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
  72. 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
  73. 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
  74. self.num_classes = len(self.labels)
  75. self.results = {}
  76. self.file_path = ''
  77. self.save_prefix = args_detection.outputs_dir
  78. self.ann_file = args_detection.ann_file
  79. self._coco = COCO(self.ann_file)
  80. self._img_ids = list(sorted(self._coco.imgs.keys()))
  81. self.det_boxes = []
  82. self.nms_thresh = args_detection.nms_thresh
  83. self.coco_catids = self._coco.getCatIds()
  84. def do_nms_for_results(self):
  85. """Get result boxes."""
  86. for img_id in self.results:
  87. for clsi in self.results[img_id]:
  88. dets = self.results[img_id][clsi]
  89. dets = np.array(dets)
  90. keep_index = self._diou_nms(dets, thresh=0.6)
  91. keep_box = [{'image_id': int(img_id),
  92. 'category_id': int(clsi),
  93. 'bbox': list(dets[i][:4].astype(float)),
  94. 'score': dets[i][4].astype(float)}
  95. for i in keep_index]
  96. self.det_boxes.extend(keep_box)
  97. def _nms(self, predicts, threshold):
  98. """Calculate NMS."""
  99. # convert xywh -> xmin ymin xmax ymax
  100. x1 = predicts[:, 0]
  101. y1 = predicts[:, 1]
  102. x2 = x1 + predicts[:, 2]
  103. y2 = y1 + predicts[:, 3]
  104. scores = predicts[:, 4]
  105. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  106. order = scores.argsort()[::-1]
  107. reserved_boxes = []
  108. while order.size > 0:
  109. i = order[0]
  110. reserved_boxes.append(i)
  111. max_x1 = np.maximum(x1[i], x1[order[1:]])
  112. max_y1 = np.maximum(y1[i], y1[order[1:]])
  113. min_x2 = np.minimum(x2[i], x2[order[1:]])
  114. min_y2 = np.minimum(y2[i], y2[order[1:]])
  115. intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
  116. intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
  117. intersect_area = intersect_w * intersect_h
  118. ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
  119. indexes = np.where(ovr <= threshold)[0]
  120. order = order[indexes + 1]
  121. return reserved_boxes
  122. def _diou_nms(self, dets, thresh=0.5):
  123. """
  124. convert xywh -> xmin ymin xmax ymax
  125. """
  126. x1 = dets[:, 0]
  127. y1 = dets[:, 1]
  128. x2 = x1 + dets[:, 2]
  129. y2 = y1 + dets[:, 3]
  130. scores = dets[:, 4]
  131. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  132. order = scores.argsort()[::-1]
  133. keep = []
  134. while order.size > 0:
  135. i = order[0]
  136. keep.append(i)
  137. xx1 = np.maximum(x1[i], x1[order[1:]])
  138. yy1 = np.maximum(y1[i], y1[order[1:]])
  139. xx2 = np.minimum(x2[i], x2[order[1:]])
  140. yy2 = np.minimum(y2[i], y2[order[1:]])
  141. w = np.maximum(0.0, xx2 - xx1 + 1)
  142. h = np.maximum(0.0, yy2 - yy1 + 1)
  143. inter = w * h
  144. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  145. center_x1 = (x1[i] + x2[i]) / 2
  146. center_x2 = (x1[order[1:]] + x2[order[1:]]) / 2
  147. center_y1 = (y1[i] + y2[i]) / 2
  148. center_y2 = (y1[order[1:]] + y2[order[1:]]) / 2
  149. inter_diag = (center_x2 - center_x1) ** 2 + (center_y2 - center_y1) ** 2
  150. out_max_x = np.maximum(x2[i], x2[order[1:]])
  151. out_max_y = np.maximum(y2[i], y2[order[1:]])
  152. out_min_x = np.minimum(x1[i], x1[order[1:]])
  153. out_min_y = np.minimum(y1[i], y1[order[1:]])
  154. outer_diag = (out_max_x - out_min_x) ** 2 + (out_max_y - out_min_y) ** 2
  155. diou = ovr - inter_diag / outer_diag
  156. diou = np.clip(diou, -1, 1)
  157. inds = np.where(diou <= thresh)[0]
  158. order = order[inds + 1]
  159. return keep
  160. def write_result(self):
  161. """Save result to file."""
  162. import json
  163. t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
  164. try:
  165. self.file_path = self.save_prefix + '/predict' + t + '.json'
  166. f = open(self.file_path, 'w')
  167. json.dump(self.det_boxes, f)
  168. except IOError as e:
  169. raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
  170. else:
  171. f.close()
  172. return self.file_path
  173. def get_eval_result(self):
  174. """Get eval result."""
  175. coco_gt = COCO(self.ann_file)
  176. coco_dt = coco_gt.loadRes(self.file_path)
  177. coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
  178. coco_eval.evaluate()
  179. coco_eval.accumulate()
  180. rdct = Redirct()
  181. stdout = sys.stdout
  182. sys.stdout = rdct
  183. coco_eval.summarize()
  184. sys.stdout = stdout
  185. return rdct.content
  186. def detect(self, outputs, batch, image_shape, image_id):
  187. """Detect boxes."""
  188. outputs_num = len(outputs)
  189. # output [|32, 52, 52, 3, 85| ]
  190. for batch_id in range(batch):
  191. for out_id in range(outputs_num):
  192. # 32, 52, 52, 3, 85
  193. out_item = outputs[out_id]
  194. # 52, 52, 3, 85
  195. out_item_single = out_item[batch_id, :]
  196. # get number of items in one head, [B, gx, gy, anchors, 5+80]
  197. dimensions = out_item_single.shape[:-1]
  198. out_num = 1
  199. for d in dimensions:
  200. out_num *= d
  201. ori_w, ori_h = image_shape[batch_id]
  202. img_id = int(image_id[batch_id])
  203. x = out_item_single[..., 0] * ori_w
  204. y = out_item_single[..., 1] * ori_h
  205. w = out_item_single[..., 2] * ori_w
  206. h = out_item_single[..., 3] * ori_h
  207. conf = out_item_single[..., 4:5]
  208. cls_emb = out_item_single[..., 5:]
  209. cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1)
  210. x = x.reshape(-1)
  211. y = y.reshape(-1)
  212. w = w.reshape(-1)
  213. h = h.reshape(-1)
  214. cls_emb = cls_emb.reshape(-1, self.num_classes)
  215. conf = conf.reshape(-1)
  216. cls_argmax = cls_argmax.reshape(-1)
  217. x_top_left = x - w / 2.
  218. y_top_left = y - h / 2.
  219. # create all False
  220. flag = np.random.random(cls_emb.shape) > sys.maxsize
  221. for i in range(flag.shape[0]):
  222. c = cls_argmax[i]
  223. flag[i, c] = True
  224. confidence = cls_emb[flag] * conf
  225. for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax):
  226. if confi < self.ignore_threshold:
  227. continue
  228. if img_id not in self.results:
  229. self.results[img_id] = defaultdict(list)
  230. x_lefti = max(0, x_lefti)
  231. y_lefti = max(0, y_lefti)
  232. wi = min(wi, ori_w)
  233. hi = min(hi, ori_h)
  234. # transform catId to match coco
  235. coco_clsi = self.coco_catids[clsi]
  236. self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
  237. def convert_testing_shape(args_testing_shape):
  238. """Convert testing shape to list."""
  239. testing_shape = [int(args_testing_shape), int(args_testing_shape)]
  240. return testing_shape
  241. if __name__ == "__main__":
  242. start_time = time.time()
  243. device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
  244. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=device_id)
  245. # logger
  246. args.outputs_dir = os.path.join(args.log_path,
  247. datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  248. rank_id = int(os.environ.get('RANK_ID')) if os.environ.get('RANK_ID') else 0
  249. args.logger = get_logger(args.outputs_dir, rank_id)
  250. context.reset_auto_parallel_context()
  251. parallel_mode = ParallelMode.STAND_ALONE
  252. context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1)
  253. args.logger.info('Creating Network....')
  254. network = YOLOV4CspDarkNet53(is_training=False)
  255. args.logger.info(args.pretrained)
  256. if os.path.isfile(args.pretrained):
  257. param_dict = load_checkpoint(args.pretrained)
  258. param_dict_new = {}
  259. for key, values in param_dict.items():
  260. if key.startswith('moments.'):
  261. continue
  262. elif key.startswith('yolo_network.'):
  263. param_dict_new[key[13:]] = values
  264. else:
  265. param_dict_new[key] = values
  266. load_param_into_net(network, param_dict_new)
  267. args.logger.info('load_model {} success'.format(args.pretrained))
  268. else:
  269. args.logger.info('{} not exists or not a pre-trained file'.format(args.pretrained))
  270. assert FileNotFoundError('{} not exists or not a pre-trained file'.format(args.pretrained))
  271. exit(1)
  272. data_root = args.data_root
  273. ann_file = args.ann_file
  274. config = ConfigYOLOV4CspDarkNet53()
  275. if args.testing_shape:
  276. config.test_img_shape = convert_testing_shape(args.testing_shape)
  277. ds, data_size = create_yolo_dataset(data_root, ann_file, is_training=False, batch_size=args.per_batch_size,
  278. max_epoch=1, device_num=1, rank=rank_id, shuffle=False,
  279. config=config)
  280. args.logger.info('testing shape : {}'.format(config.test_img_shape))
  281. args.logger.info('totol {} images to eval'.format(data_size))
  282. network.set_train(False)
  283. # init detection engine
  284. detection = DetectionEngine(args)
  285. input_shape = Tensor(tuple(config.test_img_shape), ms.float32)
  286. args.logger.info('Start inference....')
  287. for index, data in enumerate(ds.create_dict_iterator(num_epochs=1)):
  288. image = data["image"]
  289. image_shape_ = data["image_shape"]
  290. image_id_ = data["img_id"]
  291. prediction = network(image, input_shape)
  292. output_big, output_me, output_small = prediction
  293. output_big = output_big.asnumpy()
  294. output_me = output_me.asnumpy()
  295. output_small = output_small.asnumpy()
  296. image_id_ = image_id_.asnumpy()
  297. image_shape_ = image_shape_.asnumpy()
  298. detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape_, image_id_)
  299. if index % 1000 == 0:
  300. args.logger.info('Processing... {:.2f}% '.format(index * args.per_batch_size / data_size * 100))
  301. args.logger.info('Calculating mAP...')
  302. detection.do_nms_for_results()
  303. result_file_path = detection.write_result()
  304. args.logger.info('result file path: {}'.format(result_file_path))
  305. eval_result = detection.get_eval_result()
  306. cost_time = time.time() - start_time
  307. args.logger.info('\n=============coco eval reulst=========\n' + eval_result)
  308. args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.))