|
|
|
@@ -205,9 +205,9 @@ class TrainOneStepCell(Cell): |
|
|
|
>>> train_net = nn.TrainOneStepCell(loss_net, optim) |
|
|
|
>>> |
|
|
|
>>> #2) Using user-defined WithLossCell |
|
|
|
>>>class MyWithLossCell(nn.cell): |
|
|
|
>>>class MyWithLossCell(nn.Cell): |
|
|
|
>>> def __init__(self, backbone, loss_fn): |
|
|
|
>>> super(WithLossCell, self).__init__(auto_prefix=False) |
|
|
|
>>> super(MyWithLossCell, self).__init__(auto_prefix=False) |
|
|
|
>>> self._backbone = backbone |
|
|
|
>>> self._loss_fn = loss_fn |
|
|
|
>>> |
|
|
|
@@ -215,6 +215,10 @@ class TrainOneStepCell(Cell): |
|
|
|
>>> out = self._backbone(x, y) |
|
|
|
>>> return self._loss_fn(out, label) |
|
|
|
>>> |
|
|
|
>>> @property |
|
|
|
>>> def backbone_network(self): |
|
|
|
>>> return self._backbone |
|
|
|
>>> |
|
|
|
>>> loss_net = MyWithLossCell(net, loss_fn) |
|
|
|
>>> train_net = nn.TrainOneStepCell(loss_net, optim) |
|
|
|
""" |
|
|
|
|