You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

error.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Error."""
  16. import numpy as np
  17. from .metric import Metric
  18. class MAE(Metric):
  19. r"""
  20. Calculates the mean absolute error.
  21. Creates a criterion that measures the mean absolute error (MAE)
  22. between each element in the input: :math:`x` and the target: :math:`y`.
  23. .. math::
  24. \text{MAE} = \frac{\sum_{i=1}^n \|y_i - x_i\|}{n}
  25. Here :math:`y_i` is the prediction and :math:`x_i` is the true value.
  26. Note:
  27. The method `update` must be called with the form `update(y_pred, y)`.
  28. Examples:
  29. >>> x = Tensor(np.array([0.1, 0.2, 0.6, 0.9]), mindspore.float32)
  30. >>> y = Tensor(np.array([0.1, 0.25, 0.7, 0.9]), mindspore.float32)
  31. >>> error = nn.MAE()
  32. >>> error.clear()
  33. >>> error.update(x, y)
  34. >>> result = error.eval()
  35. """
  36. def __init__(self):
  37. super(MAE, self).__init__()
  38. self.clear()
  39. def clear(self):
  40. """Clears the internal evaluation result."""
  41. self._abs_error_sum = 0
  42. self._samples_num = 0
  43. def update(self, *inputs):
  44. """
  45. Updates the internal evaluation result :math:`y_{pred}` and :math:`y`.
  46. Args:
  47. inputs: Input `y_pred` and `y` for calculating mean absolute error where the shape of
  48. `y_pred` and `y` are both N-D and the shape are the same.
  49. Raises:
  50. ValueError: If the number of the input is not 2.
  51. """
  52. if len(inputs) != 2:
  53. raise ValueError('Mean absolute error need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
  54. y_pred = self._convert_data(inputs[0])
  55. y = self._convert_data(inputs[1])
  56. abs_error_sum = np.abs(y.reshape(y_pred.shape) - y_pred)
  57. self._abs_error_sum += abs_error_sum.sum()
  58. self._samples_num += y.shape[0]
  59. def eval(self):
  60. """
  61. Computes the mean absolute error.
  62. Returns:
  63. Float, the computed result.
  64. Raises:
  65. RuntimeError: If the number of the total samples is 0.
  66. """
  67. if self._samples_num == 0:
  68. raise RuntimeError('Total samples num must not be 0.')
  69. return self._abs_error_sum / self._samples_num
  70. class MSE(Metric):
  71. r"""
  72. Measures the mean squared error.
  73. Creates a criterion that measures the mean squared error (squared L2
  74. norm) between each element in the input: :math:`x` and the target: :math:`y`.
  75. .. math::
  76. \text{MSE}(x,\ y) = \frac{\sum_{i=1}^n(y_i - x_i)^2}{n},
  77. where :math:`n` is batch size.
  78. Examples:
  79. >>> x = Tensor(np.array([0.1, 0.2, 0.6, 0.9]), mindspore.float32)
  80. >>> y = Tensor(np.array([0.1, 0.25, 0.5, 0.9]), mindspore.float32)
  81. >>> error = nn.MSE()
  82. >>> error.clear()
  83. >>> error.update(x, y)
  84. >>> result = error.eval()
  85. """
  86. def __init__(self):
  87. super(MSE, self).__init__()
  88. self.clear()
  89. def clear(self):
  90. """Clear the internal evaluation result."""
  91. self._squared_error_sum = 0
  92. self._samples_num = 0
  93. def update(self, *inputs):
  94. """
  95. Updates the internal evaluation result :math:`y_{pred}` and :math:`y`.
  96. Args:
  97. inputs: Input `y_pred` and `y` for calculating mean square error where the shape of
  98. `y_pred` and `y` are both N-D and the shape are the same.
  99. Raises:
  100. ValueError: If the number of input is not 2.
  101. """
  102. if len(inputs) != 2:
  103. raise ValueError('Mean squared error need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
  104. y_pred = self._convert_data(inputs[0])
  105. y = self._convert_data(inputs[1])
  106. squared_error_sum = np.power(y.reshape(y_pred.shape) - y_pred, 2)
  107. self._squared_error_sum += squared_error_sum.sum()
  108. self._samples_num += y.shape[0]
  109. def eval(self):
  110. """
  111. Compute the mean squared error.
  112. Returns:
  113. Float, the computed result.
  114. Raises:
  115. RuntimeError: If the number of samples is 0.
  116. """
  117. if self._samples_num == 0:
  118. raise RuntimeError('The number of input samples must not be 0.')
  119. return self._squared_error_sum / self._samples_num