You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

postprocess.py 5.5 kB

4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """unet 310 infer."""
  16. import os
  17. import argparse
  18. import cv2
  19. import numpy as np
  20. from src.config import cfg_unet
  21. class dice_coeff():
  22. def __init__(self):
  23. self.clear()
  24. def clear(self):
  25. self._dice_coeff_sum = 0
  26. self._iou_sum = 0
  27. self._samples_num = 0
  28. def update(self, *inputs):
  29. if len(inputs) != 2:
  30. raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs)))
  31. y = np.array(inputs[1])
  32. self._samples_num += y.shape[0]
  33. y = y.transpose(0, 2, 3, 1)
  34. b, h, w, c = y.shape
  35. if b != 1:
  36. raise ValueError('Batch size should be 1 when in evaluation.')
  37. y = y.reshape((h, w, c))
  38. if cfg_unet["eval_activate"].lower() == "softmax":
  39. y_softmax = np.squeeze(inputs[0][0], axis=0)
  40. if cfg_unet["eval_resize"]:
  41. y_pred = []
  42. for m in range(cfg_unet["num_classes"]):
  43. y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, m] * 255), (w, h)) / 255)
  44. y_pred = np.stack(y_pred, axis=-1)
  45. else:
  46. y_pred = y_softmax
  47. elif cfg_unet["eval_activate"].lower() == "argmax":
  48. y_argmax = np.squeeze(inputs[0][1], axis=0)
  49. y_pred = []
  50. for n in range(cfg_unet["num_classes"]):
  51. if cfg_unet["eval_resize"]:
  52. y_pred.append(cv2.resize(np.uint8(y_argmax == n), (w, h), interpolation=cv2.INTER_NEAREST))
  53. else:
  54. y_pred.append(np.float32(y_argmax == n))
  55. y_pred = np.stack(y_pred, axis=-1)
  56. else:
  57. raise ValueError('config eval_activate should be softmax or argmax.')
  58. y_pred = y_pred.astype(np.float32)
  59. inter = np.dot(y_pred.flatten(), y.flatten())
  60. union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
  61. single_dice_coeff = 2*float(inter)/float(union+1e-6)
  62. single_iou = single_dice_coeff / (2 - single_dice_coeff)
  63. print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou))
  64. self._dice_coeff_sum += single_dice_coeff
  65. self._iou_sum += single_iou
  66. def eval(self):
  67. if self._samples_num == 0:
  68. raise RuntimeError('Total samples num must not be 0.')
  69. return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
  70. def get_args():
  71. parser = argparse.ArgumentParser(description='Test the UNet on images and target masks',
  72. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  73. parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
  74. help='data directory')
  75. parser.add_argument('-p', '--rst_path', dest='rst_path', type=str, default='./result_Files/',
  76. help='infer result path')
  77. return parser.parse_args()
  78. if __name__ == '__main__':
  79. args = get_args()
  80. rst_path = args.rst_path
  81. metrics = dice_coeff()
  82. if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei":
  83. img_size = tuple(cfg_unet['img_size'])
  84. for i, bin_name in enumerate(os.listdir('./preprocess_Result/')):
  85. f = bin_name.replace(".png", "")
  86. bin_name_softmax = f + "_0.bin"
  87. bin_name_argmax = f + "_1.bin"
  88. file_name_sof = rst_path + bin_name_softmax
  89. file_name_arg = rst_path + bin_name_argmax
  90. softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 96, 96, 2)
  91. argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 96, 96)
  92. mask = cv2.imread(os.path.join(args.data_url, f, "mask.png"), cv2.IMREAD_GRAYSCALE)
  93. mask = cv2.resize(mask, img_size)
  94. mask = mask.astype(np.float32) / 255
  95. mask = (mask > 0.5).astype(np.int)
  96. mask = (np.arange(2) == mask[..., None]).astype(int)
  97. mask = mask.transpose(2, 0, 1).astype(np.float32)
  98. label = mask.reshape(1, 2, 96, 96)
  99. metrics.update((softmax_out, argmax_out), label)
  100. else:
  101. label_list = np.load('label.npy')
  102. for j in range(len(os.listdir('./preprocess_Result/'))):
  103. file_name_sof = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin"
  104. file_name_arg = rst_path + "ISBI_test_bs_1_" + str(j) + "_1" + ".bin"
  105. softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 576, 576, 2)
  106. argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 576, 576)
  107. label = label_list[j]
  108. metrics.update((softmax_out, argmax_out), label)
  109. eval_score = metrics.eval()
  110. print("============== Cross valid dice coeff is:", eval_score[0])
  111. print("============== Cross valid IOU is:", eval_score[1])