From 850c637794c71ff3739eaad59f917a03cf1a0143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=98=89=E7=90=AA?= Date: Mon, 17 Aug 2020 17:27:27 +0800 Subject: [PATCH] add comments --- mindspore/nn/wrap/cell_wrapper.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 9a3539373e..4e989a56b2 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -157,8 +157,23 @@ class TrainOneStepCell(Cell): >>> net = Net() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> #1) Using the WithLossCell existing provide >>> loss_net = nn.WithLossCell(net, loss_fn) >>> train_net = nn.TrainOneStepCell(loss_net, optim) + >>> + >>> #2) Using user-defined WithLossCell + >>>class MyWithLossCell(nn.cell): + >>> def __init__(self, backbone, loss_fn): + >>> super(WithLossCell, self).__init__(auto_prefix=False) + >>> self._backbone = backbone + >>> self._loss_fn = loss_fn + >>> + >>> def construct(self, x, y, label): + >>> out = self._backbone(x, y) + >>> return self._loss_fn(out, label) + >>> + >>> loss_net = MyWithLossCell(net, loss_fn) + >>> train_net = nn.TrainOneStepCell(loss_net, optim) """ def __init__(self, network, optimizer, sens=1.0): super(TrainOneStepCell, self).__init__(auto_prefix=False)