You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """YoloV4 test-dev."""
  16. import os
  17. import sys
  18. import argparse
  19. import datetime
  20. from collections import defaultdict
  21. import json
  22. import numpy as np
  23. from mindspore import context
  24. from mindspore import Tensor
  25. from mindspore.context import ParallelMode
  26. from mindspore.communication.management import init, get_rank, get_group_size
  27. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  28. import mindspore as ms
  29. from src.yolo import YOLOV4CspDarkNet53
  30. from src.logger import get_logger
  31. from src.yolo_dataset import create_yolo_datasetv2
  32. from src.config import ConfigYOLOV4CspDarkNet53
  33. devid = int(os.getenv('DEVICE_ID'))
  34. context.set_context(mode=context.GRAPH_MODE, device_target="Davinci", save_graphs=False, device_id=devid)
  35. parser = argparse.ArgumentParser('mindspore coco testing')
  36. # dataset related
  37. parser.add_argument('--data_dir', type=str, default='', help='train data dir')
  38. parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu')
  39. # network related
  40. parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load')
  41. # logging related
  42. parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location')
  43. # distributed related
  44. parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
  45. parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
  46. parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
  47. # detect_related
  48. parser.add_argument('--nms_thresh', type=float, default=0.45, help='threshold for NMS')
  49. parser.add_argument('--annFile', type=str, default='', help='path to annotation')
  50. parser.add_argument('--testing_shape', type=str, default='', help='shape for test ')
  51. parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes')
  52. args, _ = parser.parse_known_args()
  53. args.data_root = os.path.join(args.data_dir, 'test2017')
  54. class DetectionEngine():
  55. """Detection engine"""
  56. def __init__(self, args_engine):
  57. self.ignore_threshold = args_engine.ignore_threshold
  58. self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
  59. 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
  60. 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
  61. 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
  62. 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
  63. 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
  64. 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
  65. 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
  66. 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
  67. 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
  68. self.num_classes = len(self.labels)
  69. self.results = {} # img_id->class
  70. self.file_path = '' # path to save predict result
  71. self.save_prefix = args_engine.outputs_dir
  72. self.det_boxes = []
  73. self.nms_thresh = args_engine.nms_thresh
  74. self.coco_catids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27,
  75. 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53,
  76. 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80,
  77. 81, 82, 84, 85, 86, 87, 88, 89, 90]
  78. def do_nms_for_results(self):
  79. """nms result"""
  80. for img_id in self.results:
  81. for clsi in self.results[img_id]:
  82. dets = self.results[img_id][clsi]
  83. dets = np.array(dets)
  84. keep_index = self._diou_nms(dets, thresh=0.6)
  85. keep_box = [{'image_id': int(img_id),
  86. 'category_id': int(clsi),
  87. 'bbox': list(dets[i][:4].astype(float)),
  88. 'score': dets[i][4].astype(float)}
  89. for i in keep_index]
  90. self.det_boxes.extend(keep_box)
  91. def _nms(self, dets, thresh):
  92. """nms function"""
  93. # convert xywh -> xmin ymin xmax ymax
  94. x1 = dets[:, 0]
  95. y1 = dets[:, 1]
  96. x2 = x1 + dets[:, 2]
  97. y2 = y1 + dets[:, 3]
  98. scores = dets[:, 4]
  99. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  100. order = scores.argsort()[::-1]
  101. keep = []
  102. while order.size > 0:
  103. i = order[0]
  104. keep.append(i)
  105. xx1 = np.maximum(x1[i], x1[order[1:]])
  106. yy1 = np.maximum(y1[i], y1[order[1:]])
  107. xx2 = np.minimum(x2[i], x2[order[1:]])
  108. yy2 = np.minimum(y2[i], y2[order[1:]])
  109. w = np.maximum(0.0, xx2 - xx1 + 1)
  110. h = np.maximum(0.0, yy2 - yy1 + 1)
  111. inter = w * h
  112. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  113. inds = np.where(ovr <= thresh)[0]
  114. order = order[inds + 1]
  115. return keep
  116. def _diou_nms(self, dets, thresh=0.5):
  117. """convert xywh -> xmin ymin xmax ymax"""
  118. x1 = dets[:, 0]
  119. y1 = dets[:, 1]
  120. x2 = x1 + dets[:, 2]
  121. y2 = y1 + dets[:, 3]
  122. scores = dets[:, 4]
  123. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  124. order = scores.argsort()[::-1]
  125. keep = []
  126. while order.size > 0:
  127. i = order[0]
  128. keep.append(i)
  129. xx1 = np.maximum(x1[i], x1[order[1:]])
  130. yy1 = np.maximum(y1[i], y1[order[1:]])
  131. xx2 = np.minimum(x2[i], x2[order[1:]])
  132. yy2 = np.minimum(y2[i], y2[order[1:]])
  133. w = np.maximum(0.0, xx2 - xx1 + 1)
  134. h = np.maximum(0.0, yy2 - yy1 + 1)
  135. inter = w * h
  136. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  137. center_x1 = (x1[i] + x2[i]) / 2
  138. center_x2 = (x1[order[1:]] + x2[order[1:]]) / 2
  139. center_y1 = (y1[i] + y2[i]) / 2
  140. center_y2 = (y1[order[1:]] + y2[order[1:]]) / 2
  141. inter_diag = (center_x2 - center_x1) ** 2 + (center_y2 - center_y1) ** 2
  142. out_max_x = np.maximum(x2[i], x2[order[1:]])
  143. out_max_y = np.maximum(y2[i], y2[order[1:]])
  144. out_min_x = np.minimum(x1[i], x1[order[1:]])
  145. out_min_y = np.minimum(y1[i], y1[order[1:]])
  146. outer_diag = (out_max_x - out_min_x) ** 2 + (out_max_y - out_min_y) ** 2
  147. diou = ovr - inter_diag / outer_diag
  148. diou = np.clip(diou, -1, 1)
  149. inds = np.where(diou <= thresh)[0]
  150. order = order[inds + 1]
  151. return keep
  152. def write_result(self):
  153. """write result to json file"""
  154. t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S')
  155. try:
  156. self.file_path = self.save_prefix + '/predict' + t + '.json'
  157. f = open(self.file_path, 'w')
  158. json.dump(self.det_boxes, f)
  159. except IOError as e:
  160. raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e)))
  161. else:
  162. f.close()
  163. return self.file_path
  164. def detect(self, outputs, batch, image_shape, image_id):
  165. """post process"""
  166. outputs_num = len(outputs)
  167. # output [|32, 52, 52, 3, 85| ]
  168. for batch_id in range(batch):
  169. for out_id in range(outputs_num):
  170. # 32, 52, 52, 3, 85
  171. out_item = outputs[out_id]
  172. # 52, 52, 3, 85
  173. out_item_single = out_item[batch_id, :]
  174. # get number of items in one head, [B, gx, gy, anchors, 5+80]
  175. dimensions = out_item_single.shape[:-1]
  176. out_num = 1
  177. for d in dimensions:
  178. out_num *= d
  179. ori_w, ori_h = image_shape[batch_id]
  180. img_id = int(image_id[batch_id])
  181. x = out_item_single[..., 0] * ori_w
  182. y = out_item_single[..., 1] * ori_h
  183. w = out_item_single[..., 2] * ori_w
  184. h = out_item_single[..., 3] * ori_h
  185. conf = out_item_single[..., 4:5]
  186. cls_emb = out_item_single[..., 5:]
  187. cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1)
  188. x = x.reshape(-1)
  189. y = y.reshape(-1)
  190. w = w.reshape(-1)
  191. h = h.reshape(-1)
  192. cls_emb = cls_emb.reshape(-1, 80)
  193. conf = conf.reshape(-1)
  194. cls_argmax = cls_argmax.reshape(-1)
  195. x_top_left = x - w / 2.
  196. y_top_left = y - h / 2.
  197. # create all False
  198. flag = np.random.random(cls_emb.shape) > sys.maxsize
  199. for i in range(flag.shape[0]):
  200. c = cls_argmax[i]
  201. flag[i, c] = True
  202. confidence = cls_emb[flag] * conf
  203. for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax):
  204. if confi < self.ignore_threshold:
  205. continue
  206. if img_id not in self.results:
  207. self.results[img_id] = defaultdict(list)
  208. x_lefti = max(0, x_lefti)
  209. y_lefti = max(0, y_lefti)
  210. wi = min(wi, ori_w)
  211. hi = min(hi, ori_h)
  212. # transform catId to match coco
  213. coco_clsi = self.coco_catids[clsi]
  214. self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi])
  215. def convert_testing_shape(args_test):
  216. testing_shape = [int(args_test.testing_shape), int(args_test.testing_shape)]
  217. return testing_shape
  218. def test():
  219. """test method"""
  220. # init distributed
  221. if args.is_distributed:
  222. init()
  223. args.rank = get_rank()
  224. args.group_size = get_group_size()
  225. # logger
  226. args.outputs_dir = os.path.join(args.log_path,
  227. datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  228. args.logger = get_logger(args.outputs_dir, args.rank)
  229. context.reset_auto_parallel_context()
  230. if args.is_distributed:
  231. parallel_mode = ParallelMode.DATA_PARALLEL
  232. else:
  233. parallel_mode = ParallelMode.STAND_ALONE
  234. context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1)
  235. args.logger.info('Creating Network....')
  236. network = YOLOV4CspDarkNet53(is_training=False)
  237. args.logger.info(args.pretrained)
  238. if os.path.isfile(args.pretrained):
  239. param_dict = load_checkpoint(args.pretrained)
  240. param_dict_new = {}
  241. for key, values in param_dict.items():
  242. if key.startswith('moments.'):
  243. continue
  244. elif key.startswith('yolo_network.'):
  245. param_dict_new[key[13:]] = values
  246. else:
  247. param_dict_new[key] = values
  248. load_param_into_net(network, param_dict_new)
  249. args.logger.info('load_model {} success'.format(args.pretrained))
  250. else:
  251. args.logger.info('{} not exists or not a pre-trained file'.format(args.pretrained))
  252. assert FileNotFoundError('{} not exists or not a pre-trained file'.format(args.pretrained))
  253. exit(1)
  254. data_root = args.data_root
  255. config = ConfigYOLOV4CspDarkNet53()
  256. if args.testing_shape:
  257. config.test_img_shape = convert_testing_shape(args)
  258. data_txt = os.path.join(args.data_dir, 'testdev2017.txt')
  259. ds, data_size = create_yolo_datasetv2(data_root, data_txt=data_txt, batch_size=args.per_batch_size,
  260. max_epoch=1, device_num=args.group_size, rank=args.rank, shuffle=False,
  261. config=config)
  262. args.logger.info('testing shape : {}'.format(config.test_img_shape))
  263. args.logger.info('totol {} images to eval'.format(data_size))
  264. network.set_train(False)
  265. # init detection engine
  266. detection = DetectionEngine(args)
  267. input_shape = Tensor(tuple(config.test_img_shape), ms.float32)
  268. args.logger.info('Start inference....')
  269. for i, data in enumerate(ds.create_dict_iterator()):
  270. image = Tensor(data["image"])
  271. image_shape = Tensor(data["image_shape"])
  272. image_id = Tensor(data["img_id"])
  273. prediction = network(image, input_shape)
  274. output_big, output_me, output_small = prediction
  275. output_big = output_big.asnumpy()
  276. output_me = output_me.asnumpy()
  277. output_small = output_small.asnumpy()
  278. image_id = image_id.asnumpy()
  279. image_shape = image_shape.asnumpy()
  280. detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape, image_id)
  281. if i % 1000 == 0:
  282. args.logger.info('Processing... {:.2f}% '.format(i * args.per_batch_size / data_size * 100))
  283. args.logger.info('Calculating mAP...')
  284. detection.do_nms_for_results()
  285. result_file_path = detection.write_result()
  286. args.logger.info('result file path: {}'.format(result_file_path))
  287. if __name__ == "__main__":
  288. test()