|
|
|
@@ -76,8 +76,8 @@ class LossBase(Cell): |
|
|
|
|
|
|
|
Args: |
|
|
|
weights (Union[float, Tensor]): Optional `Tensor` whose rank is either 0, or the same rank as inputs, |
|
|
|
and must be broadcastable to inputs (i.e., all dimensions must be either `1`, |
|
|
|
or the same as the corresponding inputs dimension). |
|
|
|
and must be broadcastable to inputs (i.e., all dimensions must be either `1`, |
|
|
|
or the same as the corresponding inputs dimension). |
|
|
|
""" |
|
|
|
input_dtype = x.dtype |
|
|
|
x = self.cast(x, mstype.float32) |
|
|
|
@@ -1282,10 +1282,10 @@ class FocalLoss(LossBase): |
|
|
|
convert_weight = self.squeeze(convert_weight) |
|
|
|
log_probability = log_probability * convert_weight |
|
|
|
|
|
|
|
weight = F.pows(-probability + 1.0, self.gamma) |
|
|
|
weight = F.pows(-1 * probability + 1.0, self.gamma) |
|
|
|
if target.shape[1] == 1: |
|
|
|
loss = (-weight * log_probability).mean(axis=1) |
|
|
|
loss = (-1 * weight * log_probability).mean(axis=1) |
|
|
|
else: |
|
|
|
loss = (-weight * targets * log_probability).mean(axis=-1) |
|
|
|
loss = (-1 * weight * targets * log_probability).mean(axis=-1) |
|
|
|
|
|
|
|
return self.get_loss(loss) |