You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

postprocess.py 3.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """YoloV4 310 infer."""
  16. import os
  17. import argparse
  18. import datetime
  19. import time
  20. import numpy as np
  21. from pycocotools.coco import COCO
  22. from src.logger import get_logger
  23. from eval import DetectionEngine
  24. parser = argparse.ArgumentParser('mindspore coco testing')
  25. # dataset related
  26. parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu')
  27. # logging related
  28. parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location')
  29. # detect_related
  30. parser.add_argument('--nms_thresh', type=float, default=0.5, help='threshold for NMS')
  31. parser.add_argument('--ann_file', type=str, default='', help='path to annotation')
  32. parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes')
  33. parser.add_argument('--img_id_file_path', type=str, default='', help='path of image dataset')
  34. parser.add_argument('--result_files', type=str, default='./result_Files', help='path to 310 infer result floder')
  35. args, _ = parser.parse_known_args()
  36. class Redirct:
  37. def __init__(self):
  38. self.content = ""
  39. def write(self, content):
  40. self.content += content
  41. def flush(self):
  42. self.content = ""
  43. if __name__ == "__main__":
  44. start_time = time.time()
  45. args.outputs_dir = os.path.join(args.log_path,
  46. datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  47. args.logger = get_logger(args.outputs_dir, 0)
  48. # init detection engine
  49. detection = DetectionEngine(args)
  50. coco = COCO(args.ann_file)
  51. result_path = args.result_files
  52. files = os.listdir(args.img_id_file_path)
  53. for file in files:
  54. img_ids_name = file.split('.')[0]
  55. img_id = int(np.squeeze(img_ids_name))
  56. imgIds = coco.getImgIds(imgIds=[img_id])
  57. img = coco.loadImgs(imgIds[np.random.randint(0, len(imgIds))])[0]
  58. image_shape = ((img['width'], img['height']),)
  59. img_id = (np.squeeze(img_ids_name),)
  60. result_path_0 = os.path.join(result_path, img_ids_name + "_0.bin")
  61. result_path_1 = os.path.join(result_path, img_ids_name + "_1.bin")
  62. result_path_2 = os.path.join(result_path, img_ids_name + "_2.bin")
  63. output_small = np.fromfile(result_path_0, dtype=np.float32).reshape(1, 19, 19, 3, 85)
  64. output_me = np.fromfile(result_path_1, dtype=np.float32).reshape(1, 38, 38, 3, 85)
  65. output_big = np.fromfile(result_path_2, dtype=np.float32).reshape(1, 76, 76, 3, 85)
  66. detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape, img_id)
  67. args.logger.info('Calculating mAP...')
  68. detection.do_nms_for_results()
  69. result_file_path = detection.write_result()
  70. args.logger.info('result file path: {}'.format(result_file_path))
  71. eval_result = detection.get_eval_result()
  72. cost_time = time.time() - start_time
  73. args.logger.info('\n=============coco eval reulst=========\n' + eval_result)
  74. args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.))