diff --git a/mindspore/nn/metrics/__init__.py b/mindspore/nn/metrics/__init__.py index 411e8c497e..a99342a019 100755 --- a/mindspore/nn/metrics/__init__.py +++ b/mindspore/nn/metrics/__init__.py @@ -28,9 +28,12 @@ from .fbeta import Fbeta, F1 from .dice import Dice from .topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy from .loss import Loss +from .mean_surface_distance import MeanSurfaceDistance +from .root_mean_square_surface_distance import RootMeanSquareDistance __all__ = [ - "names", "get_metric_fn", + "names", + "get_metric_fn", "Accuracy", "MAE", "MSE", "Metric", @@ -44,6 +47,8 @@ __all__ = [ "Top1CategoricalAccuracy", "Top5CategoricalAccuracy", "Loss", + "MeanSurfaceDistance", + "RootMeanSquareDistance", ] __factory__ = { @@ -60,6 +65,8 @@ __factory__ = { 'mae': MAE, 'mse': MSE, 'loss': Loss, + 'mean_surface_distance': MeanSurfaceDistance, + 'root_mean_square_distance': RootMeanSquareDistance, } diff --git a/mindspore/nn/metrics/mean_surface_distance.py b/mindspore/nn/metrics/mean_surface_distance.py new file mode 100644 index 0000000000..ddb2caba63 --- /dev/null +++ b/mindspore/nn/metrics/mean_surface_distance.py @@ -0,0 +1,137 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MeanSurfaceDistance.""" +from scipy.ndimage import morphology +import numpy as np +from mindspore._checkparam import Validator as validator +from .metric import Metric + + +class MeanSurfaceDistance(Metric): + """ + This function is used to compute the Average Surface Distance from `y_pred` to `y` under the default setting. + Mean Surface Distance(MSD), the mean of the vector is taken. This tell us how much, on average, the surface varies + between the segmentation and the GT. + + Args: + distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods, + "euclidean", "chessboard" or "taxicab". Default: "euclidean". + symmetric (bool): if calculate the symmetric average surface distance between `y_pred` and `y`. In addition, + if sets ``symmetric = True``, the average symmetric surface distance between these two inputs + will be returned. Defaults: False. + + Examples: + >>> x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])) + >>> y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])) + >>> metric = nn.MeanSurfaceDistance(symmetric=False, distance_metric="euclidean") + >>> metric.clear() + >>> metric.update(x, y, 0) + >>> mean_average_distance = metric.eval() + >>> print(mean_average_distance) + 0.8047378541243649 + + """ + + def __init__(self, symmetric=False, distance_metric="euclidean"): + super(MeanSurfaceDistance, self).__init__() + self.distance_metric_list = ["euclidean", "chessboard", "taxicab"] + distance_metric = validator.check_value_type("distance_metric", distance_metric, [str]) + self.distance_metric = validator.check_string(distance_metric, self.distance_metric_list, "distance_metric") + self.symmetric = validator.check_value_type("symmetric", symmetric, [bool]) + self.clear() + + def clear(self): + """Clears the internal evaluation result.""" + self._y_pred_edges = 0 + self._y_edges = 0 + self._is_update = False + + def _get_surface_distance(self, y_pred_edges, y_edges): + """ + Calculate the surface distances from `y_pred_edges` to `y_edges`. + + Args: + y_pred_edges (np.ndarray): the edge of the predictions. + y_edges (np.ndarray): the edge of the ground truth. + """ + + if not np.any(y_pred_edges): + return np.array([]) + + if not np.any(y_edges): + dis = np.full(y_edges.shape, np.inf) + else: + if self.distance_metric == "euclidean": + dis = morphology.distance_transform_edt(~y_edges) + elif self.distance_metric in self.distance_metric_list[-2:]: + dis = morphology.distance_transform_cdt(~y_edges, metric=self.distance_metric) + + surface_distance = dis[y_pred_edges] + + return surface_distance + + def update(self, *inputs): + """ + Updates the internal evaluation result 'y_pred', 'y' and 'label_idx'. + + Args: + inputs: Input 'y_pred', 'y' and 'label_idx'. 'y_pred' and 'y' are Tensor or numpy.ndarray. 'y_pred' is the + predicted binary image. 'y' is the actual binary image. 'label_idx', the data type of `label_idx` + is int. + + Raises: + ValueError: If the number of the inputs is not 3. + """ + if len(inputs) != 3: + raise ValueError('MeanSurfaceDistance need 3 inputs (y_pred, y, label), but got {}.'.format(len(inputs))) + y_pred = self._convert_data(inputs[0]) + y = self._convert_data(inputs[1]) + label_idx = inputs[2] + + if y_pred.size == 0 or y_pred.shape != y.shape: + raise ValueError("y_pred and y should have same shape, but got {}, {}.".format(y_pred.shape, y.shape)) + + if y_pred.dtype != bool: + y_pred = y_pred == label_idx + if y.dtype != bool: + y = y == label_idx + + self._y_pred_edges = morphology.binary_erosion(y_pred) ^ y_pred + self._y_edges = morphology.binary_erosion(y) ^ y + self._is_update = True + + def eval(self): + """ + Calculate mean surface distance. + """ + if self._is_update is False: + raise RuntimeError('Call the update method before calling eval.') + + mean_surface_distance = self._get_surface_distance(self._y_pred_edges, self._y_edges) + + if mean_surface_distance.shape == (0,): + return np.inf + + avg_surface_distance = mean_surface_distance.mean() + + if not self.symmetric: + return avg_surface_distance + + contrary_mean_surface_distance = self._get_surface_distance(self._y_edges, self._y_pred_edges) + if contrary_mean_surface_distance.shape == (0,): + return np.inf + + contrary_avg_surface_distance = contrary_mean_surface_distance.mean() + return np.mean((avg_surface_distance, contrary_avg_surface_distance)) diff --git a/mindspore/nn/metrics/root_mean_square_surface_distance.py b/mindspore/nn/metrics/root_mean_square_surface_distance.py new file mode 100644 index 0000000000..14069032a5 --- /dev/null +++ b/mindspore/nn/metrics/root_mean_square_surface_distance.py @@ -0,0 +1,140 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""RootMeanSquareSurfaceDistance.""" +from scipy.ndimage import morphology +import numpy as np +from mindspore._checkparam import Validator as validator +from .metric import Metric + + +class RootMeanSquareDistance(Metric): + """ + This function is used to compute the Residual Mean Square Distance from `y_pred` to `y` under the default + setting. Residual Mean Square Distance(RMS), the mean is taken from each of the points in the vector, these + residuals are squared (to remove negative signs), summed, weighted by the mean and then the square-root is taken. + Measured in mm. + + Args: + distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods, + "euclidean", "chessboard" or "taxicab". Default: "euclidean". + symmetric (bool): if calculate the symmetric average surface distance between `y_pred` and `y`. In addition, + if sets ``symmetric = True``, the average symmetric surface distance between these two inputs + will be returned. Defaults: False. + + Examples: + >>> x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])) + >>> y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])) + >>> metric = nn.RootMeanSquareDistance(symmetric=False, distance_metric="euclidean") + >>> metric.clear() + >>> metric.update(x, y, 0) + >>> root_mean_square_distance = metric.eval() + >>> print(root_mean_square_distance) + 1.0000000000000002 + + """ + + def __init__(self, symmetric=False, distance_metric="euclidean"): + super(RootMeanSquareDistance, self).__init__() + self.distance_metric_list = ["euclidean", "chessboard", "taxicab"] + distance_metric = validator.check_value_type("distance_metric", distance_metric, [str]) + self.distance_metric = validator.check_string(distance_metric, self.distance_metric_list, "distance_metric") + self.symmetric = validator.check_value_type("symmetric", symmetric, [bool]) + self.clear() + + def clear(self): + """Clears the internal evaluation result.""" + self._y_pred_edges = 0 + self._y_edges = 0 + self._is_update = False + + def _get_surface_distance(self, y_pred_edges, y_edges): + """ + Calculate the surface distances from `y_pred_edges` to `y_edges`. + + Args: + y_pred_edges (np.ndarray): the edge of the predictions. + y_edges (np.ndarray): the edge of the ground truth. + """ + + if not np.any(y_pred_edges): + return np.array([]) + + if not np.any(y_edges): + dis = np.full(y_edges.shape, np.inf) + else: + if self.distance_metric == "euclidean": + dis = morphology.distance_transform_edt(~y_edges) + elif self.distance_metric in self.distance_metric_list[-2:]: + dis = morphology.distance_transform_cdt(~y_edges, metric=self.distance_metric) + + surface_distance = dis[y_pred_edges] + + return surface_distance + + def update(self, *inputs): + """ + Updates the internal evaluation result 'y_pred', 'y' and 'label_idx'. + + Args: + inputs: Input 'y_pred', 'y' and 'label_idx'. 'y_pred' and 'y' are Tensor or numpy.ndarray. 'y_pred' is the + predicted binary image. 'y' is the actual binary image. 'label_idx', the data type of `label_idx` + is int. + + Raises: + ValueError: If the number of the inputs is not 3. + """ + if len(inputs) != 3: + raise ValueError('MeanSurfaceDistance need 3 inputs (y_pred, y, label), but got {}.'.format(len(inputs))) + y_pred = self._convert_data(inputs[0]) + y = self._convert_data(inputs[1]) + label_idx = inputs[2] + + if y_pred.size == 0 or y_pred.shape != y.shape: + raise ValueError("y_pred and y should have same shape, but got {}, {}.".format(y_pred.shape, y.shape)) + + if y_pred.dtype != bool: + y_pred = y_pred == label_idx + if y.dtype != bool: + y = y == label_idx + + self._y_pred_edges = morphology.binary_erosion(y_pred) ^ y_pred + self._y_edges = morphology.binary_erosion(y) ^ y + self._is_update = True + + def eval(self): + """ + Calculate residual mean square surface distance. + """ + if self._is_update is False: + raise RuntimeError('Call the update method before calling eval.') + + residual_mean_square_distance = self._get_surface_distance(self._y_pred_edges, self._y_edges) + + if residual_mean_square_distance.shape == (0,): + return np.inf + + rms_surface_distance = (residual_mean_square_distance**2).mean() + + if not self.symmetric: + return rms_surface_distance + + contrary_residual_mean_square_distance = self._get_surface_distance(self._y_edges, self._y_pred_edges) + if contrary_residual_mean_square_distance.shape == (0,): + return np.inf + + contrary_rms_surface_distance = (contrary_residual_mean_square_distance**2).mean() + + rms_distance = np.sqrt(np.mean((rms_surface_distance, contrary_rms_surface_distance))) + return rms_distance diff --git a/tests/ut/python/metrics/test_mean_surface_distance.py b/tests/ut/python/metrics/test_mean_surface_distance.py new file mode 100644 index 0000000000..83c3b4ddeb --- /dev/null +++ b/tests/ut/python/metrics/test_mean_surface_distance.py @@ -0,0 +1,70 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# """test_mean_surface_distance""" + +import math +import numpy as np +import pytest +from mindspore import Tensor +from mindspore.nn.metrics import get_metric_fn, MeanSurfaceDistance + + +def test_mean_surface_distance(): + """test_mean_surface_distance""" + x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])) + y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])) + metric = get_metric_fn('mean_surface_distance') + metric.clear() + metric.update(x, y, 0) + distance = metric.eval() + + assert math.isclose(distance, 0.8047378541243649, abs_tol=0.001) + + +def test_mean_surface_distance_update1(): + x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) + metric = MeanSurfaceDistance() + metric.clear() + + with pytest.raises(ValueError): + metric.update(x) + + +def test_mean_surface_distance_update2(): + x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) + y = Tensor(np.array([1, 0])) + metric = MeanSurfaceDistance() + metric.clear() + + with pytest.raises(ValueError): + metric.update(x, y) + + +def test_mean_surface_distance_init(): + with pytest.raises(ValueError): + MeanSurfaceDistance(symmetric=False, distance_metric="eucli") + + +def test_mean_surface_distance_init2(): + with pytest.raises(TypeError): + MeanSurfaceDistance(symmetric=1) + + +def test_mean_surface_distance_runtime(): + metric = MeanSurfaceDistance() + metric.clear() + + with pytest.raises(RuntimeError): + metric.eval() diff --git a/tests/ut/python/metrics/test_root_mean_square_distance.py b/tests/ut/python/metrics/test_root_mean_square_distance.py new file mode 100644 index 0000000000..9da4692ee2 --- /dev/null +++ b/tests/ut/python/metrics/test_root_mean_square_distance.py @@ -0,0 +1,70 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# """test_mean_surface_distance""" + +import math +import numpy as np +import pytest +from mindspore import Tensor +from mindspore.nn.metrics import get_metric_fn, RootMeanSquareDistance + + +def test_root_mean_square_distance(): + """test_root_mean_square_distance""" + x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])) + y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])) + metric = get_metric_fn('root_mean_square_distance') + metric.clear() + metric.update(x, y, 0) + distance = metric.eval() + + assert math.isclose(distance, 1.0000000000000002, abs_tol=0.001) + + +def test_root_mean_square_distance_update1(): + x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) + metric = RootMeanSquareDistance() + metric.clear() + + with pytest.raises(ValueError): + metric.update(x) + + +def test_root_mean_square_distance_update2(): + x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) + y = Tensor(np.array([1, 0])) + metric = RootMeanSquareDistance() + metric.clear() + + with pytest.raises(ValueError): + metric.update(x, y) + + +def test_root_mean_square_distance_init(): + with pytest.raises(ValueError): + RootMeanSquareDistance(symmetric=False, distance_metric="eucli") + + +def test_root_mean_square_distance_init2(): + with pytest.raises(TypeError): + RootMeanSquareDistance(symmetric=1) + + +def test_root_mean_square_distance_runtime(): + metric = RootMeanSquareDistance() + metric.clear() + + with pytest.raises(RuntimeError): + metric.eval()