You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

postprocess.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """post process for 310 inference"""
  16. import os
  17. import argparse
  18. import numpy as np
  19. import cv2
  20. from eval import cal_hist, pre_process
  21. def parse_args():
  22. parser = argparse.ArgumentParser(description="deeplabv3 accuracy calculation")
  23. parser.add_argument('--data_root', type=str, default='', help='root path of val data')
  24. parser.add_argument('--data_lst', type=str, default='', help='list of val data')
  25. parser.add_argument('--batch_size', type=int, default=1, help='batch size')
  26. parser.add_argument('--crop_size', type=int, default=513, help='crop size')
  27. parser.add_argument('--scales', type=float, action='append', help='scales of evaluation')
  28. parser.add_argument('--flip', action='store_true', help='perform left-right flip')
  29. parser.add_argument('--ignore_label', type=int, default=255, help='ignore label')
  30. parser.add_argument('--num_classes', type=int, default=21, help='number of classes')
  31. parser.add_argument('--result_path', type=str, default='./result_Files', help='result Files path')
  32. args, _ = parser.parse_known_args()
  33. return args
  34. def eval_batch(args, result_file, img_lst, crop_size=513, flip=True):
  35. result_lst = []
  36. batch_size = len(img_lst)
  37. batch_img = np.zeros((args.batch_size, 3, crop_size, crop_size), dtype=np.float32)
  38. resize_hw = []
  39. for l in range(batch_size):
  40. img_ = img_lst[l]
  41. img_, resize_h, resize_w = pre_process(args, img_, crop_size)
  42. batch_img[l] = img_
  43. resize_hw.append([resize_h, resize_w])
  44. batch_img = np.ascontiguousarray(batch_img)
  45. net_out = np.fromfile(result_file, np.float32).reshape(args.batch_size, args.num_classes, crop_size, crop_size)
  46. for bs in range(batch_size):
  47. probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0))
  48. ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1]
  49. probs_ = cv2.resize(probs_, (ori_w, ori_h))
  50. result_lst.append(probs_)
  51. return result_lst
  52. def eval_batch_scales(args, eval_net, img_lst, scales,
  53. base_crop_size=513, flip=True):
  54. sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales]
  55. probs_lst = eval_batch(args, eval_net, img_lst, crop_size=sizes_[0], flip=flip)
  56. print(sizes_)
  57. for crop_size_ in sizes_[1:]:
  58. probs_lst_tmp = eval_batch(args, eval_net, img_lst, crop_size=crop_size_, flip=flip)
  59. for pl, _ in enumerate(probs_lst):
  60. probs_lst[pl] += probs_lst_tmp[pl]
  61. result_msk = []
  62. for i in probs_lst:
  63. result_msk.append(i.argmax(axis=2))
  64. return result_msk
  65. def acc_cal():
  66. args = parse_args()
  67. args.image_mean = [103.53, 116.28, 123.675]
  68. args.image_std = [57.375, 57.120, 58.395]
  69. # data list
  70. with open(args.data_lst) as f:
  71. img_lst = f.readlines()
  72. # evaluate
  73. hist = np.zeros((args.num_classes, args.num_classes))
  74. batch_img_lst = []
  75. batch_msk_lst = []
  76. bi = 0
  77. image_num = 0
  78. for i, line in enumerate(img_lst):
  79. img_path, msk_path = line.strip().split(' ')
  80. result_file = os.path.join(args.result_path, os.path.basename(img_path).split('.jpg')[0] + '_0.bin')
  81. img_path = os.path.join(args.data_root, img_path)
  82. msk_path = os.path.join(args.data_root, msk_path)
  83. img_ = cv2.imread(img_path)
  84. msk_ = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
  85. batch_img_lst.append(img_)
  86. batch_msk_lst.append(msk_)
  87. bi += 1
  88. if bi == args.batch_size:
  89. batch_res = eval_batch_scales(args, result_file, batch_img_lst, scales=args.scales,
  90. base_crop_size=args.crop_size, flip=args.flip)
  91. for mi in range(args.batch_size):
  92. hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes)
  93. bi = 0
  94. batch_img_lst = []
  95. batch_msk_lst = []
  96. print('processed {} images'.format(i+1))
  97. image_num = i
  98. if bi > 0:
  99. batch_res = eval_batch_scales(args, result_file, batch_img_lst, scales=args.scales,
  100. base_crop_size=args.crop_size, flip=args.flip)
  101. for mi in range(bi):
  102. hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes)
  103. print('processed {} images'.format(image_num + 1))
  104. print(hist)
  105. iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
  106. print('per-class IoU', iu)
  107. print('mean IoU', np.nanmean(iu))
  108. if __name__ == '__main__':
  109. acc_cal()