You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.TrainOneStepCell.rst 1.9 kB

4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. mindspore.nn.TrainOneStepCell
  2. =============================
  3. .. py:class:: mindspore.nn.TrainOneStepCell(network, optimizer, sens=1.0)
  4. 训练网络封装类。
  5. 封装 `network` 和 `optimizer` ,构建一个输入'\*inputs'的用于训练的Cell。
  6. 执行函数 `construct` 中会构建反向图以更新网络参数。支持不同的并行训练模式。
  7. **参数:**
  8. - **network** (Cell) - 训练网络。只支持单输出网络。
  9. - **optimizer** (Union[Cell]) - 用于更新网络参数的优化器。
  10. - **sens** (numbers.Number) - 反向传播的输入,缩放系数。默认值为1.0。
  11. **输入:**
  12. **(\*inputs)** (Tuple(Tensor)) - shape为 :math:`(N, \ldots)` 的Tensor组成的元组。
  13. **输出:**
  14. Tensor,损失函数值,其shape通常为 :math:`()` 。
  15. **异常:**
  16. **TypeError**:`sens` 不是numbers.Number。
  17. **支持平台:**
  18. ``Ascend`` ``GPU`` ``CPU``
  19. **样例:**
  20. >>> net = Net()
  21. >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
  22. >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  23. >>> # 1)使用MindSpore提供的WithLossCell
  24. >>> loss_net = nn.WithLossCell(net, loss_fn)
  25. >>> train_net = nn.TrainOneStepCell(loss_net, optim)
  26. >>>
  27. >>> # 2)用户自定义的WithLossCell
  28. >>> class MyWithLossCell(Cell):
  29. ... def __init__(self, backbone, loss_fn):
  30. ... super(MyWithLossCell, self).__init__(auto_prefix=False)
  31. ... self._backbone = backbone
  32. ... self._loss_fn = loss_fn
  33. ...
  34. ... def construct(self, x, y, label):
  35. ... out = self._backbone(x, y)
  36. ... return self._loss_fn(out, label)
  37. ...
  38. ... @property
  39. ... def backbone_network(self):
  40. ... return self._backbone
  41. ...
  42. >>> loss_net = MyWithLossCell(net, loss_fn)
  43. >>> train_net = nn.TrainOneStepCell(loss_net, optim)