Browse Source

!21856 Fix comments in LossBase.get_loss()

Merge pull request !21856 from chenhaozhe/code_docs_change_lossBase_getloss_1.4
r1.4
i-robot Gitee 4 years ago
parent
commit
f40c33ff18
1 changed files with 5 additions and 5 deletions
  1. +5
    -5
      mindspore/nn/loss/loss.py

+ 5
- 5
mindspore/nn/loss/loss.py View File

@@ -76,8 +76,8 @@ class LossBase(Cell):

Args:
weights (Union[float, Tensor]): Optional `Tensor` whose rank is either 0, or the same rank as inputs,
and must be broadcastable to inputs (i.e., all dimensions must be either `1`,
or the same as the corresponding inputs dimension).
and must be broadcastable to inputs (i.e., all dimensions must be either `1`,
or the same as the corresponding inputs dimension).
"""
input_dtype = x.dtype
x = self.cast(x, mstype.float32)
@@ -1282,10 +1282,10 @@ class FocalLoss(LossBase):
convert_weight = self.squeeze(convert_weight)
log_probability = log_probability * convert_weight

weight = F.pows(-probability + 1.0, self.gamma)
weight = F.pows(-1 * probability + 1.0, self.gamma)
if target.shape[1] == 1:
loss = (-weight * log_probability).mean(axis=1)
loss = (-1 * weight * log_probability).mean(axis=1)
else:
loss = (-weight * targets * log_probability).mean(axis=-1)
loss = (-1 * weight * targets * log_probability).mean(axis=-1)

return self.get_loss(loss)

Loading…
Cancel
Save