You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.RNNCell.rst 1.5 kB

1234567891011121314151617181920212223242526272829303132333435
  1. mindspore.nn.RNNCell
  2. =====================
  3. .. py:class:: mindspore.nn.RNNCell(input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = 'tanh')
  4. 循环神经网络单元,激活函数是tanh或relu。
  5. .. math::
  6. h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})
  7. 其中 :math:`h_t` 是在 `t` 时刻的隐藏状态, :math:`x_t` 是在 `t` 时刻的输入, :math:`h_{(t-1)}` 是在 :math:`t-1` 时刻的隐藏状态,或初始隐藏状态。
  8. 如果 `nonlinearity` 是'relu',则使用'relu'而不是'tanh'。
  9. **参数:**
  10. - **input_size** (int) - 输入层输入的特征向量维度。
  11. - **hidden_size** (int) - 隐藏层输出的特征向量维度。
  12. - **has_bias** (bool) - Cell是否有偏置项 `b_ih` 和 `b_hh` 。默认值:True。
  13. - **nonlinearity** (str) - 用于选择非线性激活函数。取值可以是'tanh'或'relu'。默认值:'tanh'。
  14. **输入:**
  15. - **x** (Tensor) - 输入Tensor,其shape为 :math:`(batch\_size, input\_size)` 。
  16. - **hx** (Tensor) - 输入Tensor,其数据类型为mindspore.float32及shape为 :math:`(batch\_size, hidden\_size)` 。 `hx` 的数据类型与 `x` 相同。
  17. **输出:**
  18. - **hx'** (Tensor) - shape为 :math:`(batch\_size, hidden\_size)` 的Tensor。
  19. **异常:**
  20. - **TypeError** - `input_size` 或 `hidden_size` 不是int或不大于0。
  21. - **TypeError** - `has_bias` 不是bool。
  22. - **ValueError** - `nonlinearity` 不在['tanh', 'relu']中。