You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.TrainOneStepWithLossScaleCell.rst 6.2 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. mindspore.nn.TrainOneStepWithLossScaleCell
  2. ==========================================
  3. .. py:class:: mindspore.nn.TrainOneStepWithLossScaleCell(network, optimizer, scale_sense)
  4. 使用梯度放大功能(loss scale)的训练网络。
  5. 实现了包含梯度放大功能的单次训练。它使用网络、优化器和用于更新梯度放大系数的Cell(或一个Tensor)作为参数。可在host侧或device侧更新梯度放大系数。
  6. 如果需要在host侧更新,使用Tensor作为 `scale_sense` ,否则,使用可更新梯度放大系数的Cell实例作为 `scale_sense` 。
  7. **参数:**
  8. - **network** (Cell) - 训练网络。仅支持单输出网络。
  9. - **optimizer** (Cell) - 用于更新网络参数的优化器。
  10. - **scale_sense** (Union[Tensor, Cell]) - 如果此值为Cell类型,`TrainOneStepWithLossScaleCell` 会调用它来更新梯度放大系数。如果此值为Tensor类型,可调用 `set_sense_scale` 来更新梯度放大系数,shape为 :math:`()` 或 :math:`(1,)` 。
  11. **输入:**
  12. **(*inputs)** (Tuple(Tensor))- shape为 :math:`(N, \ldots)` 的Tensor组成的元组。
  13. **输出:**
  14. Tuple,包含三个Tensor,分别为损失函数值、溢出状态和当前梯度放大系数。
  15. - **loss** (Tensor) - shape为 :math:`()` 的Tensor。
  16. - **overflow** (Tensor)- shape为 :math:`()` 的Tensor,类型为bool。
  17. - **loss scale** (Tensor)- shape为 :math:`()` 的Tensor。
  18. **异常:**
  19. - **TypeError** - `scale_sense` 既不是Cell,也不是Tensor。
  20. - **ValueError** - `scale_sense` 的shape既不是(1,)也不是()。
  21. **支持平台:**
  22. ``Ascend`` ``GPU``
  23. **样例:**
  24. >>> import numpy as np
  25. >>> from mindspore import Tensor, Parameter, nn, ops
  26. >>> from mindspore import dtype as mstype
  27. >>>
  28. >>> class Net(nn.Cell):
  29. ... def __init__(self, in_features, out_features):
  30. ... super(Net, self).__init__()
  31. ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
  32. ... name='weight')
  33. ... self.matmul = ops.MatMul()
  34. ...
  35. ... def construct(self, x):
  36. ... output = self.matmul(x, self.weight)
  37. ... return output
  38. ...
  39. >>> size, in_features, out_features = 16, 16, 10
  40. >>> #1)scale_sense类型为Cell时:
  41. >>> net = Net(in_features, out_features)
  42. >>> loss = nn.MSELoss()
  43. >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  44. >>> net_with_loss = nn.WithLossCell(net, loss)
  45. >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
  46. >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
  47. >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
  48. >>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
  49. >>> output = train_network(input, labels)
  50. >>>
  51. >>>> #2)当scale_sense类型为Tensor时:
  52. >>> net = Net(in_features, out_features)
  53. >>> loss = nn.MSELoss()
  54. >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  55. >>> net_with_loss = nn.WithLossCell(net, loss)
  56. >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
  57. >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
  58. >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
  59. >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens)
  60. >>> output = train_network(inputs, label)
  61. .. py:method:: get_overflow_status(status, compute_output)
  62. 获取浮点溢出状态。
  63. 溢出检测的目标过程执行完成后,获取溢出结果。继承该类自定义训练网络时,可复用该接口。
  64. **输入:**
  65. - **status** (object) - 用于检测溢出的状态实例。
  66. - **compute_output** - 对特定计算过程进行溢出检测时,将 `compute_output` 设置为该计算过程的输出,以确保在执行计算之前获取了 `status`。
  67. **输出:**
  68. bool,是否发生溢出。
  69. .. py:method:: process_loss_scale(overflow)
  70. 根据溢出状态计算梯度放大系数。继承该类自定义训练网络时,可复用该接口。
  71. **输入:**
  72. - **overflow** (bool) - 是否发生溢出。
  73. **输出:**
  74. bool,溢出状态,即输入。
  75. .. py:method:: set_sense_scale(sens)
  76. 如果使用了Tensor类型的 `scale_sense` ,可调用此函数修改它的值。
  77. **输入:**
  78. - **sens** (Tensor)- 新的梯度放大系数,其shape和类型需要与原始 `scale_sense` 相同。
  79. .. py:method:: start_overflow_check(pre_cond, compute_input)
  80. 启动浮点溢出检测。创建并清除溢出检测状态。
  81. 指定参数 `pre_cond` 和 `compute_input` ,以确保在正确的时间清除溢出状态。以当前接口为例,我们需要在损失函数计算后进行清除状态,在梯度计算过程中检测溢出。在这种情况下,`pre_cond` 应为损失函数的输出,而 `compute_input` 应为梯度计算函数的输入。继承该类自定义训练网络时,可复用该接口。
  82. **输入:**
  83. - **pre_cond** (Tensor) -启动溢出检测的先决条件。它决定溢出状态清除和先前处理的执行顺序。它确保函数 `start_overflow` 在执行完先决条件后清除状态。
  84. - **compute_input** (object) - 后续运算的输入。需要对特定的计算过程进行溢出检测。将 `compute_input` 设置这一计算过程的输入,以确保在执行该计算之前清除了溢出状态。
  85. **输出:**
  86. - **Tuple** [object, object],GPU后端的第一个值为False,而其他后端的第一个值是NPUAllocFloatStatus的实例。该值用于在 `get_overflow_status` 期间检测溢出。第二个值与 `compute_input` 的输入相同,用于控制执行序。