diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index 03406f6cc4..cd6742b675 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -310,11 +310,10 @@ class DiceLoss(_Loss): Args: smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0. Default: 1e-5. - threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5. Inputs: - - **y_pred** (Tensor) - Tensor of shape (N, C). - - **y** (Tensor) - Tensor of shape (N, C). + - **y_pred** (Tensor) - Tensor of shape (N, ...). + - **y** (Tensor) - Tensor of shape (N, ...). Outputs: Tensor, a tensor of shape with the per-example sampled Dice losses. @@ -323,32 +322,30 @@ class DiceLoss(_Loss): ``Ascend`` Examples: - >>> loss = nn.Diceloss(smooth=1e-5, threshold=0.5) + >>> loss = nn.DiceLoss(smooth=1e-5) >>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) >>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32) >>> output = loss(y_pred, y) >>> print(output) - [0.77777076] + [0.7953220862819745] + + Raises: + ValueError: If the dimensions are different. + TypeError: If the type of inputs are not Tensor. """ - def __init__(self, smooth=1e-5, threshold=0.5): + def __init__(self, smooth=1e-5): super(DiceLoss, self).__init__() self.smooth = validator.check_positive_float(smooth, "smooth") - self.threshold = validator.check_value_type("threshold", threshold, [float]) self.reshape = P.Reshape() def construct(self, logits, label): _check_shape(logits.shape, label.shape) - logits = self.cast((logits > self.threshold), mstype.float32) - label = self.cast(label, mstype.float32) - dim = label.shape - pred_flat = self.reshape(logits, (dim[0], -1)) - true_flat = self.reshape(label, (dim[0], -1)) - - intersection = self.reduce_sum((pred_flat * true_flat), 1) - unionset = self.reduce_sum(pred_flat, 1) + self.reduce_sum(true_flat, 1) + intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1))) + unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \ + self.reduce_sum(self.mul(label.view(-1), label.view(-1))) - dice = (2 * intersection + self.smooth) / (unionset + self.smooth) - dice_loss = 1 - self.reduce_sum(dice) / dim[0] + single_dice_coeff = (2 * intersection) / (unionset + self.smooth) + dice_loss = 1 - single_dice_coeff / label.shape[0] return dice_loss diff --git a/mindspore/nn/metrics/dice.py b/mindspore/nn/metrics/dice.py index ad5e8f26af..1b1e661441 100644 --- a/mindspore/nn/metrics/dice.py +++ b/mindspore/nn/metrics/dice.py @@ -26,35 +26,35 @@ class Dice(Metric): The function is shown as follows: .. math:: - dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true} + dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true} Args: smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0. Default: 1e-5. - threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5. Examples: >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) >>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]])) - >>> metric = Dice(smooth=1e-5, threshold=0.5) + >>> metric = Dice(smooth=1e-5) >>> metric.clear() >>> metric.update(x, y) >>> dice = metric.eval() - 0.22222926 + >>> print(dice) + 0.20467791371802546 """ - def __init__(self, smooth=1e-5, threshold=0.5): + def __init__(self, smooth=1e-5): super(Dice, self).__init__() self.smooth = validator.check_positive_float(smooth, "smooth") - self.threshold = validator.check_value_type("threshold", threshold, [float]) + self._dice_coeff_sum = 0 + self._samples_num = 0 self.clear() def clear(self): """Clears the internal evaluation result.""" - self._dim = 0 - self.intersection = 0 - self.unionset = 0 + self._dice_coeff_sum = 0 + self._samples_num = 0 def update(self, *inputs): """ @@ -62,7 +62,7 @@ class Dice(Metric): Args: inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the - predicted value, `y` is the true value. The shape of `y_pred` and `y` are both :math:`(N, C)`. + predicted value, `y` is the true value. The shape of `y_pred` and `y` are both :math:`(N, ...)`. Raises: ValueError: If the number of the inputs is not 2. @@ -72,17 +72,17 @@ class Dice(Metric): y_pred = self._convert_data(inputs[0]) y = self._convert_data(inputs[1]) + self._samples_num += y.shape[0] if y_pred.shape != y.shape: raise RuntimeError('y_pred and y should have same the dimension, but the shape of y_pred is{}, ' 'the shape of y is {}.'.format(y_pred.shape, y.shape)) - y_pred = (y_pred > self.threshold).astype(int) - self._dim = y.shape - pred_flat = np.reshape(y_pred, (self._dim[0], -1)) - true_flat = np.reshape(y, (self._dim[0], -1)) - self.intersection = np.sum((pred_flat * true_flat), axis=1) - self.unionset = np.sum(pred_flat, axis=1) + np.sum(true_flat, axis=1) + intersection = np.dot(y_pred.flatten(), y.flatten()) + unionset = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) + + single_dice_coeff = 2 * float(intersection) / float(unionset + self.smooth) + self._dice_coeff_sum += single_dice_coeff def eval(self): r""" @@ -92,11 +92,9 @@ class Dice(Metric): Float, the computed result. Raises: - RuntimeError: If the sample size is 0. + RuntimeError: If the total samples num is 0. """ - if self._dim[0] == 0: - raise RuntimeError('Dice can not be calculated, because the number of samples is 0.') - - dice = (2 * self.intersection + self.smooth) / (self.unionset + self.smooth) + if self._samples_num == 0: + raise RuntimeError('Total samples num must not be 0.') - return np.sum(dice) / self._dim[0] + return self._dice_coeff_sum / float(self._samples_num) diff --git a/tests/ut/python/metrics/test_dice.py b/tests/ut/python/metrics/test_dice.py index 4a9fbb7ba5..dfbd13eedb 100644 --- a/tests/ut/python/metrics/test_dice.py +++ b/tests/ut/python/metrics/test_dice.py @@ -29,12 +29,12 @@ def test_classification_dice(): metric.update(x, y) dice = metric.eval() - assert math.isclose(dice, 0.22222926, abs_tol=0.001) + assert math.isclose(dice, 0.20467791371802546, abs_tol=0.001) def test_dice_update1(): x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) - metric = Dice(1e-5, 0.5) + metric = Dice(1e-5) metric.clear() with pytest.raises(ValueError): @@ -42,8 +42,8 @@ def test_dice_update1(): def test_dice_runtime(): - metric = Dice(1e-5, 0.8) + metric = Dice(1e-5) metric.clear() - with pytest.raises(TypeError): + with pytest.raises(RuntimeError): metric.eval() diff --git a/tests/ut/python/nn/test_loss.py b/tests/ut/python/nn/test_loss.py index a7c6e52422..0408aabff8 100644 --- a/tests/ut/python/nn/test_loss.py +++ b/tests/ut/python/nn/test_loss.py @@ -97,8 +97,7 @@ def test_dice_loss(): y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32) # Pass the test if no error is reported - loss(y_pred, y).asnumpy() - + loss(y_pred, y) def test_dice_loss_check_shape():