From: @lijiaqi0612 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -28,9 +28,12 @@ from .fbeta import Fbeta, F1 | |||
| from .dice import Dice | |||
| from .topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy | |||
| from .loss import Loss | |||
| from .mean_surface_distance import MeanSurfaceDistance | |||
| from .root_mean_square_surface_distance import RootMeanSquareDistance | |||
| __all__ = [ | |||
| "names", "get_metric_fn", | |||
| "names", | |||
| "get_metric_fn", | |||
| "Accuracy", | |||
| "MAE", "MSE", | |||
| "Metric", | |||
| @@ -44,6 +47,8 @@ __all__ = [ | |||
| "Top1CategoricalAccuracy", | |||
| "Top5CategoricalAccuracy", | |||
| "Loss", | |||
| "MeanSurfaceDistance", | |||
| "RootMeanSquareDistance", | |||
| ] | |||
| __factory__ = { | |||
| @@ -60,6 +65,8 @@ __factory__ = { | |||
| 'mae': MAE, | |||
| 'mse': MSE, | |||
| 'loss': Loss, | |||
| 'mean_surface_distance': MeanSurfaceDistance, | |||
| 'root_mean_square_distance': RootMeanSquareDistance, | |||
| } | |||
| @@ -0,0 +1,137 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """MeanSurfaceDistance.""" | |||
| from scipy.ndimage import morphology | |||
| import numpy as np | |||
| from mindspore._checkparam import Validator as validator | |||
| from .metric import Metric | |||
| class MeanSurfaceDistance(Metric): | |||
| """ | |||
| This function is used to compute the Average Surface Distance from `y_pred` to `y` under the default setting. | |||
| Mean Surface Distance(MSD), the mean of the vector is taken. This tell us how much, on average, the surface varies | |||
| between the segmentation and the GT. | |||
| Args: | |||
| distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods, | |||
| "euclidean", "chessboard" or "taxicab". Default: "euclidean". | |||
| symmetric (bool): if calculate the symmetric average surface distance between `y_pred` and `y`. In addition, | |||
| if sets ``symmetric = True``, the average symmetric surface distance between these two inputs | |||
| will be returned. Defaults: False. | |||
| Examples: | |||
| >>> x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])) | |||
| >>> y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])) | |||
| >>> metric = nn.MeanSurfaceDistance(symmetric=False, distance_metric="euclidean") | |||
| >>> metric.clear() | |||
| >>> metric.update(x, y, 0) | |||
| >>> mean_average_distance = metric.eval() | |||
| >>> print(mean_average_distance) | |||
| 0.8047378541243649 | |||
| """ | |||
| def __init__(self, symmetric=False, distance_metric="euclidean"): | |||
| super(MeanSurfaceDistance, self).__init__() | |||
| self.distance_metric_list = ["euclidean", "chessboard", "taxicab"] | |||
| distance_metric = validator.check_value_type("distance_metric", distance_metric, [str]) | |||
| self.distance_metric = validator.check_string(distance_metric, self.distance_metric_list, "distance_metric") | |||
| self.symmetric = validator.check_value_type("symmetric", symmetric, [bool]) | |||
| self.clear() | |||
| def clear(self): | |||
| """Clears the internal evaluation result.""" | |||
| self._y_pred_edges = 0 | |||
| self._y_edges = 0 | |||
| self._is_update = False | |||
| def _get_surface_distance(self, y_pred_edges, y_edges): | |||
| """ | |||
| Calculate the surface distances from `y_pred_edges` to `y_edges`. | |||
| Args: | |||
| y_pred_edges (np.ndarray): the edge of the predictions. | |||
| y_edges (np.ndarray): the edge of the ground truth. | |||
| """ | |||
| if not np.any(y_pred_edges): | |||
| return np.array([]) | |||
| if not np.any(y_edges): | |||
| dis = np.full(y_edges.shape, np.inf) | |||
| else: | |||
| if self.distance_metric == "euclidean": | |||
| dis = morphology.distance_transform_edt(~y_edges) | |||
| elif self.distance_metric in self.distance_metric_list[-2:]: | |||
| dis = morphology.distance_transform_cdt(~y_edges, metric=self.distance_metric) | |||
| surface_distance = dis[y_pred_edges] | |||
| return surface_distance | |||
| def update(self, *inputs): | |||
| """ | |||
| Updates the internal evaluation result 'y_pred', 'y' and 'label_idx'. | |||
| Args: | |||
| inputs: Input 'y_pred', 'y' and 'label_idx'. 'y_pred' and 'y' are Tensor or numpy.ndarray. 'y_pred' is the | |||
| predicted binary image. 'y' is the actual binary image. 'label_idx', the data type of `label_idx` | |||
| is int. | |||
| Raises: | |||
| ValueError: If the number of the inputs is not 3. | |||
| """ | |||
| if len(inputs) != 3: | |||
| raise ValueError('MeanSurfaceDistance need 3 inputs (y_pred, y, label), but got {}.'.format(len(inputs))) | |||
| y_pred = self._convert_data(inputs[0]) | |||
| y = self._convert_data(inputs[1]) | |||
| label_idx = inputs[2] | |||
| if y_pred.size == 0 or y_pred.shape != y.shape: | |||
| raise ValueError("y_pred and y should have same shape, but got {}, {}.".format(y_pred.shape, y.shape)) | |||
| if y_pred.dtype != bool: | |||
| y_pred = y_pred == label_idx | |||
| if y.dtype != bool: | |||
| y = y == label_idx | |||
| self._y_pred_edges = morphology.binary_erosion(y_pred) ^ y_pred | |||
| self._y_edges = morphology.binary_erosion(y) ^ y | |||
| self._is_update = True | |||
| def eval(self): | |||
| """ | |||
| Calculate mean surface distance. | |||
| """ | |||
| if self._is_update is False: | |||
| raise RuntimeError('Call the update method before calling eval.') | |||
| mean_surface_distance = self._get_surface_distance(self._y_pred_edges, self._y_edges) | |||
| if mean_surface_distance.shape == (0,): | |||
| return np.inf | |||
| avg_surface_distance = mean_surface_distance.mean() | |||
| if not self.symmetric: | |||
| return avg_surface_distance | |||
| contrary_mean_surface_distance = self._get_surface_distance(self._y_edges, self._y_pred_edges) | |||
| if contrary_mean_surface_distance.shape == (0,): | |||
| return np.inf | |||
| contrary_avg_surface_distance = contrary_mean_surface_distance.mean() | |||
| return np.mean((avg_surface_distance, contrary_avg_surface_distance)) | |||
| @@ -0,0 +1,140 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """RootMeanSquareSurfaceDistance.""" | |||
| from scipy.ndimage import morphology | |||
| import numpy as np | |||
| from mindspore._checkparam import Validator as validator | |||
| from .metric import Metric | |||
| class RootMeanSquareDistance(Metric): | |||
| """ | |||
| This function is used to compute the Residual Mean Square Distance from `y_pred` to `y` under the default | |||
| setting. Residual Mean Square Distance(RMS), the mean is taken from each of the points in the vector, these | |||
| residuals are squared (to remove negative signs), summed, weighted by the mean and then the square-root is taken. | |||
| Measured in mm. | |||
| Args: | |||
| distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods, | |||
| "euclidean", "chessboard" or "taxicab". Default: "euclidean". | |||
| symmetric (bool): if calculate the symmetric average surface distance between `y_pred` and `y`. In addition, | |||
| if sets ``symmetric = True``, the average symmetric surface distance between these two inputs | |||
| will be returned. Defaults: False. | |||
| Examples: | |||
| >>> x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])) | |||
| >>> y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])) | |||
| >>> metric = nn.RootMeanSquareDistance(symmetric=False, distance_metric="euclidean") | |||
| >>> metric.clear() | |||
| >>> metric.update(x, y, 0) | |||
| >>> root_mean_square_distance = metric.eval() | |||
| >>> print(root_mean_square_distance) | |||
| 1.0000000000000002 | |||
| """ | |||
| def __init__(self, symmetric=False, distance_metric="euclidean"): | |||
| super(RootMeanSquareDistance, self).__init__() | |||
| self.distance_metric_list = ["euclidean", "chessboard", "taxicab"] | |||
| distance_metric = validator.check_value_type("distance_metric", distance_metric, [str]) | |||
| self.distance_metric = validator.check_string(distance_metric, self.distance_metric_list, "distance_metric") | |||
| self.symmetric = validator.check_value_type("symmetric", symmetric, [bool]) | |||
| self.clear() | |||
| def clear(self): | |||
| """Clears the internal evaluation result.""" | |||
| self._y_pred_edges = 0 | |||
| self._y_edges = 0 | |||
| self._is_update = False | |||
| def _get_surface_distance(self, y_pred_edges, y_edges): | |||
| """ | |||
| Calculate the surface distances from `y_pred_edges` to `y_edges`. | |||
| Args: | |||
| y_pred_edges (np.ndarray): the edge of the predictions. | |||
| y_edges (np.ndarray): the edge of the ground truth. | |||
| """ | |||
| if not np.any(y_pred_edges): | |||
| return np.array([]) | |||
| if not np.any(y_edges): | |||
| dis = np.full(y_edges.shape, np.inf) | |||
| else: | |||
| if self.distance_metric == "euclidean": | |||
| dis = morphology.distance_transform_edt(~y_edges) | |||
| elif self.distance_metric in self.distance_metric_list[-2:]: | |||
| dis = morphology.distance_transform_cdt(~y_edges, metric=self.distance_metric) | |||
| surface_distance = dis[y_pred_edges] | |||
| return surface_distance | |||
| def update(self, *inputs): | |||
| """ | |||
| Updates the internal evaluation result 'y_pred', 'y' and 'label_idx'. | |||
| Args: | |||
| inputs: Input 'y_pred', 'y' and 'label_idx'. 'y_pred' and 'y' are Tensor or numpy.ndarray. 'y_pred' is the | |||
| predicted binary image. 'y' is the actual binary image. 'label_idx', the data type of `label_idx` | |||
| is int. | |||
| Raises: | |||
| ValueError: If the number of the inputs is not 3. | |||
| """ | |||
| if len(inputs) != 3: | |||
| raise ValueError('MeanSurfaceDistance need 3 inputs (y_pred, y, label), but got {}.'.format(len(inputs))) | |||
| y_pred = self._convert_data(inputs[0]) | |||
| y = self._convert_data(inputs[1]) | |||
| label_idx = inputs[2] | |||
| if y_pred.size == 0 or y_pred.shape != y.shape: | |||
| raise ValueError("y_pred and y should have same shape, but got {}, {}.".format(y_pred.shape, y.shape)) | |||
| if y_pred.dtype != bool: | |||
| y_pred = y_pred == label_idx | |||
| if y.dtype != bool: | |||
| y = y == label_idx | |||
| self._y_pred_edges = morphology.binary_erosion(y_pred) ^ y_pred | |||
| self._y_edges = morphology.binary_erosion(y) ^ y | |||
| self._is_update = True | |||
| def eval(self): | |||
| """ | |||
| Calculate residual mean square surface distance. | |||
| """ | |||
| if self._is_update is False: | |||
| raise RuntimeError('Call the update method before calling eval.') | |||
| residual_mean_square_distance = self._get_surface_distance(self._y_pred_edges, self._y_edges) | |||
| if residual_mean_square_distance.shape == (0,): | |||
| return np.inf | |||
| rms_surface_distance = (residual_mean_square_distance**2).mean() | |||
| if not self.symmetric: | |||
| return rms_surface_distance | |||
| contrary_residual_mean_square_distance = self._get_surface_distance(self._y_edges, self._y_pred_edges) | |||
| if contrary_residual_mean_square_distance.shape == (0,): | |||
| return np.inf | |||
| contrary_rms_surface_distance = (contrary_residual_mean_square_distance**2).mean() | |||
| rms_distance = np.sqrt(np.mean((rms_surface_distance, contrary_rms_surface_distance))) | |||
| return rms_distance | |||
| @@ -0,0 +1,70 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| # """test_mean_surface_distance""" | |||
| import math | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore.nn.metrics import get_metric_fn, MeanSurfaceDistance | |||
| def test_mean_surface_distance(): | |||
| """test_mean_surface_distance""" | |||
| x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])) | |||
| y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])) | |||
| metric = get_metric_fn('mean_surface_distance') | |||
| metric.clear() | |||
| metric.update(x, y, 0) | |||
| distance = metric.eval() | |||
| assert math.isclose(distance, 0.8047378541243649, abs_tol=0.001) | |||
| def test_mean_surface_distance_update1(): | |||
| x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) | |||
| metric = MeanSurfaceDistance() | |||
| metric.clear() | |||
| with pytest.raises(ValueError): | |||
| metric.update(x) | |||
| def test_mean_surface_distance_update2(): | |||
| x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) | |||
| y = Tensor(np.array([1, 0])) | |||
| metric = MeanSurfaceDistance() | |||
| metric.clear() | |||
| with pytest.raises(ValueError): | |||
| metric.update(x, y) | |||
| def test_mean_surface_distance_init(): | |||
| with pytest.raises(ValueError): | |||
| MeanSurfaceDistance(symmetric=False, distance_metric="eucli") | |||
| def test_mean_surface_distance_init2(): | |||
| with pytest.raises(TypeError): | |||
| MeanSurfaceDistance(symmetric=1) | |||
| def test_mean_surface_distance_runtime(): | |||
| metric = MeanSurfaceDistance() | |||
| metric.clear() | |||
| with pytest.raises(RuntimeError): | |||
| metric.eval() | |||
| @@ -0,0 +1,70 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| # """test_mean_surface_distance""" | |||
| import math | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore.nn.metrics import get_metric_fn, RootMeanSquareDistance | |||
| def test_root_mean_square_distance(): | |||
| """test_root_mean_square_distance""" | |||
| x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])) | |||
| y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])) | |||
| metric = get_metric_fn('root_mean_square_distance') | |||
| metric.clear() | |||
| metric.update(x, y, 0) | |||
| distance = metric.eval() | |||
| assert math.isclose(distance, 1.0000000000000002, abs_tol=0.001) | |||
| def test_root_mean_square_distance_update1(): | |||
| x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) | |||
| metric = RootMeanSquareDistance() | |||
| metric.clear() | |||
| with pytest.raises(ValueError): | |||
| metric.update(x) | |||
| def test_root_mean_square_distance_update2(): | |||
| x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) | |||
| y = Tensor(np.array([1, 0])) | |||
| metric = RootMeanSquareDistance() | |||
| metric.clear() | |||
| with pytest.raises(ValueError): | |||
| metric.update(x, y) | |||
| def test_root_mean_square_distance_init(): | |||
| with pytest.raises(ValueError): | |||
| RootMeanSquareDistance(symmetric=False, distance_metric="eucli") | |||
| def test_root_mean_square_distance_init2(): | |||
| with pytest.raises(TypeError): | |||
| RootMeanSquareDistance(symmetric=1) | |||
| def test_root_mean_square_distance_runtime(): | |||
| metric = RootMeanSquareDistance() | |||
| metric.clear() | |||
| with pytest.raises(RuntimeError): | |||
| metric.eval() | |||