diff --git a/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py b/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py index 06f6a01370..073bd80985 100644 --- a/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +++ b/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py @@ -43,8 +43,7 @@ class WithBNNLossCell(Cell): Examples: >>> net = Net() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) - >>> net_with_criterion_object = WithBNNLossCell(net, loss_fn) - >>> net_with_criterion = net_with_criterion_object() + >>> net_with_criterion = WithBNNLossCell(net, loss_fn) >>> >>> batch_size = 2 >>> data = Tensor(np.ones([batch_size, 3, 64, 64]).astype(np.float32) * 0.01) diff --git a/mindspore/nn/probability/transforms/transform_bnn.py b/mindspore/nn/probability/transforms/transform_bnn.py index 0d1a08de08..eca6aec1b2 100644 --- a/mindspore/nn/probability/transforms/transform_bnn.py +++ b/mindspore/nn/probability/transforms/transform_bnn.py @@ -58,7 +58,7 @@ class TransformToBNN: >>> net = Net() >>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) - >>> net_with_loss = WithLossCell(network, criterion) + >>> net_with_loss = WithLossCell(net, criterion) >>> train_network = TrainOneStepCell(net_with_loss, optim) >>> bnn_transformer = TransformToBNN(train_network, 60000, 0.0001) """ @@ -115,7 +115,7 @@ class TransformToBNN: >>> net = Net() >>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) - >>> net_with_loss = WithLossCell(network, criterion) + >>> net_with_loss = WithLossCell(net, criterion) >>> train_network = TrainOneStepCell(net_with_loss, optim) >>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1) >>> train_bnn_network = bnn_transformer.transform_to_bnn_model() @@ -160,7 +160,7 @@ class TransformToBNN: >>> net = Net() >>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) - >>> net_with_loss = WithLossCell(network, criterion) + >>> net_with_loss = WithLossCell(net, criterion) >>> train_network = TrainOneStepCell(net_with_loss, optim) >>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1) >>> train_bnn_network = bnn_transformer.transform_to_bnn_layer(Dense, DenseReparam)