|
|
@@ -21,6 +21,7 @@ from mindspore.ops import functional as F |
|
|
from mindspore.ops.primitive import constexpr |
|
|
from mindspore.ops.primitive import constexpr |
|
|
from mindspore.ops import _selected_ops |
|
|
from mindspore.ops import _selected_ops |
|
|
from mindspore.nn.cell import Cell |
|
|
from mindspore.nn.cell import Cell |
|
|
|
|
|
from mindspore.nn.layer.activation import get_activation |
|
|
from mindspore._checkparam import Validator as validator |
|
|
from mindspore._checkparam import Validator as validator |
|
|
from mindspore._checkparam import Rel |
|
|
from mindspore._checkparam import Rel |
|
|
from ... import context |
|
|
from ... import context |
|
|
@@ -329,14 +330,14 @@ class DiceLoss(_Loss): |
|
|
Default: 1e-5. |
|
|
Default: 1e-5. |
|
|
|
|
|
|
|
|
Inputs: |
|
|
Inputs: |
|
|
- **y_pred** (Tensor) - Tensor of shape (N, ...). |
|
|
|
|
|
- **y** (Tensor) - Tensor of shape (N, ...). |
|
|
|
|
|
|
|
|
- **y_pred** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32. |
|
|
|
|
|
- **y** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32. |
|
|
|
|
|
|
|
|
Outputs: |
|
|
Outputs: |
|
|
Tensor, a tensor of shape with the per-example sampled Dice losses. |
|
|
Tensor, a tensor of shape with the per-example sampled Dice losses. |
|
|
|
|
|
|
|
|
Supported Platforms: |
|
|
Supported Platforms: |
|
|
``Ascend`` |
|
|
|
|
|
|
|
|
``Ascend`` ``GPU`` ``CPU`` |
|
|
|
|
|
|
|
|
Examples: |
|
|
Examples: |
|
|
>>> loss = nn.DiceLoss(smooth=1e-5) |
|
|
>>> loss = nn.DiceLoss(smooth=1e-5) |
|
|
@@ -364,7 +365,7 @@ class DiceLoss(_Loss): |
|
|
single_dice_coeff = (2 * intersection) / (unionset + self.smooth) |
|
|
single_dice_coeff = (2 * intersection) / (unionset + self.smooth) |
|
|
dice_loss = 1 - single_dice_coeff / label.shape[0] |
|
|
dice_loss = 1 - single_dice_coeff / label.shape[0] |
|
|
|
|
|
|
|
|
return dice_loss |
|
|
|
|
|
|
|
|
return dice_loss.mean() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
@constexpr |
|
|
@@ -372,6 +373,79 @@ def _check_shape(logits_shape, label_shape): |
|
|
validator.check('logits_shape', logits_shape, 'label_shape', label_shape) |
|
|
validator.check('logits_shape', logits_shape, 'label_shape', label_shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
|
|
def _check_weights(weight, label): |
|
|
|
|
|
if weight.shape[0] != label.shape[1]: |
|
|
|
|
|
raise ValueError("The shape of weight should be equal to the shape of label, but the shape of weight is {}, " |
|
|
|
|
|
"and the shape of label is {}.".format(weight.shape, label.shape)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiClassDiceLoss(_Loss): |
|
|
|
|
|
r""" |
|
|
|
|
|
When there are multiple classifications, label is transformed into multiple binary classifications by one hot. |
|
|
|
|
|
For each channel section in the channel, it can be regarded as a binary classification problem, so it can be |
|
|
|
|
|
obtained through the binary loss of each category, and then the average value. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
weights (Union[Tensor, None]): Tensor of shape `[num_classes, dim]`. |
|
|
|
|
|
ignore_indiex (Union[int, None]): Class index to ignore. |
|
|
|
|
|
activation (Union[str, Cell]): Activate function applied to the output of the fully connected layer, eg. 'ReLU'. |
|
|
|
|
|
Default: 'Softmax'. Choose from: |
|
|
|
|
|
['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'FastGelu', 'Sigmoid', |
|
|
|
|
|
'PReLU', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', 'LogSigmoid'] |
|
|
|
|
|
|
|
|
|
|
|
Inputs: |
|
|
|
|
|
- **y_pred** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32. |
|
|
|
|
|
- **y** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32. |
|
|
|
|
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
Tensor, a tensor of shape with the per-example sampled MultiClass Dice Losses. |
|
|
|
|
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
|
|
``Ascend`` ``GPU`` ``CPU`` |
|
|
|
|
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
>>> loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax") |
|
|
|
|
|
>>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) |
|
|
|
|
|
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32) |
|
|
|
|
|
>>> output = loss(y_pred, y) |
|
|
|
|
|
>>> print(output) |
|
|
|
|
|
[0.7761003] |
|
|
|
|
|
|
|
|
|
|
|
Raises: |
|
|
|
|
|
ValueError: If the shapes are different. |
|
|
|
|
|
TypeError: If the type of inputs are not Tensor. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self, weights=None, ignore_indiex=None, activation="softmax"): |
|
|
|
|
|
super(MultiClassDiceLoss, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.binarydiceloss = DiceLoss(smooth=1e-5) |
|
|
|
|
|
self.weights = weights if weights is None else validator.check_value_type("weights", weights, [Tensor]) |
|
|
|
|
|
self.ignore_indiex = ignore_indiex if ignore_indiex is None else \ |
|
|
|
|
|
validator.check_value_type("ignore_indiex", ignore_indiex, [int]) |
|
|
|
|
|
self.activation = get_activation(activation) if isinstance(activation, str) else activation |
|
|
|
|
|
if self.activation is not None and not isinstance(self.activation, Cell): |
|
|
|
|
|
raise TypeError("The activation must be str or Cell, but got {}.".format(activation)) |
|
|
|
|
|
self.reshape = P.Reshape() |
|
|
|
|
|
|
|
|
|
|
|
def construct(self, logits, label): |
|
|
|
|
|
_check_shape(logits.shape, label.shape) |
|
|
|
|
|
total_loss = 0 |
|
|
|
|
|
|
|
|
|
|
|
if self.activation is not None: |
|
|
|
|
|
logits = self.activation(logits) |
|
|
|
|
|
|
|
|
|
|
|
for i in range(label.shape[1]): |
|
|
|
|
|
if i != self.ignore_indiex: |
|
|
|
|
|
dice_loss = self.binarydiceloss(logits[:, i], label[:, i]) |
|
|
|
|
|
if self.weights is not None: |
|
|
|
|
|
_check_weights(self.weights, label) |
|
|
|
|
|
dice_loss *= self.weights[i] |
|
|
|
|
|
total_loss += dice_loss |
|
|
|
|
|
|
|
|
|
|
|
return total_loss/label.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SampledSoftmaxLoss(_Loss): |
|
|
class SampledSoftmaxLoss(_Loss): |
|
|
r""" |
|
|
r""" |
|
|
Computes the sampled softmax training loss. |
|
|
Computes the sampled softmax training loss. |
|
|
|