You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

postprocess.py 3.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """unet 310 infer."""
  16. import os
  17. import argparse
  18. import numpy as np
  19. from src.data_loader import create_dataset
  20. from src.config import cfg_unet
  21. from scipy.special import softmax
  22. class dice_coeff():
  23. def __init__(self):
  24. self.clear()
  25. def clear(self):
  26. self._dice_coeff_sum = 0
  27. self._samples_num = 0
  28. def update(self, *inputs):
  29. if len(inputs) != 2:
  30. raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
  31. y_pred = inputs[0]
  32. y = np.array(inputs[1])
  33. self._samples_num += y.shape[0]
  34. y_pred = y_pred.transpose(0, 2, 3, 1)
  35. y = y.transpose(0, 2, 3, 1)
  36. y_pred = softmax(y_pred, axis=3)
  37. inter = np.dot(y_pred.flatten(), y.flatten())
  38. union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
  39. single_dice_coeff = 2*float(inter)/float(union+1e-6)
  40. print("single dice coeff is:", single_dice_coeff)
  41. self._dice_coeff_sum += single_dice_coeff
  42. def eval(self):
  43. if self._samples_num == 0:
  44. raise RuntimeError('Total samples num must not be 0.')
  45. return self._dice_coeff_sum / float(self._samples_num)
  46. def test_net(data_dir,
  47. cross_valid_ind=1,
  48. cfg=None):
  49. _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'],
  50. img_size=cfg['img_size'])
  51. labels_list = []
  52. for data in valid_dataset:
  53. labels_list.append(data[1].asnumpy())
  54. return labels_list
  55. def get_args():
  56. parser = argparse.ArgumentParser(description='Test the UNet on images and target masks',
  57. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  58. parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
  59. help='data directory')
  60. parser.add_argument('-p', '--rst_path', dest='rst_path', type=str, default='./result_Files/',
  61. help='infer result path')
  62. return parser.parse_args()
  63. if __name__ == '__main__':
  64. args = get_args()
  65. label_list = test_net(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet)
  66. rst_path = args.rst_path
  67. metrics = dice_coeff()
  68. for j in range(len(os.listdir(rst_path))):
  69. file_name = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin"
  70. output = np.fromfile(file_name, np.float32).reshape(1, 2, 576, 576)
  71. label = label_list[j]
  72. metrics.update(output, label)
  73. print("Cross valid dice coeff is: ", metrics.eval())