Browse Source

!9441 modify example

From: @lijiaqi0612
Reviewed-by: @c_34,@zh_qh
Signed-off-by: @c_34
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
bfc6cca4f8
1 changed files with 6 additions and 2 deletions
  1. +6
    -2
      mindspore/nn/wrap/cell_wrapper.py

+ 6
- 2
mindspore/nn/wrap/cell_wrapper.py View File

@@ -205,9 +205,9 @@ class TrainOneStepCell(Cell):
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
>>>
>>> #2) Using user-defined WithLossCell
>>>class MyWithLossCell(nn.cell):
>>>class MyWithLossCell(nn.Cell):
>>> def __init__(self, backbone, loss_fn):
>>> super(WithLossCell, self).__init__(auto_prefix=False)
>>> super(MyWithLossCell, self).__init__(auto_prefix=False)
>>> self._backbone = backbone
>>> self._loss_fn = loss_fn
>>>
@@ -215,6 +215,10 @@ class TrainOneStepCell(Cell):
>>> out = self._backbone(x, y)
>>> return self._loss_fn(out, label)
>>>
>>> @property
>>> def backbone_network(self):
>>> return self._backbone
>>>
>>> loss_net = MyWithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
"""


Loading…
Cancel
Save