From ef3f6d62db0b6caa304baf79c7dd8c870f99d7a4 Mon Sep 17 00:00:00 2001 From: Jiaqi Date: Wed, 23 Dec 2020 17:49:39 +0800 Subject: [PATCH] auc_roc --- mindspore/nn/metrics/__init__.py | 6 + mindspore/nn/metrics/auc.py | 118 +++++++++++++++++++ mindspore/nn/metrics/metric.py | 29 +++++ mindspore/nn/metrics/roc.py | 177 ++++++++++++++++++++++++++++ tests/ut/python/metrics/test_auc.py | 34 ++++++ tests/ut/python/metrics/test_roc.py | 92 +++++++++++++++ 6 files changed, 456 insertions(+) create mode 100644 mindspore/nn/metrics/auc.py create mode 100644 mindspore/nn/metrics/roc.py create mode 100644 tests/ut/python/metrics/test_auc.py create mode 100644 tests/ut/python/metrics/test_roc.py diff --git a/mindspore/nn/metrics/__init__.py b/mindspore/nn/metrics/__init__.py index 411e8c497e..dec311ae33 100755 --- a/mindspore/nn/metrics/__init__.py +++ b/mindspore/nn/metrics/__init__.py @@ -26,6 +26,8 @@ from .precision import Precision from .recall import Recall from .fbeta import Fbeta, F1 from .dice import Dice +from .roc import ROC +from .auc import auc from .topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy from .loss import Loss @@ -40,6 +42,8 @@ __all__ = [ "Fbeta", "F1", "Dice", + "ROC", + "auc", "TopKCategoricalAccuracy", "Top1CategoricalAccuracy", "Top5CategoricalAccuracy", @@ -53,6 +57,8 @@ __factory__ = { 'recall': Recall, 'F1': F1, 'dice': Dice, + 'roc': ROC, + 'auc': auc, 'topk': TopKCategoricalAccuracy, 'hausdorff_distance': HausdorffDistance, 'top_1_accuracy': Top1CategoricalAccuracy, diff --git a/mindspore/nn/metrics/auc.py b/mindspore/nn/metrics/auc.py new file mode 100644 index 0000000000..cd0155a271 --- /dev/null +++ b/mindspore/nn/metrics/auc.py @@ -0,0 +1,118 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""auc""" +import numpy as np + + +def auc(x, y, reorder=False): + """ + Compute the Area Under the Curve (AUC) using the trapezoidal rule. This is a general function, given points on a + curve. For computing the area under the ROC-curve. + + Args: + x (Union[np.array, list]): From the ROC curve(fpr), np.array with false positive rates. If multiclass, + this is a list of such np.array, one for each class. The shape :math:`(N)`. + y (Union[np.array, list]): From the ROC curve(tpr), np.array with true positive rates. If multiclass, + this is a list of such np.array, one for each class. The shape :math:`(N)`. + reorder (boolean): If True, assume that the curve is ascending in the case of ties, as for an ROC curve. + If the curve is non-ascending, the result will be wrong. Default: False. + + Returns: + area (float): Compute result. + + Examples: + >>> y_pred = np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]]) + >>> y = np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]]) + >>> metric = ROC(pos_label=2) + >>> metric.clear() + >>> metric.update(y_pred, y) + >>> fpr, tpr, thre = metric.eval() + >>> output = auc(fpr, tpr) + 0.5357142857142857 + """ + if not isinstance(x, np.ndarray) or not isinstance(y, np.ndarray): + raise TypeError('The inputs must be np.ndarray, but got {}, {}'.format(type(x), type(y))) + _check_consistent_length(x, y) + x = _column_or_1d(x) + y = _column_or_1d(y) + + if x.shape[0] < 2: + raise ValueError('At least 2 points are needed to compute the AUC, but x.shape = {}.'.format(x.shape)) + + direction = 1 + if reorder: + order = np.lexsort((y, x)) + x, y = x[order], y[order] + else: + dx = np.diff(x) + if np.any(dx < 0): + if np.all(dx <= 0): + direction = -1 + else: + raise ValueError("Reordering is not turned on, and the x array is not increasing:{}".format(x)) + + area = direction * np.trapz(y, x) + if isinstance(area, np.memmap): + area = area.dtype.type(area) + return area + + +def _column_or_1d(y): + """ + Ravel column or 1d numpy array, otherwise raise an error. + """ + shape = np.shape(y) + if len(shape) == 1: + return np.ravel(y) + if len(shape) == 2 and shape[1] == 1: + return np.ravel(y) + + raise ValueError("Bad input shape {0}.".format(shape)) + + +def _num_samples(x): + """Return the number of samples in array-like x.""" + if hasattr(x, 'fit') and callable(x.fit): + raise TypeError('Expected sequence or array-like, got estimator {}.'.format(x)) + if not hasattr(x, '__len__') and not hasattr(x, 'shape'): + if hasattr(x, '__array__'): + x = np.asarray(x) + else: + raise TypeError("Expected sequence or array-like, got {}." .format(type(x))) + if hasattr(x, 'shape'): + if x.ndim == 0: + raise TypeError("Singleton array {} cannot be considered as a valid collection.".format(x)) + res = x.shape[0] + else: + res = x.size + + return res + + +def _check_consistent_length(*arrays): + r""" + Check that all arrays have consistent first dimensions. Check whether all objects in arrays have the same shape + or length. + + Args: + - **(*arrays)** - (Union[tuple, list]): list or tuple of input objects. Objects that will be checked for + consistent length. + """ + + lengths = [_num_samples(array) for array in arrays if array is not None] + uniques = np.unique(lengths) + if len(uniques) > 1: + raise ValueError("Found input variables with inconsistent numbers of samples: {}." + .format([int(length) for length in lengths])) diff --git a/mindspore/nn/metrics/metric.py b/mindspore/nn/metrics/metric.py index 19c06d2759..13e1775e53 100644 --- a/mindspore/nn/metrics/metric.py +++ b/mindspore/nn/metrics/metric.py @@ -65,6 +65,35 @@ class Metric(metaclass=ABCMeta): return True return False + def _binary_clf_curve(self, preds, target, sample_weights=None, pos_label=1): + """Calculate True Positives and False Positives per binary classification threshold.""" + if sample_weights is not None and not isinstance(sample_weights, np.ndarray): + sample_weights = np.array(sample_weights) + + if preds.ndim > target.ndim: + preds = preds[:, 0] + desc_score_indices = np.argsort(-preds) + + preds = preds[desc_score_indices] + target = target[desc_score_indices] + + if sample_weights is not None: + weight = sample_weights[desc_score_indices] + else: + weight = 1. + + distinct_value_indices = np.where(preds[1:] - preds[:-1])[0] + threshold_idxs = np.pad(distinct_value_indices, (0, 1), constant_values=target.shape[0] - 1) + target = np.array(target == pos_label).astype(np.int64) + tps = np.cumsum(target * weight, axis=0)[threshold_idxs] + + if sample_weights is not None: + fps = np.cumsum((1 - target) * weight, axis=0)[threshold_idxs] + else: + fps = 1 + threshold_idxs - tps + + return fps, tps, preds[threshold_idxs] + def __call__(self, *inputs): """ Evaluate input data once. diff --git a/mindspore/nn/metrics/roc.py b/mindspore/nn/metrics/roc.py new file mode 100644 index 0000000000..89860401f2 --- /dev/null +++ b/mindspore/nn/metrics/roc.py @@ -0,0 +1,177 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ROC""" +import numpy as np +from mindspore._checkparam import Validator as validator +from .metric import Metric + + +class ROC(Metric): + """ + Calculate the ROC curve. It is suitable for solving binary classification and multi classification problems. + In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. + + Args: + class_num (int): Integer with the number of classes. For the problem of binary classification, it is not + necessary to provide this argument. Default: None. + pos_label (int): Determine the integer of positive class. Default: None. For binary problems, it is translated + to 1. For multiclass problems, this argument should not be set, as it is iteratively changed + in the range [0,num_classes-1]. Default: None. + + Examples: + >>> 1) binary classification example + >>> x = Tensor(np.array([3, 1, 4, 2])) + >>> y = Tensor(np.array([0, 1, 2, 3])) + >>> metric = ROC(pos_label=2) + >>> metric.clear() + >>> metric.update(x, y) + >>> fpr, tpr, thresholds = metric.eval() + [0., 0., 0.33333333, 0.6666667, 1.] + [0., 1, 1., 1., 1.] + [5, 4, 3, 2, 1] + >>> + >>> 2) multiclass classification example + >>> x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05], + ... [0.05, 0.05, 0.05, 0.75]])) + >>> y = Tensor(np.array([0, 1, 2, 3])) + >>> metric = ROC(class_num=4) + >>> metric.clear() + >>> metric.update(x, y) + >>> fpr, tpr, thresholds = metric.eval() + [array([0., 0., 0.33333333, 0.66666667, 1.]), array([0., 0.33333333, 0.33333333, 1.]), + array([0., 0.33333333, 1.]), array([0., 0., 1.])] + [array([0., 1., 1., 1., 1.]), array([0., 0., 1., 1.]), array([0., 1., 1.]), array([0., 1., 1.])] + [array([1.28, 0.28, 0.2, 0.1, 0.05]), array([1.55, 0.55, 0.2, 0.05]), array([1.15, 0.15, 0.05]), + array([1.75, 0.75, 0.05])] + """ + def __init__(self, class_num=None, pos_label=None): + super().__init__() + self.class_num = class_num if class_num is None else validator.check_value_type("class_num", class_num, [int]) + self.pos_label = pos_label if pos_label is None else validator.check_value_type("pos_label", pos_label, [int]) + self.clear() + + def clear(self): + """Clear the internal evaluation result.""" + self.y_pred = 0 + self.y = 0 + self.sample_weights = None + self._is_update = False + + def _precision_recall_curve_update(self, y_pred, y, class_num, pos_label): + """update curve""" + if not (len(y_pred.shape) == len(y.shape) or len(y_pred.shape) == len(y.shape) + 1): + raise ValueError("y_pred and y must have the same number of dimensions, or one additional dimension for" + " y_pred.") + + # single class evaluation + if len(y_pred.shape) == len(y.shape): + if class_num is not None and class_num != 1: + raise ValueError('y_pred and y should have the same shape, but number of classes is different from 1.') + class_num = 1 + if pos_label is None: + pos_label = 1 + y_pred = y_pred.flatten() + y = y.flatten() + + # multi class evaluation + elif len(y_pred.shape) == len(y.shape) + 1: + if pos_label is not None: + raise ValueError('Argument `pos_label` should be `None` when running multiclass precision recall ' + 'curve, but got {}.'.format(pos_label)) + if class_num != y_pred.shape[1]: + raise ValueError('Argument `class_num` was set to {}, but detected {} number of classes from ' + 'predictions.'.format(class_num, y_pred.shape[1])) + y_pred = y_pred.transpose(0, 1).reshape(class_num, -1).transpose(0, 1) + y = y.flatten() + + return y_pred, y, class_num, pos_label + + def update(self, *inputs): + """ + Update state with predictions and targets. + Args: + inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. + In most cases (not strictly), y_pred is a list of floating numbers in range :math:`[0, 1]` + and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C` + is the number of categories. y contains values of integers. + """ + if len(inputs) != 2: + raise ValueError('ROC need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) + y_pred = self._convert_data(inputs[0]) + y = self._convert_data(inputs[1]) + + y_pred, y, class_num, pos_label = self._precision_recall_curve_update(y_pred, y, self.class_num, self.pos_label) + + self.y_pred = y_pred + self.y = y + self.class_num = class_num + self.pos_label = pos_label + self._is_update = True + + def _roc_eval(self, y_pred, y, class_num, pos_label, sample_weights=None): + """Computes the ROC curve.""" + if class_num == 1: + fps, tps, thresholds = self._binary_clf_curve(y_pred, y, sample_weights=sample_weights, + pos_label=pos_label) + tps = np.squeeze(np.hstack([np.zeros(1, dtype=tps.dtype), tps])) + fps = np.squeeze(np.hstack([np.zeros(1, dtype=fps.dtype), fps])) + thresholds = np.hstack([thresholds[0][None] + 1, thresholds]) + + if fps[-1] <= 0: + raise ValueError("No negative samples in y, false positive value should be meaningless.") + fpr = fps / fps[-1] + + if tps[-1] <= 0: + raise ValueError("No positive samples in y, true positive value should be meaningless.") + tpr = tps / tps[-1] + + return fpr, tpr, thresholds + + fpr, tpr, thresholds = [], [], [] + for c in range(class_num): + preds_c = y_pred[:, c] + res = self.roc(preds_c, y, class_num=1, pos_label=c, sample_weights=sample_weights) + fpr.append(res[0]) + tpr.append(res[1]) + thresholds.append(res[2]) + + return fpr, tpr, thresholds + + def roc(self, y_pred, y, class_num=None, pos_label=None, sample_weights=None): + """roc""" + y_pred, y, class_num, pos_label = self._precision_recall_curve_update(y_pred, y, class_num, pos_label) + + return self._roc_eval(y_pred, y, class_num, pos_label, sample_weights) + + def eval(self): + """ + Computes the ROC curve. + + Returns: + A tuple, composed of `fpr`, `tpr`, and `thresholds`. + + - **fpr** (np.array) - np.array with false positive rates. If multiclass, this is a list of such np.array, + one for each class. + - **tps** (np.array) - np.array with true positive rates. If multiclass, this is a list of such np.array, + one for each class. + - **thresholds** (np.array) - thresholds used for computing false- and true postive rates. + """ + if self._is_update is False: + raise RuntimeError('Call the update method before calling eval.') + + y_pred = np.squeeze(np.vstack(self.y_pred)) + y = np.squeeze(np.vstack(self.y)) + + return self._roc_eval(y_pred, y, self.class_num, self.pos_label) diff --git a/tests/ut/python/metrics/test_auc.py b/tests/ut/python/metrics/test_auc.py new file mode 100644 index 0000000000..818fbc3cf7 --- /dev/null +++ b/tests/ut/python/metrics/test_auc.py @@ -0,0 +1,34 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# """test_auc""" + +import math +import numpy as np +from mindspore import Tensor +from mindspore.nn.metrics import ROC, auc + + +def test_auc(): + """test_auc""" + x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])) + y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])) + metric = ROC(pos_label=1) + metric.clear() + metric.update(x, y) + fpr, tpr, thre = metric.eval() + output = auc(fpr, tpr) + + assert math.isclose(output, 0.45, abs_tol=0.001) + assert np.equal(thre, np.array([4, 3, 2, 1, 0])).all() diff --git a/tests/ut/python/metrics/test_roc.py b/tests/ut/python/metrics/test_roc.py new file mode 100644 index 0000000000..e973b297ad --- /dev/null +++ b/tests/ut/python/metrics/test_roc.py @@ -0,0 +1,92 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# """test_roc""" + +import numpy as np +import pytest +from mindspore import Tensor +from mindspore.nn.metrics import ROC + + +def test_roc(): + """test_roc_binary""" + x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])) + y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])) + metric = ROC(pos_label=1) + metric.clear() + metric.update(x, y) + fpr, tpr, thresholds = metric.eval() + + assert np.equal(fpr, np.array([0, 0.4, 0.4, 0.6, 1])).all() + assert np.equal(tpr, np.array([0, 0, 0.25, 0.75, 1])).all() + assert np.equal(thresholds, np.array([4, 3, 2, 1, 0])).all() + + +def test_roc2(): + """test_roc_multiclass""" + x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05], + [0.05, 0.05, 0.05, 0.75]])) + y = Tensor(np.array([0, 1, 2, 3])) + metric = ROC(class_num=4) + metric.clear() + metric.update(x, y) + fpr, tpr, thresholds = metric.eval() + list1 = [np.array([0., 0., 0.33333333, 0.66666667, 1.]), np.array([0., 0.33333333, 0.33333333, 1.]), + np.array([0., 0.33333333, 1.]), np.array([0., 0., 1.])] + list2 = [np.array([0., 1., 1., 1., 1.]), np.array([0., 0., 1., 1.]), + np.array([0., 1., 1.]), np.array([0., 1., 1.])] + list3 = [np.array([1.28, 0.28, 0.2, 0.1, 0.05]), np.array([1.55, 0.55, 0.2, 0.05]), + np.array([1.15, 0.15, 0.05]), np.array([1.75, 0.75, 0.05])] + + assert fpr[0].shape == list1[0].shape + assert np.equal(tpr[1], list2[1]).all() + assert np.equal(thresholds[2], list3[2]).all() + + +def test_roc_update1(): + x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) + metric = ROC() + metric.clear() + + with pytest.raises(ValueError): + metric.update(x) + + +def test_roc_update2(): + x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) + y = Tensor(np.array([1, 0])) + metric = ROC() + metric.clear() + + with pytest.raises(ValueError): + metric.update(x, y) + + +def test_roc_init1(): + with pytest.raises(TypeError): + ROC(pos_label=1.2) + + +def test_roc_init2(): + with pytest.raises(TypeError): + ROC(class_num="class_num") + + +def test_roc_runtime(): + metric = ROC() + metric.clear() + + with pytest.raises(RuntimeError): + metric.eval()