You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.VocabEmbedding.rst 2.0 kB

4 years ago
12345678910111213141516171819202122232425262728293031
  1. .. py:class:: mindspore.nn.transformer.VocabEmbedding(vocab_size, embedding_size, parallel_config=default_embedding_parallel_config, param_init="normal")
  2. 根据输入的索引查找参数表中的行作为返回值。当设置并行模式为 `AUTO_PARALLEL_MODE` 时,如果parallel_config.vocab_emb_dp为True时,那么embedding lookup表采用数据并行的方式,数据并行度为 `parallel_config.data_parallel` ,否则按 `parallel_config.model_parallel` 对embedding表中的第0维度进行切分。
  3. .. note::
  4. 启用 `AUTO_PARALLEL` / `SEMI_AUTO_PARALLEL` 模式时,此层仅支持2维度的输入,因为策略是为2D输入而配置的。
  5. **参数:**
  6. - **vocab_size** (int) - 表示查找表的大小。
  7. - **embedding_size** (int)- 表示查找表中每个嵌入向量的大小。
  8. - **param_init** (Union[Tensor, str, Initializer, numbers.Number])- 表示embedding_table的Initializer。当指定字符串时,请参见 `initializer` 类了解字符串的值。默认值:'normal'。
  9. - **parallel_config** (EmbeddingOpParallelConfig) - 表示网络的并行配置。默认值为 `default_embedding_parallel_config` ,表示带有默认参数的 `EmbeddingOpParallelConfig` 实例。
  10. **输入:**
  11. **input_ids** (Tensor) - shape为(batch_size, seq_length)的输入,其数据类型为int32。
  12. **输出:**
  13. Tuple,表示一个包含(`output`, `embedding_table`)的元组。
  14. - **output** (Tensor) - shape为(batch_size, seq_length, embedding_size)嵌入向量查找结果。
  15. - **weight** (Tensor) - shape为(vocab_size, embedding_size)的嵌入表。
  16. **异常:**
  17. - **ValueError** - parallel_config.vocab_emb_dp为True时,词典的大小不是parallel_config.model_parallel的倍数。
  18. - **ValueError** - `vocab_size` 不是正值。
  19. - **ValueError** - `embedding_size` 不是正值。
  20. - **TypeError** - `parallel_config` 不是OpParallelConfig的子类。