You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.GRUCell.rst 1.6 kB

12345678910111213141516171819202122232425262728293031323334353637
  1. mindspore.nn.GRUCell
  2. =====================
  3. .. py:class:: mindspore.nn.GRUCell(input_size: int, hidden_size: int, has_bias: bool = True)
  4. GRU(Gate Recurrent Unit)称为门控循环单元。
  5. .. math::
  6. \begin{array}{ll}
  7. r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
  8. z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
  9. n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
  10. h' = (1 - z) * n + z * h
  11. \end{array}
  12. 这里 :math:`\sigma` 是sigmoid激活函数, :math:`*` 是乘积。 :math:`W,b` 是公式中输出和输入之间的可学习权重。例如, :math:`W_{ir}, b_{ir}` 是用于将输入 :math:`x` 转换为 :math:`r` 的权重和偏置。详见论文 `Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation <https://aclanthology.org/D14-1179.pdf>`_ 。
  13. **参数:**
  14. - **input_size** (int) - 输入的大小。
  15. - **hidden_size** (int) - 隐藏状态大小。
  16. - **has_bias** (bool) - cell是否有偏置项 `b_ih` 和 `b_hh` 。默认值:True。
  17. **输入:**
  18. - **x** (Tensor) - shape为(batch_size, `input_size` )的Tensor。
  19. - **hx** (Tensor) - 数据类型为mindspore.float32、shape为(batch_size, `hidden_size` )的Tensor。 `hx` 的数据类型必须与 `x` 相同。
  20. **输出:**
  21. - **hx** (Tensor) - shape为(batch_size, `hidden_size`)的Tensor。
  22. **异常:**
  23. - **TypeError** - `input_size` 、 `hidden_size` 不是int。
  24. - **TypeError** - `has_bias` 不是bool值。