|
|
|
@@ -436,7 +436,7 @@ class DiceLoss(_Loss): |
|
|
|
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32) |
|
|
|
>>> output = loss(y_pred, y) |
|
|
|
>>> print(output) |
|
|
|
0.38596618 |
|
|
|
[0.38596618] |
|
|
|
""" |
|
|
|
def __init__(self, smooth=1e-5): |
|
|
|
super(DiceLoss, self).__init__() |
|
|
|
@@ -1027,6 +1027,12 @@ def _check_channel_and_shape(predict, target): |
|
|
|
f"inferred from 'predict': C={predict}.") |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _check_input_dtype(targets_dtype, cls_name): |
|
|
|
validator.check_type_name("targets", targets_dtype, [mstype.int32, mstype.int64, mstype.float16, |
|
|
|
mstype.float32], cls_name) |
|
|
|
|
|
|
|
|
|
|
|
class FocalLoss(_Loss): |
|
|
|
r""" |
|
|
|
The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the |
|
|
|
@@ -1089,11 +1095,14 @@ class FocalLoss(_Loss): |
|
|
|
self.squeeze = P.Squeeze(axis=1) |
|
|
|
self.tile = P.Tile() |
|
|
|
self.cast = P.Cast() |
|
|
|
self.dtype = P.DType() |
|
|
|
self.logsoftmax = nn.LogSoftmax(1) |
|
|
|
|
|
|
|
def construct(self, predict, target): |
|
|
|
targets = target |
|
|
|
_check_ndim(predict.ndim, targets.ndim) |
|
|
|
_check_channel_and_shape(predict.shape[1], targets.shape[1]) |
|
|
|
_check_input_dtype(self.dtype(targets), self.cls_name) |
|
|
|
|
|
|
|
if predict.ndim > 2: |
|
|
|
predict = predict.view(predict.shape[0], predict.shape[1], -1) |
|
|
|
@@ -1102,7 +1111,7 @@ class FocalLoss(_Loss): |
|
|
|
predict = self.expand_dims(predict, 2) |
|
|
|
targets = self.expand_dims(targets, 2) |
|
|
|
|
|
|
|
log_probability = nn.LogSoftmax(1)(predict) |
|
|
|
log_probability = self.logsoftmax(predict) |
|
|
|
|
|
|
|
if target.shape[1] == 1: |
|
|
|
log_probability = self.gather_d(log_probability, 1, self.cast(targets, mindspore.int32)) |
|
|
|
@@ -1116,7 +1125,7 @@ class FocalLoss(_Loss): |
|
|
|
if target.shape[1] == 1: |
|
|
|
convert_weight = self.gather_d(convert_weight, 1, self.cast(targets, mindspore.int32)) |
|
|
|
convert_weight = self.squeeze(convert_weight) |
|
|
|
probability = log_probability * convert_weight |
|
|
|
log_probability = log_probability * convert_weight |
|
|
|
|
|
|
|
weight = F.pows(-probability + 1.0, self.gamma) |
|
|
|
if target.shape[1] == 1: |
|
|
|
|