You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

recall.py 6.3 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Recall."""
  16. import sys
  17. import numpy as np
  18. from mindspore._checkparam import Validator as validator
  19. from ._evaluation import EvaluationBase
  20. class Recall(EvaluationBase):
  21. r"""
  22. Calculates recall for classification and multilabel data.
  23. The recall class creates two local variables, :math:`\text{true_positive}` and :math:`\text{false_negative}`,
  24. that are used to compute the recall. This value is ultimately returned as the recall, an idempotent operation
  25. that simply divides :math:`\text{true_positive}` by the sum of :math:`\text{true_positive}` and
  26. :math:`\text{false_negative}`.
  27. .. math::
  28. \text{recall} = \frac{\text{true_positive}}{\text{true_positive} + \text{false_negative}}
  29. Note:
  30. In the multi-label cases, the elements of :math:`y` and :math:`y_{pred}` must be 0 or 1.
  31. Args:
  32. eval_type (str): Metric to calculate the recall over a dataset, for classification or
  33. multilabel. Default: 'classification'.
  34. Examples:
  35. >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
  36. >>> y = Tensor(np.array([1, 0, 1]))
  37. >>> metric = nn.Recall('classification')
  38. >>> metric.clear()
  39. >>> metric.update(x, y)
  40. >>> recall = metric.eval()
  41. >>> print(recall)
  42. [1. 0.5]
  43. """
  44. def __init__(self, eval_type='classification'):
  45. super(Recall, self).__init__(eval_type)
  46. self.eps = sys.float_info.min
  47. self.clear()
  48. def clear(self):
  49. """Clears the internal evaluation result."""
  50. self._class_num = 0
  51. if self._type == "multilabel":
  52. self._true_positives = np.empty(0)
  53. self._actual_positives = np.empty(0)
  54. self._true_positives_average = 0
  55. self._actual_positives_average = 0
  56. else:
  57. self._true_positives = 0
  58. self._actual_positives = 0
  59. def update(self, *inputs):
  60. """
  61. Updates the internal evaluation result with `y_pred` and `y`.
  62. Args:
  63. inputs: Input `y_pred` and `y`. `y_pred` and `y` are a `Tensor`, a list or an array.
  64. For 'classification' evaluation type, `y_pred` is in most cases (not strictly) a list
  65. of floating numbers in range :math:`[0, 1]`
  66. and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C`
  67. is the number of categories. Shape of `y` can be :math:`(N, C)` with values 0 and 1 if one-hot
  68. encoding is used or the shape is :math:`(N,)` with integer values if index of category is used.
  69. For 'multilabel' evaluation type, `y_pred` and `y` can only be one-hot encoding with
  70. values 0 or 1. Indices with 1 indicate positive category. The shape of `y_pred` and `y`
  71. are both :math:`(N, C)`.
  72. Raises:
  73. ValueError: If the number of input is not 2.
  74. """
  75. if len(inputs) != 2:
  76. raise ValueError('Recall need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
  77. y_pred = self._convert_data(inputs[0])
  78. y = self._convert_data(inputs[1])
  79. if self._type == 'classification' and y_pred.ndim == y.ndim and self._check_onehot_data(y):
  80. y = y.argmax(axis=1)
  81. self._check_shape(y_pred, y)
  82. self._check_value(y_pred, y)
  83. if self._class_num == 0:
  84. self._class_num = y_pred.shape[1]
  85. elif y_pred.shape[1] != self._class_num:
  86. raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
  87. 'classes'.format(self._class_num, y_pred.shape[1]))
  88. class_num = self._class_num
  89. if self._type == "classification":
  90. if y.max() + 1 > class_num:
  91. raise ValueError('y_pred contains {} classes less than y contains {} classes.'.
  92. format(class_num, y.max() + 1))
  93. y = np.eye(class_num)[y.reshape(-1)]
  94. indices = y_pred.argmax(axis=1).reshape(-1)
  95. y_pred = np.eye(class_num)[indices]
  96. elif self._type == "multilabel":
  97. y_pred = y_pred.swapaxes(1, 0).reshape(class_num, -1)
  98. y = y.swapaxes(1, 0).reshape(class_num, -1)
  99. actual_positives = y.sum(axis=0)
  100. true_positives = (y * y_pred).sum(axis=0)
  101. if self._type == "multilabel":
  102. self._true_positives_average += np.sum(true_positives / (actual_positives + self.eps))
  103. self._actual_positives_average += len(actual_positives)
  104. self._true_positives = np.concatenate((self._true_positives, true_positives), axis=0)
  105. self._actual_positives = np.concatenate((self._actual_positives, actual_positives), axis=0)
  106. else:
  107. self._true_positives += true_positives
  108. self._actual_positives += actual_positives
  109. def eval(self, average=False):
  110. """
  111. Computes the recall.
  112. Args:
  113. average (bool): Specify whether calculate the average recall. Default value is False.
  114. Returns:
  115. Float, the computed result.
  116. """
  117. if self._class_num == 0:
  118. raise RuntimeError('Input number of samples can not be 0.')
  119. validator.check_value_type("average", average, [bool], self.__class__.__name__)
  120. result = self._true_positives / (self._actual_positives + self.eps)
  121. if average:
  122. if self._type == "multilabel":
  123. result = self._true_positives_average / (self._actual_positives_average + self.eps)
  124. return result.mean()
  125. return result