diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index c1fb36635c..87b4037e8d 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -205,9 +205,9 @@ class TrainOneStepCell(Cell): >>> train_net = nn.TrainOneStepCell(loss_net, optim) >>> >>> #2) Using user-defined WithLossCell - >>>class MyWithLossCell(nn.cell): + >>>class MyWithLossCell(nn.Cell): >>> def __init__(self, backbone, loss_fn): - >>> super(WithLossCell, self).__init__(auto_prefix=False) + >>> super(MyWithLossCell, self).__init__(auto_prefix=False) >>> self._backbone = backbone >>> self._loss_fn = loss_fn >>> @@ -215,6 +215,10 @@ class TrainOneStepCell(Cell): >>> out = self._backbone(x, y) >>> return self._loss_fn(out, label) >>> + >>> @property + >>> def backbone_network(self): + >>> return self._backbone + >>> >>> loss_net = MyWithLossCell(net, loss_fn) >>> train_net = nn.TrainOneStepCell(loss_net, optim) """