| @@ -157,8 +157,23 @@ class TrainOneStepCell(Cell): | |||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() | >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() | ||||
| >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | ||||
| >>> #1) Using the WithLossCell existing provide | |||||
| >>> loss_net = nn.WithLossCell(net, loss_fn) | >>> loss_net = nn.WithLossCell(net, loss_fn) | ||||
| >>> train_net = nn.TrainOneStepCell(loss_net, optim) | >>> train_net = nn.TrainOneStepCell(loss_net, optim) | ||||
| >>> | |||||
| >>> #2) Using user-defined WithLossCell | |||||
| >>>class MyWithLossCell(nn.cell): | |||||
| >>> def __init__(self, backbone, loss_fn): | |||||
| >>> super(WithLossCell, self).__init__(auto_prefix=False) | |||||
| >>> self._backbone = backbone | |||||
| >>> self._loss_fn = loss_fn | |||||
| >>> | |||||
| >>> def construct(self, x, y, label): | |||||
| >>> out = self._backbone(x, y) | |||||
| >>> return self._loss_fn(out, label) | |||||
| >>> | |||||
| >>> loss_net = MyWithLossCell(net, loss_fn) | |||||
| >>> train_net = nn.TrainOneStepCell(loss_net, optim) | |||||
| """ | """ | ||||
| def __init__(self, network, optimizer, sens=1.0): | def __init__(self, network, optimizer, sens=1.0): | ||||
| super(TrainOneStepCell, self).__init__(auto_prefix=False) | super(TrainOneStepCell, self).__init__(auto_prefix=False) | ||||