You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.MSELoss.txt 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. Class mindspore.nn.MSELoss(reduction='mean')
  2. MSELoss是用来测量:math:`x`和:math:`y`对应元素之间的均方差,其中:math:`x`是输入Tensor,:math:`y`是标签Tensor。
  3. 假设:math:`x`和:math:`y`为一维Tensor,长度:math:`N`,则计算:math:`x`和:math:`y`的unreduced loss(即reduction参数设置为"none")的公式如下:
  4. .. math::
  5. \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad \text{with} \quad l_n = (x_n - y_n)^2.
  6. 其中,:math:`N`为batch size。如果`reduction`不是"none",则:
  7. .. math::
  8. \ell(x, y) =
  9. \begin{cases}
  10. \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
  11. \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
  12. \end{cases}
  13. 参数:
  14. reduction (str):应用于loss的reduction类型。取值为"mean","sum",或"none"。
  15. 默认值:"mean"。
  16. 输入:
  17. - **logits** (Tensor):shape为:math:`(N, *)`的Tensor,其中:math:`*`表示任意的附加维度。
  18. - **labels** (Tensor):shape为:math:`(N, *)`的Tensor,在通常情况下与`logits`的shape相同。
  19. 但是如果`logits`和`labels`的shape不同,需要保证他们之间可以互相广播。
  20. 输出:
  21. Tensor,为loss float tensor,如果`reduction`为"mean"或"sum",则shape为零;如果`reduction`为"none",则输出的shape为输入Tensor广播后的shape。
  22. 异常:
  23. ValueError:`reduction`不为"mean","sum",或"none"。
  24. 支持平台:
  25. ``Ascend`` ``GPU`` ``CPU``
  26. 示例:
  27. >>> #用例1:logits.shape = labels.shape = (3,)
  28. >>> loss = nn.MSELoss()
  29. >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
  30. >>> labels = Tensor(np.array([1, 1, 1]), mindspore.float32)
  31. >>> output = loss(logits, labels)
  32. >>> print(output)
  33. 1.6666667
  34. >>> #用例2:logits.shape = (3,), labels.shape = (2, 3)
  35. >>> loss = nn.MSELoss(reduction='none')
  36. >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
  37. >>> labels = Tensor(np.array([[1, 1, 1], [1, 2, 2]]), mindspore.float32)
  38. >>> output = loss(logits, labels)
  39. >>> print(output)
  40. [[0. 1. 4.] [0. 0. 1.]]