You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.WithEvalCell.rst 1.1 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536
  1. mindspore.nn.WithEvalCell
  2. =========================
  3. .. py:class:: mindspore.nn.WithEvalCell(network, loss_fn, add_cast_fp32=False)
  4. 封装前向网络和损失函数,返回用于计算评估指标的损失函数值、前向输出和标签。
  5. **参数:**
  6. - **network** (Cell) - 前向网络。
  7. - **loss_fn** (Cell) - 损失函数。
  8. - **add_cast_fp32** (bool):是否将数据类型调整为float32。默认值:False。
  9. **输入:**
  10. - **data** (Tensor) - shape为 :math:`(N, \ldots)` 的Tensor。
  11. - **label** (Tensor) - shape为 :math:`(N, \ldots)` 的Tensor。
  12. **输出:**
  13. Tuple(Tensor),包括标量损失函数、shape为 :math:`(N, \ldots)` 的网络输出和shape为 :math:`(N, \ldots)` 的标签。
  14. **异常:**
  15. **TypeError**: `add_cast_fp32` 不是bool。
  16. **支持平台:**
  17. ``Ascend`` ``GPU`` ``CPU``
  18. **样例:**
  19. >>> # 未包含损失函数的前向网络
  20. >>> net = Net()
  21. >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
  22. >>> eval_net = nn.WithEvalCell(net, loss_fn)