Browse Source

!7197 modify loss_scale example

Merge pull request !7197 from lijiaqi/modify_eg
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
94b8277d2f
1 changed files with 6 additions and 1 deletions
  1. +6
    -1
      mindspore/nn/wrap/loss_scale.py

+ 6
- 1
mindspore/nn/wrap/loss_scale.py View File

@@ -193,16 +193,21 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
- **loss scaling value** (Tensor) - Tensor with shape :math:`()` - **loss scaling value** (Tensor) - Tensor with shape :math:`()`


Examples: Examples:
>>> #1) when the type scale_sense is Cell:
>>> net_with_loss = Net() >>> net_with_loss = Net()
>>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
>>> train_network.set_train() >>> train_network.set_train()
>>> >>>
>>> #2) when the type scale_sense is Tensor:
>>> net_with_loss = Net()
>>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) >>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32) >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
>>> output = train_network(inputs, label, scaling_sens)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens)
>>> output = train_network(inputs, label)
""" """
def __init__(self, network, optimizer, scale_sense): def __init__(self, network, optimizer, scale_sense):
super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None) super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)


Loading…
Cancel
Save