diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index 9389c021f9..490ccb3f90 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -193,16 +193,21 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): - **loss scaling value** (Tensor) - Tensor with shape :math:`()` Examples: + >>> #1) when the type scale_sense is Cell: >>> net_with_loss = Net() >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) >>> train_network.set_train() >>> + >>> #2) when the type scale_sense is Tensor: + >>> net_with_loss = Net() + >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) >>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32) - >>> output = train_network(inputs, label, scaling_sens) + >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens) + >>> output = train_network(inputs, label) """ def __init__(self, network, optimizer, scale_sense): super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)