From 006879fbb8034767c6a7776673748decb58632a1 Mon Sep 17 00:00:00 2001 From: Jiaqi Date: Thu, 3 Dec 2020 17:23:13 +0800 Subject: [PATCH] modify example --- mindspore/nn/wrap/cell_wrapper.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index f283dd834d..437d77c02d 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -162,9 +162,9 @@ class TrainOneStepCell(Cell): >>> train_net = nn.TrainOneStepCell(loss_net, optim) >>> >>> #2) Using user-defined WithLossCell - >>>class MyWithLossCell(nn.cell): + >>>class MyWithLossCell(nn.Cell): >>> def __init__(self, backbone, loss_fn): - >>> super(WithLossCell, self).__init__(auto_prefix=False) + >>> super(MyWithLossCell, self).__init__(auto_prefix=False) >>> self._backbone = backbone >>> self._loss_fn = loss_fn >>> @@ -172,6 +172,10 @@ class TrainOneStepCell(Cell): >>> out = self._backbone(x, y) >>> return self._loss_fn(out, label) >>> + >>> @property + >>> def backbone_network(self): + >>> return self._backbone + >>> >>> loss_net = MyWithLossCell(net, loss_fn) >>> train_net = nn.TrainOneStepCell(loss_net, optim) """