| @@ -26,6 +26,8 @@ from .precision import Precision | |||||
| from .recall import Recall | from .recall import Recall | ||||
| from .fbeta import Fbeta, F1 | from .fbeta import Fbeta, F1 | ||||
| from .dice import Dice | from .dice import Dice | ||||
| from .roc import ROC | |||||
| from .auc import auc | |||||
| from .topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy | from .topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy | ||||
| from .loss import Loss | from .loss import Loss | ||||
| @@ -40,6 +42,8 @@ __all__ = [ | |||||
| "Fbeta", | "Fbeta", | ||||
| "F1", | "F1", | ||||
| "Dice", | "Dice", | ||||
| "ROC", | |||||
| "auc", | |||||
| "TopKCategoricalAccuracy", | "TopKCategoricalAccuracy", | ||||
| "Top1CategoricalAccuracy", | "Top1CategoricalAccuracy", | ||||
| "Top5CategoricalAccuracy", | "Top5CategoricalAccuracy", | ||||
| @@ -53,6 +57,8 @@ __factory__ = { | |||||
| 'recall': Recall, | 'recall': Recall, | ||||
| 'F1': F1, | 'F1': F1, | ||||
| 'dice': Dice, | 'dice': Dice, | ||||
| 'roc': ROC, | |||||
| 'auc': auc, | |||||
| 'topk': TopKCategoricalAccuracy, | 'topk': TopKCategoricalAccuracy, | ||||
| 'hausdorff_distance': HausdorffDistance, | 'hausdorff_distance': HausdorffDistance, | ||||
| 'top_1_accuracy': Top1CategoricalAccuracy, | 'top_1_accuracy': Top1CategoricalAccuracy, | ||||
| @@ -0,0 +1,118 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """auc""" | |||||
| import numpy as np | |||||
| def auc(x, y, reorder=False): | |||||
| """ | |||||
| Compute the Area Under the Curve (AUC) using the trapezoidal rule. This is a general function, given points on a | |||||
| curve. For computing the area under the ROC-curve. | |||||
| Args: | |||||
| x (Union[np.array, list]): From the ROC curve(fpr), np.array with false positive rates. If multiclass, | |||||
| this is a list of such np.array, one for each class. The shape :math:`(N)`. | |||||
| y (Union[np.array, list]): From the ROC curve(tpr), np.array with true positive rates. If multiclass, | |||||
| this is a list of such np.array, one for each class. The shape :math:`(N)`. | |||||
| reorder (boolean): If True, assume that the curve is ascending in the case of ties, as for an ROC curve. | |||||
| If the curve is non-ascending, the result will be wrong. Default: False. | |||||
| Returns: | |||||
| area (float): Compute result. | |||||
| Examples: | |||||
| >>> y_pred = np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]]) | |||||
| >>> y = np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]]) | |||||
| >>> metric = ROC(pos_label=2) | |||||
| >>> metric.clear() | |||||
| >>> metric.update(y_pred, y) | |||||
| >>> fpr, tpr, thre = metric.eval() | |||||
| >>> output = auc(fpr, tpr) | |||||
| 0.5357142857142857 | |||||
| """ | |||||
| if not isinstance(x, np.ndarray) or not isinstance(y, np.ndarray): | |||||
| raise TypeError('The inputs must be np.ndarray, but got {}, {}'.format(type(x), type(y))) | |||||
| _check_consistent_length(x, y) | |||||
| x = _column_or_1d(x) | |||||
| y = _column_or_1d(y) | |||||
| if x.shape[0] < 2: | |||||
| raise ValueError('At least 2 points are needed to compute the AUC, but x.shape = {}.'.format(x.shape)) | |||||
| direction = 1 | |||||
| if reorder: | |||||
| order = np.lexsort((y, x)) | |||||
| x, y = x[order], y[order] | |||||
| else: | |||||
| dx = np.diff(x) | |||||
| if np.any(dx < 0): | |||||
| if np.all(dx <= 0): | |||||
| direction = -1 | |||||
| else: | |||||
| raise ValueError("Reordering is not turned on, and the x array is not increasing:{}".format(x)) | |||||
| area = direction * np.trapz(y, x) | |||||
| if isinstance(area, np.memmap): | |||||
| area = area.dtype.type(area) | |||||
| return area | |||||
| def _column_or_1d(y): | |||||
| """ | |||||
| Ravel column or 1d numpy array, otherwise raise an error. | |||||
| """ | |||||
| shape = np.shape(y) | |||||
| if len(shape) == 1: | |||||
| return np.ravel(y) | |||||
| if len(shape) == 2 and shape[1] == 1: | |||||
| return np.ravel(y) | |||||
| raise ValueError("Bad input shape {0}.".format(shape)) | |||||
| def _num_samples(x): | |||||
| """Return the number of samples in array-like x.""" | |||||
| if hasattr(x, 'fit') and callable(x.fit): | |||||
| raise TypeError('Expected sequence or array-like, got estimator {}.'.format(x)) | |||||
| if not hasattr(x, '__len__') and not hasattr(x, 'shape'): | |||||
| if hasattr(x, '__array__'): | |||||
| x = np.asarray(x) | |||||
| else: | |||||
| raise TypeError("Expected sequence or array-like, got {}." .format(type(x))) | |||||
| if hasattr(x, 'shape'): | |||||
| if x.ndim == 0: | |||||
| raise TypeError("Singleton array {} cannot be considered as a valid collection.".format(x)) | |||||
| res = x.shape[0] | |||||
| else: | |||||
| res = x.size | |||||
| return res | |||||
| def _check_consistent_length(*arrays): | |||||
| r""" | |||||
| Check that all arrays have consistent first dimensions. Check whether all objects in arrays have the same shape | |||||
| or length. | |||||
| Args: | |||||
| - **(*arrays)** - (Union[tuple, list]): list or tuple of input objects. Objects that will be checked for | |||||
| consistent length. | |||||
| """ | |||||
| lengths = [_num_samples(array) for array in arrays if array is not None] | |||||
| uniques = np.unique(lengths) | |||||
| if len(uniques) > 1: | |||||
| raise ValueError("Found input variables with inconsistent numbers of samples: {}." | |||||
| .format([int(length) for length in lengths])) | |||||
| @@ -65,6 +65,35 @@ class Metric(metaclass=ABCMeta): | |||||
| return True | return True | ||||
| return False | return False | ||||
| def _binary_clf_curve(self, preds, target, sample_weights=None, pos_label=1): | |||||
| """Calculate True Positives and False Positives per binary classification threshold.""" | |||||
| if sample_weights is not None and not isinstance(sample_weights, np.ndarray): | |||||
| sample_weights = np.array(sample_weights) | |||||
| if preds.ndim > target.ndim: | |||||
| preds = preds[:, 0] | |||||
| desc_score_indices = np.argsort(-preds) | |||||
| preds = preds[desc_score_indices] | |||||
| target = target[desc_score_indices] | |||||
| if sample_weights is not None: | |||||
| weight = sample_weights[desc_score_indices] | |||||
| else: | |||||
| weight = 1. | |||||
| distinct_value_indices = np.where(preds[1:] - preds[:-1])[0] | |||||
| threshold_idxs = np.pad(distinct_value_indices, (0, 1), constant_values=target.shape[0] - 1) | |||||
| target = np.array(target == pos_label).astype(np.int64) | |||||
| tps = np.cumsum(target * weight, axis=0)[threshold_idxs] | |||||
| if sample_weights is not None: | |||||
| fps = np.cumsum((1 - target) * weight, axis=0)[threshold_idxs] | |||||
| else: | |||||
| fps = 1 + threshold_idxs - tps | |||||
| return fps, tps, preds[threshold_idxs] | |||||
| def __call__(self, *inputs): | def __call__(self, *inputs): | ||||
| """ | """ | ||||
| Evaluate input data once. | Evaluate input data once. | ||||
| @@ -0,0 +1,177 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ROC""" | |||||
| import numpy as np | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from .metric import Metric | |||||
| class ROC(Metric): | |||||
| """ | |||||
| Calculate the ROC curve. It is suitable for solving binary classification and multi classification problems. | |||||
| In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. | |||||
| Args: | |||||
| class_num (int): Integer with the number of classes. For the problem of binary classification, it is not | |||||
| necessary to provide this argument. Default: None. | |||||
| pos_label (int): Determine the integer of positive class. Default: None. For binary problems, it is translated | |||||
| to 1. For multiclass problems, this argument should not be set, as it is iteratively changed | |||||
| in the range [0,num_classes-1]. Default: None. | |||||
| Examples: | |||||
| >>> 1) binary classification example | |||||
| >>> x = Tensor(np.array([3, 1, 4, 2])) | |||||
| >>> y = Tensor(np.array([0, 1, 2, 3])) | |||||
| >>> metric = ROC(pos_label=2) | |||||
| >>> metric.clear() | |||||
| >>> metric.update(x, y) | |||||
| >>> fpr, tpr, thresholds = metric.eval() | |||||
| [0., 0., 0.33333333, 0.6666667, 1.] | |||||
| [0., 1, 1., 1., 1.] | |||||
| [5, 4, 3, 2, 1] | |||||
| >>> | |||||
| >>> 2) multiclass classification example | |||||
| >>> x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05], | |||||
| ... [0.05, 0.05, 0.05, 0.75]])) | |||||
| >>> y = Tensor(np.array([0, 1, 2, 3])) | |||||
| >>> metric = ROC(class_num=4) | |||||
| >>> metric.clear() | |||||
| >>> metric.update(x, y) | |||||
| >>> fpr, tpr, thresholds = metric.eval() | |||||
| [array([0., 0., 0.33333333, 0.66666667, 1.]), array([0., 0.33333333, 0.33333333, 1.]), | |||||
| array([0., 0.33333333, 1.]), array([0., 0., 1.])] | |||||
| [array([0., 1., 1., 1., 1.]), array([0., 0., 1., 1.]), array([0., 1., 1.]), array([0., 1., 1.])] | |||||
| [array([1.28, 0.28, 0.2, 0.1, 0.05]), array([1.55, 0.55, 0.2, 0.05]), array([1.15, 0.15, 0.05]), | |||||
| array([1.75, 0.75, 0.05])] | |||||
| """ | |||||
| def __init__(self, class_num=None, pos_label=None): | |||||
| super().__init__() | |||||
| self.class_num = class_num if class_num is None else validator.check_value_type("class_num", class_num, [int]) | |||||
| self.pos_label = pos_label if pos_label is None else validator.check_value_type("pos_label", pos_label, [int]) | |||||
| self.clear() | |||||
| def clear(self): | |||||
| """Clear the internal evaluation result.""" | |||||
| self.y_pred = 0 | |||||
| self.y = 0 | |||||
| self.sample_weights = None | |||||
| self._is_update = False | |||||
| def _precision_recall_curve_update(self, y_pred, y, class_num, pos_label): | |||||
| """update curve""" | |||||
| if not (len(y_pred.shape) == len(y.shape) or len(y_pred.shape) == len(y.shape) + 1): | |||||
| raise ValueError("y_pred and y must have the same number of dimensions, or one additional dimension for" | |||||
| " y_pred.") | |||||
| # single class evaluation | |||||
| if len(y_pred.shape) == len(y.shape): | |||||
| if class_num is not None and class_num != 1: | |||||
| raise ValueError('y_pred and y should have the same shape, but number of classes is different from 1.') | |||||
| class_num = 1 | |||||
| if pos_label is None: | |||||
| pos_label = 1 | |||||
| y_pred = y_pred.flatten() | |||||
| y = y.flatten() | |||||
| # multi class evaluation | |||||
| elif len(y_pred.shape) == len(y.shape) + 1: | |||||
| if pos_label is not None: | |||||
| raise ValueError('Argument `pos_label` should be `None` when running multiclass precision recall ' | |||||
| 'curve, but got {}.'.format(pos_label)) | |||||
| if class_num != y_pred.shape[1]: | |||||
| raise ValueError('Argument `class_num` was set to {}, but detected {} number of classes from ' | |||||
| 'predictions.'.format(class_num, y_pred.shape[1])) | |||||
| y_pred = y_pred.transpose(0, 1).reshape(class_num, -1).transpose(0, 1) | |||||
| y = y.flatten() | |||||
| return y_pred, y, class_num, pos_label | |||||
| def update(self, *inputs): | |||||
| """ | |||||
| Update state with predictions and targets. | |||||
| Args: | |||||
| inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. | |||||
| In most cases (not strictly), y_pred is a list of floating numbers in range :math:`[0, 1]` | |||||
| and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C` | |||||
| is the number of categories. y contains values of integers. | |||||
| """ | |||||
| if len(inputs) != 2: | |||||
| raise ValueError('ROC need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) | |||||
| y_pred = self._convert_data(inputs[0]) | |||||
| y = self._convert_data(inputs[1]) | |||||
| y_pred, y, class_num, pos_label = self._precision_recall_curve_update(y_pred, y, self.class_num, self.pos_label) | |||||
| self.y_pred = y_pred | |||||
| self.y = y | |||||
| self.class_num = class_num | |||||
| self.pos_label = pos_label | |||||
| self._is_update = True | |||||
| def _roc_eval(self, y_pred, y, class_num, pos_label, sample_weights=None): | |||||
| """Computes the ROC curve.""" | |||||
| if class_num == 1: | |||||
| fps, tps, thresholds = self._binary_clf_curve(y_pred, y, sample_weights=sample_weights, | |||||
| pos_label=pos_label) | |||||
| tps = np.squeeze(np.hstack([np.zeros(1, dtype=tps.dtype), tps])) | |||||
| fps = np.squeeze(np.hstack([np.zeros(1, dtype=fps.dtype), fps])) | |||||
| thresholds = np.hstack([thresholds[0][None] + 1, thresholds]) | |||||
| if fps[-1] <= 0: | |||||
| raise ValueError("No negative samples in y, false positive value should be meaningless.") | |||||
| fpr = fps / fps[-1] | |||||
| if tps[-1] <= 0: | |||||
| raise ValueError("No positive samples in y, true positive value should be meaningless.") | |||||
| tpr = tps / tps[-1] | |||||
| return fpr, tpr, thresholds | |||||
| fpr, tpr, thresholds = [], [], [] | |||||
| for c in range(class_num): | |||||
| preds_c = y_pred[:, c] | |||||
| res = self.roc(preds_c, y, class_num=1, pos_label=c, sample_weights=sample_weights) | |||||
| fpr.append(res[0]) | |||||
| tpr.append(res[1]) | |||||
| thresholds.append(res[2]) | |||||
| return fpr, tpr, thresholds | |||||
| def roc(self, y_pred, y, class_num=None, pos_label=None, sample_weights=None): | |||||
| """roc""" | |||||
| y_pred, y, class_num, pos_label = self._precision_recall_curve_update(y_pred, y, class_num, pos_label) | |||||
| return self._roc_eval(y_pred, y, class_num, pos_label, sample_weights) | |||||
| def eval(self): | |||||
| """ | |||||
| Computes the ROC curve. | |||||
| Returns: | |||||
| A tuple, composed of `fpr`, `tpr`, and `thresholds`. | |||||
| - **fpr** (np.array) - np.array with false positive rates. If multiclass, this is a list of such np.array, | |||||
| one for each class. | |||||
| - **tps** (np.array) - np.array with true positive rates. If multiclass, this is a list of such np.array, | |||||
| one for each class. | |||||
| - **thresholds** (np.array) - thresholds used for computing false- and true postive rates. | |||||
| """ | |||||
| if self._is_update is False: | |||||
| raise RuntimeError('Call the update method before calling eval.') | |||||
| y_pred = np.squeeze(np.vstack(self.y_pred)) | |||||
| y = np.squeeze(np.vstack(self.y)) | |||||
| return self._roc_eval(y_pred, y, self.class_num, self.pos_label) | |||||
| @@ -0,0 +1,34 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| # """test_auc""" | |||||
| import math | |||||
| import numpy as np | |||||
| from mindspore import Tensor | |||||
| from mindspore.nn.metrics import ROC, auc | |||||
| def test_auc(): | |||||
| """test_auc""" | |||||
| x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])) | |||||
| y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])) | |||||
| metric = ROC(pos_label=1) | |||||
| metric.clear() | |||||
| metric.update(x, y) | |||||
| fpr, tpr, thre = metric.eval() | |||||
| output = auc(fpr, tpr) | |||||
| assert math.isclose(output, 0.45, abs_tol=0.001) | |||||
| assert np.equal(thre, np.array([4, 3, 2, 1, 0])).all() | |||||
| @@ -0,0 +1,92 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| # """test_roc""" | |||||
| import numpy as np | |||||
| import pytest | |||||
| from mindspore import Tensor | |||||
| from mindspore.nn.metrics import ROC | |||||
| def test_roc(): | |||||
| """test_roc_binary""" | |||||
| x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])) | |||||
| y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])) | |||||
| metric = ROC(pos_label=1) | |||||
| metric.clear() | |||||
| metric.update(x, y) | |||||
| fpr, tpr, thresholds = metric.eval() | |||||
| assert np.equal(fpr, np.array([0, 0.4, 0.4, 0.6, 1])).all() | |||||
| assert np.equal(tpr, np.array([0, 0, 0.25, 0.75, 1])).all() | |||||
| assert np.equal(thresholds, np.array([4, 3, 2, 1, 0])).all() | |||||
| def test_roc2(): | |||||
| """test_roc_multiclass""" | |||||
| x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05], | |||||
| [0.05, 0.05, 0.05, 0.75]])) | |||||
| y = Tensor(np.array([0, 1, 2, 3])) | |||||
| metric = ROC(class_num=4) | |||||
| metric.clear() | |||||
| metric.update(x, y) | |||||
| fpr, tpr, thresholds = metric.eval() | |||||
| list1 = [np.array([0., 0., 0.33333333, 0.66666667, 1.]), np.array([0., 0.33333333, 0.33333333, 1.]), | |||||
| np.array([0., 0.33333333, 1.]), np.array([0., 0., 1.])] | |||||
| list2 = [np.array([0., 1., 1., 1., 1.]), np.array([0., 0., 1., 1.]), | |||||
| np.array([0., 1., 1.]), np.array([0., 1., 1.])] | |||||
| list3 = [np.array([1.28, 0.28, 0.2, 0.1, 0.05]), np.array([1.55, 0.55, 0.2, 0.05]), | |||||
| np.array([1.15, 0.15, 0.05]), np.array([1.75, 0.75, 0.05])] | |||||
| assert fpr[0].shape == list1[0].shape | |||||
| assert np.equal(tpr[1], list2[1]).all() | |||||
| assert np.equal(thresholds[2], list3[2]).all() | |||||
| def test_roc_update1(): | |||||
| x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) | |||||
| metric = ROC() | |||||
| metric.clear() | |||||
| with pytest.raises(ValueError): | |||||
| metric.update(x) | |||||
| def test_roc_update2(): | |||||
| x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) | |||||
| y = Tensor(np.array([1, 0])) | |||||
| metric = ROC() | |||||
| metric.clear() | |||||
| with pytest.raises(ValueError): | |||||
| metric.update(x, y) | |||||
| def test_roc_init1(): | |||||
| with pytest.raises(TypeError): | |||||
| ROC(pos_label=1.2) | |||||
| def test_roc_init2(): | |||||
| with pytest.raises(TypeError): | |||||
| ROC(class_num="class_num") | |||||
| def test_roc_runtime(): | |||||
| metric = ROC() | |||||
| metric.clear() | |||||
| with pytest.raises(RuntimeError): | |||||
| metric.eval() | |||||