You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.AttentionMask.rst 1.0 kB

12345678910111213141516171819202122
  1. .. py:class:: mindspore.nn.AttentionMask(seq_length, parallel_config=default_dpmp_config)
  2. 从输入掩码中获取下三角矩阵。输入掩码是值为1或0的二维Tensor (batch_size, seq_length)。1表示当前位置是一个有效的标记,其他值则表示当前位置不是一个有效的标记。
  3. **参数:**
  4. - **seq_length** (int) - 表示输入Tensor的序列长度。
  5. - **parallel_config** (OpParallelConfig) - 表示并行配置。默认值为 `default_dpmp_config` ,表示一个带有默认参数的 `OpParallelConfig` 实例。
  6. **输入:**
  7. - **input_mask** (Tensor) - 掩码矩阵,shape为(batch_size, seq_length),表示每个位置是否为有效输入。
  8. **输出:**
  9. Tensor,表示shape为(batch_size, seq_length, seq_length)的注意力掩码矩阵。
  10. **异常:**
  11. - **TypeError** - `seq_length` 不是整数。
  12. - **ValueError** - `seq_length` 不是正数。
  13. - **TypeError** - `parallel_config` 不是OpParallelConfig的子类。