You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dice.py 4.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Dice"""
  16. import numpy as np
  17. from mindspore._checkparam import Validator as validator
  18. from .metric import Metric
  19. class Dice(Metric):
  20. r"""
  21. The Dice coefficient is a set similarity metric. It is used to calculate the similarity between two samples. The
  22. value of the Dice coefficient is 1 when the segmentation result is the best and 0 when the segmentation result
  23. is the worst. The Dice coefficient indicates the ratio of the area between two objects to the total area.
  24. The function is shown as follows:
  25. .. math::
  26. dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true}
  27. Args:
  28. smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
  29. Default: 1e-5.
  30. threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5.
  31. Examples:
  32. >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
  33. >>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
  34. >>> metric = Dice(smooth=1e-5, threshold=0.5)
  35. >>> metric.clear()
  36. >>> metric.update(x, y)
  37. >>> dice = metric.eval()
  38. 0.22222926
  39. """
  40. def __init__(self, smooth=1e-5, threshold=0.5):
  41. super(Dice, self).__init__()
  42. self.smooth = validator.check_positive_float(smooth, "smooth")
  43. self.threshold = validator.check_value_type("threshold", threshold, [float])
  44. self.clear()
  45. def clear(self):
  46. """Clears the internal evaluation result."""
  47. self._dim = 0
  48. self.intersection = 0
  49. self.unionset = 0
  50. def update(self, *inputs):
  51. """
  52. Updates the internal evaluation result :math:`y_pred` and :math:`y`.
  53. Args:
  54. inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the
  55. predicted value, `y` is the true value. The shape of `y_pred` and `y` are both :math:`(N, C)`.
  56. Raises:
  57. ValueError: If the number of the inputs is not 2.
  58. """
  59. if len(inputs) != 2:
  60. raise ValueError('Dice need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
  61. y_pred = self._convert_data(inputs[0])
  62. y = self._convert_data(inputs[1])
  63. if y_pred.shape != y.shape:
  64. raise RuntimeError('y_pred and y should have same the dimension, but the shape of y_pred is{}, '
  65. 'the shape of y is {}.'.format(y_pred.shape, y.shape))
  66. y_pred = (y_pred > self.threshold).astype(int)
  67. self._dim = y.shape
  68. pred_flat = np.reshape(y_pred, (self._dim[0], -1))
  69. true_flat = np.reshape(y, (self._dim[0], -1))
  70. self.intersection = np.sum((pred_flat * true_flat), axis=1)
  71. self.unionset = np.sum(pred_flat, axis=1) + np.sum(true_flat, axis=1)
  72. def eval(self):
  73. r"""
  74. Computes the Dice.
  75. Returns:
  76. Float, the computed result.
  77. Raises:
  78. RuntimeError: If the sample size is 0.
  79. """
  80. if self._dim[0] == 0:
  81. raise RuntimeError('Dice can not be calculated, because the number of samples is 0.')
  82. dice = (2 * self.intersection + self.smooth) / (self.unionset + self.smooth)
  83. return np.sum(dice) / self._dim[0]