|
|
|
@@ -13,11 +13,13 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
"""loss""" |
|
|
|
import mindspore |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.ops import functional as F |
|
|
|
from mindspore import nn |
|
|
|
from mindspore.ops.primitive import constexpr |
|
|
|
from mindspore.ops import _selected_ops |
|
|
|
from mindspore.nn.cell import Cell |
|
|
|
@@ -896,3 +898,110 @@ class BCEWithLogitsLoss(_Loss): |
|
|
|
pos_weight = ones_input |
|
|
|
loss = self.bce_with_logits_loss(predict, target, weight, pos_weight) |
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _check_ndim(predict_nidm, target_ndim): |
|
|
|
validator.check_int(predict_nidm, target_ndim, Rel.EQ, 'predict_nidm', 'target_ndim') |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _check_channel_and_shape(target, predict): |
|
|
|
if target not in (predict, 1): |
|
|
|
raise ValueError("The target must have a channel or the same shape as predict.") |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _check_predict_channel(predict): |
|
|
|
if predict == 1: |
|
|
|
raise NotImplementedError("Single channel prediction is not supported.") |
|
|
|
|
|
|
|
|
|
|
|
class FocalLoss(_Loss): |
|
|
|
r""" |
|
|
|
The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the |
|
|
|
effect of image object detection. It is a loss function to solve the imbalance of categories and the difference of |
|
|
|
classification difficulty. |
|
|
|
|
|
|
|
Args: |
|
|
|
gamma (float): Gamma is used to adjust the steepness of weight curve in focal loss. Default: 2.0. |
|
|
|
weight (Union[Tensor, None]): A rescaling weight applied to the loss of each batch element. If None, no weights |
|
|
|
are applied. Default: None. |
|
|
|
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none". |
|
|
|
If "none", do not perform reduction. Default: "mean". |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **predict** (Tensor) - Input logits. Tensor of shape should be BCH[WD]. Where C is the number of classes. |
|
|
|
Its value is greater than 1. |
|
|
|
- **target** (Tensor) - Tensor of shape should be B1H[WD] or BCH[WD]. If the target shape is B1H[WD], the |
|
|
|
expected target of this loss should be the class index within the range of [0, C-1], |
|
|
|
where C is the number of classes. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, a tensor of shape with the per-example sampled Focal losses. |
|
|
|
|
|
|
|
Raises: |
|
|
|
TypeError: If the data type of ``gamma`` is not float.. |
|
|
|
TypeError: If ``weight`` is not a Parameter. |
|
|
|
ValueError: If ``target`` shape different from ``predict``. |
|
|
|
ValueError: If ``target`` channel is not 1 and ``target`` shape is different from ``predict``. |
|
|
|
ValueError: If ``reduction`` is not one of 'none', 'mean', 'sum'. |
|
|
|
|
|
|
|
Example: |
|
|
|
>>> predict = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) |
|
|
|
>>> target = Tensor([[1], [1], [0]], mstype.int32) |
|
|
|
>>> focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean') |
|
|
|
>>> output = focalloss(inputs, labels) |
|
|
|
>>> print(output) |
|
|
|
0.33365273 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, weight=None, gamma=2.0, reduction='mean'): |
|
|
|
super(FocalLoss, self).__init__(reduction=reduction) |
|
|
|
|
|
|
|
self.gamma = validator.check_value_type("gamma", gamma, [float]) |
|
|
|
if weight is not None and not isinstance(weight, Tensor): |
|
|
|
raise TypeError("The type of weight should be Tensor, but got {}.".format(type(weight))) |
|
|
|
self.weight = weight |
|
|
|
self.expand_dims = P.ExpandDims() |
|
|
|
self.gather_d = P.GatherD() |
|
|
|
self.squeeze = P.Squeeze(axis=1) |
|
|
|
self.tile = P.Tile() |
|
|
|
self.cast = P.Cast() |
|
|
|
|
|
|
|
def construct(self, predict, target): |
|
|
|
targets = target |
|
|
|
_check_ndim(predict.ndim, targets.ndim) |
|
|
|
_check_channel_and_shape(targets.shape[1], predict.shape[1]) |
|
|
|
_check_predict_channel(predict.shape[1]) |
|
|
|
|
|
|
|
if predict.ndim > 2: |
|
|
|
predict = predict.view(predict.shape[0], predict.shape[1], -1) |
|
|
|
targets = targets.view(targets.shape[0], targets.shape[1], -1) |
|
|
|
else: |
|
|
|
predict = self.expand_dims(predict, 2) |
|
|
|
targets = self.expand_dims(targets, 2) |
|
|
|
|
|
|
|
log_probability = nn.LogSoftmax(1)(predict) |
|
|
|
|
|
|
|
if target.shape[1] == 1: |
|
|
|
log_probability = self.gather_d(log_probability, 1, self.cast(targets, mindspore.int32)) |
|
|
|
log_probability = self.squeeze(log_probability) |
|
|
|
|
|
|
|
probability = F.exp(log_probability) |
|
|
|
|
|
|
|
if self.weight is not None: |
|
|
|
convert_weight = self.weight[None, :, None] |
|
|
|
convert_weight = self.tile(convert_weight, (targets.shape[0], 1, targets.shape[2])) |
|
|
|
if target.shape[1] == 1: |
|
|
|
convert_weight = self.gather_d(convert_weight, 1, self.cast(targets, mindspore.int32)) |
|
|
|
convert_weight = self.squeeze(convert_weight) |
|
|
|
probability = log_probability * convert_weight |
|
|
|
|
|
|
|
weight = F.pows(-probability + 1.0, self.gamma) |
|
|
|
if target.shape[1] == 1: |
|
|
|
loss = (-weight * log_probability).mean(axis=1) |
|
|
|
else: |
|
|
|
loss = (-weight * targets * log_probability).mean(axis=-1) |
|
|
|
|
|
|
|
return self.get_loss(loss) |