You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Face detection eval."""
  16. import os
  17. import argparse
  18. import matplotlib.pyplot as plt
  19. from mindspore import context
  20. from mindspore import Tensor
  21. from mindspore.context import ParallelMode
  22. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  23. from mindspore.common import dtype as mstype
  24. import mindspore.dataset as de
  25. from src.data_preprocess import SingleScaleTrans
  26. from src.config import config
  27. from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3
  28. from src.FaceDetection import voc_wrapper
  29. from src.network_define import BuildTestNetwork, get_bounding_boxes, tensor_to_brambox, \
  30. parse_gt_from_anno, parse_rets, calc_recall_precision_ap
  31. plt.switch_backend('agg')
  32. def parse_args():
  33. '''parse_args'''
  34. parser = argparse.ArgumentParser('Yolov3 Face Detection')
  35. parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "CPU"),
  36. help="run platform, support Ascend and CPU.")
  37. parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
  38. parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
  39. parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
  40. parser.add_argument('--world_size', type=int, default=1, help='current process number to support distributed')
  41. arg, _ = parser.parse_known_args()
  42. return arg
  43. if __name__ == "__main__":
  44. args = parse_args()
  45. devid = int(os.getenv('DEVICE_ID', '0')) if args.run_platform != 'CPU' else 0
  46. context.set_context(mode=context.GRAPH_MODE, device_target=args.run_platform, save_graphs=False, device_id=devid)
  47. print('=============yolov3 start evaluating==================')
  48. # logger
  49. args.batch_size = config.batch_size
  50. args.input_shape = config.input_shape
  51. args.result_path = config.result_path
  52. args.conf_thresh = config.conf_thresh
  53. args.nms_thresh = config.nms_thresh
  54. context.set_auto_parallel_context(parallel_mode=ParallelMode.STAND_ALONE, device_num=args.world_size,
  55. gradients_mean=True)
  56. mindrecord_path = args.mindrecord_path
  57. print('Loading data from {}'.format(mindrecord_path))
  58. num_classes = config.num_classes
  59. if num_classes > 1:
  60. raise NotImplementedError('num_classes > 1: Yolov3 postprocess not implemented!')
  61. anchors = config.anchors
  62. anchors_mask = config.anchors_mask
  63. num_anchors_list = [len(x) for x in anchors_mask]
  64. reduction_0 = 64.0
  65. reduction_1 = 32.0
  66. reduction_2 = 16.0
  67. labels = ['face']
  68. classes = {0: 'face'}
  69. # dataloader
  70. ds = de.MindDataset(mindrecord_path + "0", columns_list=["image", "annotation", "image_name", "image_size"])
  71. single_scale_trans = SingleScaleTrans(resize=args.input_shape)
  72. ds = ds.batch(args.batch_size, per_batch_map=single_scale_trans,
  73. input_columns=["image", "annotation", "image_name", "image_size"], num_parallel_workers=8)
  74. args.steps_per_epoch = ds.get_dataset_size()
  75. # backbone
  76. network = backbone_HwYolov3(num_classes, num_anchors_list, args)
  77. # load pretrain model
  78. if os.path.isfile(args.pretrained):
  79. param_dict = load_checkpoint(args.pretrained)
  80. param_dict_new = {}
  81. for key, values in param_dict.items():
  82. if key.startswith('moments.'):
  83. continue
  84. elif key.startswith('network.'):
  85. param_dict_new[key[8:]] = values
  86. else:
  87. param_dict_new[key] = values
  88. load_param_into_net(network, param_dict_new)
  89. print('load model {} success'.format(args.pretrained))
  90. else:
  91. print('load model {} failed, please check the path of model, evaluating end'.format(args.pretrained))
  92. exit(0)
  93. ds = ds.repeat(1)
  94. det = {}
  95. img_size = {}
  96. img_anno = {}
  97. model_name = args.pretrained.split('/')[-1].replace('.ckpt', '')
  98. result_path = os.path.join(args.result_path, model_name)
  99. if os.path.exists(result_path):
  100. pass
  101. if not os.path.isdir(result_path):
  102. os.makedirs(result_path, exist_ok=True)
  103. # result file
  104. ret_files_set = {
  105. 'face': os.path.join(result_path, 'comp4_det_test_face_rm5050.txt'),
  106. }
  107. test_net = BuildTestNetwork(network, reduction_0, reduction_1, reduction_2, anchors, anchors_mask, num_classes,
  108. args)
  109. print('conf_thresh:', args.conf_thresh)
  110. eval_times = 0
  111. for data in ds.create_tuple_iterator(output_numpy=True):
  112. batch_images = data[0]
  113. batch_labels = data[1]
  114. batch_image_name = data[2]
  115. batch_image_size = data[3]
  116. eval_times += 1
  117. img_tensor = Tensor(batch_images, mstype.float32)
  118. dets = []
  119. tdets = []
  120. coords_0, cls_scores_0, coords_1, cls_scores_1, coords_2, cls_scores_2 = test_net(img_tensor)
  121. boxes_0, boxes_1, boxes_2 = get_bounding_boxes(coords_0, cls_scores_0, coords_1, cls_scores_1, coords_2,
  122. cls_scores_2, args.conf_thresh, args.input_shape,
  123. num_classes)
  124. converted_boxes_0, converted_boxes_1, converted_boxes_2 = tensor_to_brambox(boxes_0, boxes_1, boxes_2,
  125. args.input_shape, labels)
  126. tdets.append(converted_boxes_0)
  127. tdets.append(converted_boxes_1)
  128. tdets.append(converted_boxes_2)
  129. batch = len(tdets[0])
  130. for b in range(batch):
  131. single_dets = []
  132. for op in range(3):
  133. single_dets.extend(tdets[op][b])
  134. dets.append(single_dets)
  135. det.update({batch_image_name[k].decode('UTF-8'): v for k, v in enumerate(dets)})
  136. img_size.update({batch_image_name[k].decode('UTF-8'): v for k, v in enumerate(batch_image_size)})
  137. img_anno.update({batch_image_name[k].decode('UTF-8'): v for k, v in enumerate(batch_labels)})
  138. print('eval times:', eval_times)
  139. print('batch size: ', args.batch_size)
  140. netw, neth = args.input_shape
  141. reorg_dets = voc_wrapper.reorg_detection(det, netw, neth, img_size)
  142. voc_wrapper.gen_results(reorg_dets, result_path, img_size, args.nms_thresh)
  143. # compute mAP
  144. ground_truth = parse_gt_from_anno(img_anno, classes)
  145. ret_list = parse_rets(ret_files_set)
  146. iou_thr = 0.5
  147. evaluate = calc_recall_precision_ap(ground_truth, ret_list, iou_thr)
  148. aps_str = ''
  149. for cls in evaluate:
  150. per_line, = plt.plot(evaluate[cls]['recall'], evaluate[cls]['precision'], 'b-')
  151. per_line.set_label('%s:AP=%.3f' % (cls, evaluate[cls]['ap']))
  152. aps_str += '_%s_AP_%.3f' % (cls, evaluate[cls]['ap'])
  153. plt.plot([i / 1000.0 for i in range(1, 1001)], [i / 1000.0 for i in range(1, 1001)], 'y--')
  154. plt.axis([0, 1.2, 0, 1.2])
  155. plt.xlabel('recall')
  156. plt.ylabel('precision')
  157. plt.grid()
  158. plt.legend()
  159. plt.title('PR')
  160. # save mAP
  161. ap_save_path = os.path.join(result_path, result_path.replace('/', '_') + aps_str + '.png')
  162. print('Saving {}'.format(ap_save_path))
  163. plt.savefig(ap_save_path)
  164. print('=============yolov3 evaluating finished==================')