Browse Source

add comments

tags/v0.7.0-beta
李嘉琪 5 years ago
parent
commit
850c637794
1 changed files with 15 additions and 0 deletions
  1. +15
    -0
      mindspore/nn/wrap/cell_wrapper.py

+ 15
- 0
mindspore/nn/wrap/cell_wrapper.py View File

@@ -157,8 +157,23 @@ class TrainOneStepCell(Cell):
>>> net = Net() >>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> #1) Using the WithLossCell existing provide
>>> loss_net = nn.WithLossCell(net, loss_fn) >>> loss_net = nn.WithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim) >>> train_net = nn.TrainOneStepCell(loss_net, optim)
>>>
>>> #2) Using user-defined WithLossCell
>>>class MyWithLossCell(nn.cell):
>>> def __init__(self, backbone, loss_fn):
>>> super(WithLossCell, self).__init__(auto_prefix=False)
>>> self._backbone = backbone
>>> self._loss_fn = loss_fn
>>>
>>> def construct(self, x, y, label):
>>> out = self._backbone(x, y)
>>> return self._loss_fn(out, label)
>>>
>>> loss_net = MyWithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
""" """
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=False) super(TrainOneStepCell, self).__init__(auto_prefix=False)


Loading…
Cancel
Save