|
|
|
@@ -276,7 +276,7 @@ class GetNextSingleOp(Cell): |
|
|
|
>>> relu = P.ReLU() |
|
|
|
>>> result = relu(data).asnumpy() |
|
|
|
>>> print(result.shape) |
|
|
|
>>> (32, 1, 32, 32) |
|
|
|
(32, 1, 32, 32) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, dataset_types, dataset_shapes, queue_name): |
|
|
|
@@ -356,6 +356,7 @@ class WithEvalCell(Cell): |
|
|
|
Args: |
|
|
|
network (Cell): The network Cell. |
|
|
|
loss_fn (Cell): The loss Cell. |
|
|
|
add_cast_fp32 (bool): Adjust the data type to float32. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. |
|
|
|
@@ -410,7 +411,7 @@ class ParameterUpdate(Cell): |
|
|
|
>>> param = network.parameters_dict()['weight'] |
|
|
|
>>> update = nn.ParameterUpdate(param) |
|
|
|
>>> update.phase = "update_param" |
|
|
|
>>> weight = Tensor(np.arrange(12).reshape((4, 3)), mindspore.float32) |
|
|
|
>>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32) |
|
|
|
>>> network_updata = update(weight) |
|
|
|
""" |
|
|
|
|
|
|
|
|