You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

postprocess.py 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # less required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """post process for 310 inference"""
  16. import os
  17. import argparse
  18. import numpy as np
  19. from PIL import Image
  20. from src.config import config
  21. from src.eval_utils import metrics
  22. batch_size = 1
  23. parser = argparse.ArgumentParser(description="ssd_mobilenet_v1_fpn inference")
  24. parser.add_argument("--result_path", type=str, required=True, help="result files path.")
  25. parser.add_argument("--img_path", type=str, required=True, help="image file path.")
  26. args = parser.parse_args()
  27. def get_imgSize(file_name):
  28. img = Image.open(file_name)
  29. return img.size
  30. def get_result(result_path, img_id_file_path):
  31. anno_json = os.path.join(config.coco_root, config.instances_set.format(config.val_data_type))
  32. files = os.listdir(img_id_file_path)
  33. pred_data = []
  34. for file in files:
  35. img_ids_name = file.split('.')[0]
  36. img_id = int(np.squeeze(img_ids_name))
  37. img_size = get_imgSize(os.path.join(img_id_file_path, file))
  38. image_shape = np.array([img_size[1], img_size[0]])
  39. result_path_0 = os.path.join(result_path, img_ids_name + "_0.bin")
  40. result_path_1 = os.path.join(result_path, img_ids_name + "_1.bin")
  41. boxes = np.fromfile(result_path_0, dtype=np.float32).reshape(51150, 4)
  42. box_scores = np.fromfile(result_path_1, dtype=np.float32).reshape(51150, 81)
  43. pred_data.append({
  44. "boxes": boxes,
  45. "box_scores": box_scores,
  46. "img_id": img_id,
  47. "image_shape": image_shape
  48. })
  49. mAP = metrics(pred_data, anno_json)
  50. print(f" mAP:{mAP}")
  51. if __name__ == '__main__':
  52. get_result(args.result_path, args.img_path)