|
|
|
@@ -46,8 +46,8 @@ class WithBNNLossCell(Cell): |
|
|
|
>>> net_with_criterion = WithBNNLossCell(net, loss_fn) |
|
|
|
>>> |
|
|
|
>>> batch_size = 2 |
|
|
|
>>> data = Tensor(np.ones([batch_size, 3, 64, 64]).astype(np.float32) * 0.01) |
|
|
|
>>> label = Tensor(np.ones([batch_size, 1, 1, 1]).astype(np.int32)) |
|
|
|
>>> data = Tensor(np.ones([batch_size, 16]).astype(np.float32) * 0.01) |
|
|
|
>>> label = Tensor(np.ones([batch_size, 1]).astype(np.float32)) |
|
|
|
>>> |
|
|
|
>>> net_with_criterion(data, label) |
|
|
|
""" |
|
|
|
|