You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

postprocess.py 3.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """post process for 310 inference"""
  16. import os
  17. import argparse
  18. import numpy as np
  19. from PIL import Image
  20. from src.config import config
  21. from src.eval_utils import metrics
  22. batch_size = 1
  23. parser = argparse.ArgumentParser(description="ssd acc calculation")
  24. parser.add_argument("--result_path", type=str, required=True, help="result files path.")
  25. parser.add_argument("--img_path", type=str, required=True, help="image file path.")
  26. parser.add_argument("--drop", action="store_true", help="drop iscrowd images or not.")
  27. args = parser.parse_args()
  28. def get_imgSize(file_name):
  29. img = Image.open(file_name)
  30. return img.size
  31. def get_result(result_path, img_id_file_path):
  32. anno_json = os.path.join(config.coco_root, config.instances_set.format(config.val_data_type))
  33. if args.drop:
  34. from pycocotools.coco import COCO
  35. train_cls = config.classes
  36. train_cls_dict = {}
  37. for i, cls in enumerate(train_cls):
  38. train_cls_dict[cls] = i
  39. coco = COCO(anno_json)
  40. classs_dict = {}
  41. cat_ids = coco.loadCats(coco.getCatIds())
  42. for cat in cat_ids:
  43. classs_dict[cat["id"]] = cat["name"]
  44. files = os.listdir(img_id_file_path)
  45. pred_data = []
  46. for file in files:
  47. img_ids_name = file.split('.')[0]
  48. img_id = int(np.squeeze(img_ids_name))
  49. if args.drop:
  50. anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
  51. anno = coco.loadAnns(anno_ids)
  52. annos = []
  53. iscrowd = False
  54. for label in anno:
  55. bbox = label["bbox"]
  56. class_name = classs_dict[label["category_id"]]
  57. iscrowd = iscrowd or label["iscrowd"]
  58. if class_name in train_cls:
  59. x_min, x_max = bbox[0], bbox[0] + bbox[2]
  60. y_min, y_max = bbox[1], bbox[1] + bbox[3]
  61. annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]])
  62. if iscrowd or (not annos):
  63. continue
  64. img_size = get_imgSize(os.path.join(img_id_file_path, file))
  65. image_shape = np.array([img_size[1], img_size[0]])
  66. result_path_0 = os.path.join(result_path, img_ids_name + "_0.bin")
  67. result_path_1 = os.path.join(result_path, img_ids_name + "_1.bin")
  68. boxes = np.fromfile(result_path_0, dtype=np.float32).reshape(config.num_ssd_boxes, 4)
  69. box_scores = np.fromfile(result_path_1, dtype=np.float32).reshape(config.num_ssd_boxes, config.num_classes)
  70. pred_data.append({
  71. "boxes": boxes,
  72. "box_scores": box_scores,
  73. "img_id": img_id,
  74. "image_shape": image_shape
  75. })
  76. mAP = metrics(pred_data, anno_json)
  77. print(f" mAP:{mAP}")
  78. if __name__ == '__main__':
  79. get_result(args.result_path, args.img_path)