You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.TrainOneStepCell.rst 898 B

4 years ago
4 years ago
123456789101112131415161718192021222324252627
  1. mindspore.nn.TrainOneStepCell
  2. =============================
  3. .. py:class:: mindspore.nn.TrainOneStepCell(network, optimizer, sens=1.0)
  4. 训练网络封装类。
  5. 封装 `network` 和 `optimizer` ,构建一个输入'\*inputs'的用于训练的Cell。
  6. 执行函数 `construct` 中会构建反向图以更新网络参数。支持不同的并行训练模式。
  7. **参数:**
  8. - **network** (Cell) - 训练网络。只支持单输出网络。
  9. - **optimizer** (Union[Cell]) - 用于更新网络参数的优化器。
  10. - **sens** (numbers.Number) - 反向传播的输入,缩放系数。默认值为1.0。
  11. **输入:**
  12. **(\*inputs)** (Tuple(Tensor)) - shape为 :math:`(N, \ldots)` 的Tensor组成的元组。
  13. **输出:**
  14. Tensor,损失函数值,其shape通常为 :math:`()` 。
  15. **异常:**
  16. **TypeError**:`sens` 不是numbers.Number。