You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Face detection eval."""
  16. import os
  17. import argparse
  18. import matplotlib.pyplot as plt
  19. from mindspore import context
  20. from mindspore import Tensor
  21. from mindspore.context import ParallelMode
  22. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  23. from mindspore.common import dtype as mstype
  24. import mindspore.dataset as de
  25. from src.data_preprocess import SingleScaleTrans
  26. from src.config import config
  27. from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3
  28. from src.FaceDetection import voc_wrapper
  29. from src.network_define import BuildTestNetwork, get_bounding_boxes, tensor_to_brambox, \
  30. parse_gt_from_anno, parse_rets, calc_recall_precision_ap
  31. plt.switch_backend('agg')
  32. devid = int(os.getenv('DEVICE_ID'))
  33. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
  34. def parse_args():
  35. '''parse_args'''
  36. parser = argparse.ArgumentParser('Yolov3 Face Detection')
  37. parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
  38. parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
  39. parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
  40. parser.add_argument('--world_size', type=int, default=1, help='current process number to support distributed')
  41. arg, _ = parser.parse_known_args()
  42. return arg
  43. if __name__ == "__main__":
  44. args = parse_args()
  45. print('=============yolov3 start evaluating==================')
  46. # logger
  47. args.batch_size = config.batch_size
  48. args.input_shape = config.input_shape
  49. args.result_path = config.result_path
  50. args.conf_thresh = config.conf_thresh
  51. args.nms_thresh = config.nms_thresh
  52. context.set_auto_parallel_context(parallel_mode=ParallelMode.STAND_ALONE, device_num=args.world_size,
  53. gradients_mean=True)
  54. mindrecord_path = args.mindrecord_path
  55. print('Loading data from {}'.format(mindrecord_path))
  56. num_classes = config.num_classes
  57. if num_classes > 1:
  58. raise NotImplementedError('num_classes > 1: Yolov3 postprocess not implemented!')
  59. anchors = config.anchors
  60. anchors_mask = config.anchors_mask
  61. num_anchors_list = [len(x) for x in anchors_mask]
  62. reduction_0 = 64.0
  63. reduction_1 = 32.0
  64. reduction_2 = 16.0
  65. labels = ['face']
  66. classes = {0: 'face'}
  67. # dataloader
  68. ds = de.MindDataset(mindrecord_path + "0", columns_list=["image", "annotation", "image_name", "image_size"])
  69. single_scale_trans = SingleScaleTrans(resize=args.input_shape)
  70. ds = ds.batch(args.batch_size, per_batch_map=single_scale_trans,
  71. input_columns=["image", "annotation", "image_name", "image_size"], num_parallel_workers=8)
  72. args.steps_per_epoch = ds.get_dataset_size()
  73. # backbone
  74. network = backbone_HwYolov3(num_classes, num_anchors_list, args)
  75. # load pretrain model
  76. if os.path.isfile(args.pretrained):
  77. param_dict = load_checkpoint(args.pretrained)
  78. param_dict_new = {}
  79. for key, values in param_dict.items():
  80. if key.startswith('moments.'):
  81. continue
  82. elif key.startswith('network.'):
  83. param_dict_new[key[8:]] = values
  84. else:
  85. param_dict_new[key] = values
  86. load_param_into_net(network, param_dict_new)
  87. print('load model {} success'.format(args.pretrained))
  88. else:
  89. print('load model {} failed, please check the path of model, evaluating end'.format(args.pretrained))
  90. exit(0)
  91. ds = ds.repeat(1)
  92. det = {}
  93. img_size = {}
  94. img_anno = {}
  95. model_name = args.pretrained.split('/')[-1].replace('.ckpt', '')
  96. result_path = os.path.join(args.result_path, model_name)
  97. if os.path.exists(result_path):
  98. pass
  99. if not os.path.isdir(result_path):
  100. os.makedirs(result_path, exist_ok=True)
  101. # result file
  102. ret_files_set = {
  103. 'face': os.path.join(result_path, 'comp4_det_test_face_rm5050.txt'),
  104. }
  105. test_net = BuildTestNetwork(network, reduction_0, reduction_1, reduction_2, anchors, anchors_mask, num_classes,
  106. args)
  107. print('conf_thresh:', args.conf_thresh)
  108. eval_times = 0
  109. for data in ds.create_tuple_iterator(output_numpy=True):
  110. batch_images = data[0]
  111. batch_labels = data[1]
  112. batch_image_name = data[2]
  113. batch_image_size = data[3]
  114. eval_times += 1
  115. img_tensor = Tensor(batch_images, mstype.float32)
  116. dets = []
  117. tdets = []
  118. coords_0, cls_scores_0, coords_1, cls_scores_1, coords_2, cls_scores_2 = test_net(img_tensor)
  119. boxes_0, boxes_1, boxes_2 = get_bounding_boxes(coords_0, cls_scores_0, coords_1, cls_scores_1, coords_2,
  120. cls_scores_2, args.conf_thresh, args.input_shape,
  121. num_classes)
  122. converted_boxes_0, converted_boxes_1, converted_boxes_2 = tensor_to_brambox(boxes_0, boxes_1, boxes_2,
  123. args.input_shape, labels)
  124. tdets.append(converted_boxes_0)
  125. tdets.append(converted_boxes_1)
  126. tdets.append(converted_boxes_2)
  127. batch = len(tdets[0])
  128. for b in range(batch):
  129. single_dets = []
  130. for op in range(3):
  131. single_dets.extend(tdets[op][b])
  132. dets.append(single_dets)
  133. det.update({batch_image_name[k].decode('UTF-8'): v for k, v in enumerate(dets)})
  134. img_size.update({batch_image_name[k].decode('UTF-8'): v for k, v in enumerate(batch_image_size)})
  135. img_anno.update({batch_image_name[k].decode('UTF-8'): v for k, v in enumerate(batch_labels)})
  136. print('eval times:', eval_times)
  137. print('batch size: ', args.batch_size)
  138. netw, neth = args.input_shape
  139. reorg_dets = voc_wrapper.reorg_detection(det, netw, neth, img_size)
  140. voc_wrapper.gen_results(reorg_dets, result_path, img_size, args.nms_thresh)
  141. # compute mAP
  142. ground_truth = parse_gt_from_anno(img_anno, classes)
  143. ret_list = parse_rets(ret_files_set)
  144. iou_thr = 0.5
  145. evaluate = calc_recall_precision_ap(ground_truth, ret_list, iou_thr)
  146. aps_str = ''
  147. for cls in evaluate:
  148. per_line, = plt.plot(evaluate[cls]['recall'], evaluate[cls]['precision'], 'b-')
  149. per_line.set_label('%s:AP=%.3f' % (cls, evaluate[cls]['ap']))
  150. aps_str += '_%s_AP_%.3f' % (cls, evaluate[cls]['ap'])
  151. plt.plot([i / 1000.0 for i in range(1, 1001)], [i / 1000.0 for i in range(1, 1001)], 'y--')
  152. plt.axis([0, 1.2, 0, 1.2])
  153. plt.xlabel('recall')
  154. plt.ylabel('precision')
  155. plt.grid()
  156. plt.legend()
  157. plt.title('PR')
  158. # save mAP
  159. ap_save_path = os.path.join(result_path, result_path.replace('/', '_') + aps_str + '.png')
  160. print('Saving {}'.format(ap_save_path))
  161. plt.savefig(ap_save_path)
  162. print('=============yolov3 evaluating finished==================')