You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

postprocess.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """unet 310 infer."""
  16. import os
  17. import argparse
  18. import cv2
  19. import numpy as np
  20. from src.data_loader import create_dataset, create_cell_nuclei_dataset
  21. from src.config import cfg_unet
  22. class dice_coeff():
  23. def __init__(self):
  24. self.clear()
  25. def clear(self):
  26. self._dice_coeff_sum = 0
  27. self._iou_sum = 0
  28. self._samples_num = 0
  29. def update(self, *inputs):
  30. if len(inputs) != 2:
  31. raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs)))
  32. y = np.array(inputs[1])
  33. self._samples_num += y.shape[0]
  34. y = y.transpose(0, 2, 3, 1)
  35. b, h, w, c = y.shape
  36. if b != 1:
  37. raise ValueError('Batch size should be 1 when in evaluation.')
  38. y = y.reshape((h, w, c))
  39. if cfg_unet["eval_activate"].lower() == "softmax":
  40. y_softmax = np.squeeze(inputs[0][0], axis=0)
  41. if cfg_unet["eval_resize"]:
  42. y_pred = []
  43. for m in range(cfg_unet["num_classes"]):
  44. y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, m] * 255), (w, h)) / 255)
  45. y_pred = np.stack(y_pred, axis=-1)
  46. else:
  47. y_pred = y_softmax
  48. elif cfg_unet["eval_activate"].lower() == "argmax":
  49. y_argmax = np.squeeze(inputs[0][1], axis=0)
  50. y_pred = []
  51. for n in range(cfg_unet["num_classes"]):
  52. if cfg_unet["eval_resize"]:
  53. y_pred.append(cv2.resize(np.uint8(y_argmax == n), (w, h), interpolation=cv2.INTER_NEAREST))
  54. else:
  55. y_pred.append(np.float32(y_argmax == n))
  56. y_pred = np.stack(y_pred, axis=-1)
  57. else:
  58. raise ValueError('config eval_activate should be softmax or argmax.')
  59. y_pred = y_pred.astype(np.float32)
  60. inter = np.dot(y_pred.flatten(), y.flatten())
  61. union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
  62. single_dice_coeff = 2*float(inter)/float(union+1e-6)
  63. single_iou = single_dice_coeff / (2 - single_dice_coeff)
  64. print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou))
  65. self._dice_coeff_sum += single_dice_coeff
  66. self._iou_sum += single_iou
  67. def eval(self):
  68. if self._samples_num == 0:
  69. raise RuntimeError('Total samples num must not be 0.')
  70. return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
  71. def test_net(data_dir,
  72. cross_valid_ind=1,
  73. cfg=None):
  74. if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
  75. valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False,
  76. eval_resize=cfg["eval_resize"], split=0.8)
  77. else:
  78. _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'],
  79. img_size=cfg['img_size'])
  80. labels_list = []
  81. for data in valid_dataset:
  82. labels_list.append(data[1].asnumpy())
  83. return labels_list
  84. def get_args():
  85. parser = argparse.ArgumentParser(description='Test the UNet on images and target masks',
  86. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  87. parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
  88. help='data directory')
  89. parser.add_argument('-p', '--rst_path', dest='rst_path', type=str, default='./result_Files/',
  90. help='infer result path')
  91. return parser.parse_args()
  92. if __name__ == '__main__':
  93. args = get_args()
  94. label_list = test_net(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet)
  95. rst_path = args.rst_path
  96. metrics = dice_coeff()
  97. if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei":
  98. for i, bin_name in enumerate(os.listdir('./preprocess_Result/')):
  99. bin_name_softmax = bin_name.replace(".png", "") + "_0.bin"
  100. bin_name_argmax = bin_name.replace(".png", "") + "_1.bin"
  101. file_name_sof = rst_path + bin_name_softmax
  102. file_name_arg = rst_path + bin_name_argmax
  103. softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 96, 96, 2)
  104. argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 96, 96)
  105. label = label_list[i]
  106. metrics.update((softmax_out, argmax_out), label)
  107. else:
  108. for j in range(len(os.listdir('./preprocess_Result/'))):
  109. file_name_sof = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin"
  110. file_name_arg = rst_path + "ISBI_test_bs_1_" + str(j) + "_1" + ".bin"
  111. softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 576, 576, 2)
  112. argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 576, 576)
  113. label = label_list[j]
  114. metrics.update((softmax_out, argmax_out), label)
  115. eval_score = metrics.eval()
  116. print("============== Cross valid dice coeff is:", eval_score[0])
  117. print("============== Cross valid IOU is:", eval_score[1])