You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.FixedLossScaleUpdateCell.rst 2.1 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. mindspore.nn.FixedLossScaleUpdateCell
  2. =======================================
  3. .. py:class:: mindspore.nn.FixedLossScaleUpdateCell(loss_scale_value)
  4. 固定梯度放大系数的神经元。
  5. 该类是 :class:`mindspore.nn.FixedLossScaleManager` 的 `get_update_cell` 方法的返回值。训练过程中,类 :class:`mindspore.TrainOneStepWithLossScaleCell` 会调用该Cell。
  6. **参数:**
  7. - **loss_scale_value** (float) - 初始梯度放大系数。
  8. **输入:**
  9. - **loss_scale** (Tensor) - 训练期间的梯度放大系数,shape为 :math:`()`,在当前类中,该值被忽略。
  10. - **overflow** (bool) - 是否发生溢出。
  11. **输出:**
  12. Bool,即输入 `overflow`。
  13. **支持平台:**
  14. ``Ascend`` ``GPU``
  15. **样例:**
  16. >>> import numpy as np
  17. >>> from mindspore import Tensor, Parameter, nn, ops
  18. >>>
  19. >>> class Net(nn.Cell):
  20. ... def __init__(self, in_features, out_features):
  21. ... super(Net, self).__init__()
  22. ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
  23. ... name='weight')
  24. ... self.matmul = ops.MatMul()
  25. ...
  26. ... def construct(self, x):
  27. ... output = self.matmul(x, self.weight)
  28. ... return output
  29. ...
  30. >>> in_features, out_features = 16, 10
  31. >>> net = Net(in_features, out_features)
  32. >>> loss = nn.MSELoss()
  33. >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  34. >>> net_with_loss = nn.WithLossCell(net, loss)
  35. >>> manager = nn.FixedLossScaleUpdateCell(loss_scale_value=2**12)
  36. >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
  37. >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
  38. >>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
  39. >>> output = train_network(input, labels)
  40. .. py:method:: get_loss_scale()
  41. 获取当前梯度放大系数。