From 9aed617d688f60bb92581b2fac3e668ad1dc02f1 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Wed, 10 Mar 2021 17:24:14 +0800 Subject: [PATCH] Fix bug of API. --- mindspore/nn/loss/loss.py | 2 +- mindspore/ops/operations/nn_ops.py | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index c1ccbccc0d..7c8913bf85 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -973,7 +973,7 @@ class BCEWithLogitsLoss(_Loss): >>> predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32)) >>> target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32)) >>> loss = nn.BCEWithLogitsLoss() - >>> output = loss(inputs, labels) + >>> output = loss(predict, target) >>> print(output) 0.3463612 """ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index c6a4ed40a9..f3ec6f3c1c 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1885,7 +1885,7 @@ class MaxPool3D(PrimitiveWithInfer): ``Ascend`` Examples: - >>> input = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32) + >>> input = Tensor(np.arange(1 * 2 * 2 * 2 * 3).reshape((1, 2, 2, 2, 3)), mindspore.float32) >>> max_pool3d = ops.MaxPool3D(kernel_size=2, strides=1, pad_mode="valid") >>> output = max_pool3d(input) >>> print(output) @@ -7094,7 +7094,22 @@ class BasicLSTMCell(PrimitiveWithInfer): class DynamicRNN(PrimitiveWithInfer): r""" - DynamicRNN Operator. + Applies a recurrent neural network to the input. + Only long short-term memory (LSTM) currently supported. + + .. math:: + \begin{array}{ll} \\ + i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\ + f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\ + \tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\ + o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\ + c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t \\ + h_t = o_t * \tanh(c_t) \\ + \end{array} + + Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b` + are learnable weights between the output and the input in the formula. For instance, + :math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`. Args: cell_type (str): A string identifying the cell type in the op. Default: 'LSTM'.