You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.OneHot.rst 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. mindspore.nn.OneHot
  2. ====================
  3. .. py:class:: mindspore.nn.OneHot(axis=-1, depth=1, on_value=1.0, off_value=0.0, dtype=mstype.float32)
  4. 对输入进行one-hot编码并返回。
  5. 输入的 `indices` 表示的位置取值为on_value,其他所有位置取值为off_value。
  6. .. note::
  7. 如果indices是n阶Tensor,那么返回的one-hot Tensor则为n+1阶Tensor。
  8. 如果 `indices` 是Scalar,则输出shape将是长度为 `depth` 的向量。
  9. 如果 `indices` 是长度为 `features` 的向量,则输出shape为:
  10. .. code-block::
  11. features * depth if axis == -1
  12. depth * features if axis == 0
  13. 如果 `indices` 是shape为 `[batch, features]` 的矩阵,则输出shape为:
  14. .. code-block::
  15. batch * features * depth if axis == -1
  16. batch * depth * features if axis == 1
  17. depth * batch * features if axis == 0
  18. **参数:**
  19. - **axis** (int) - 指定第几阶为 `depth` 维one-hot向量,如果轴为-1,则 `features * depth` ,如果轴为0,则 `depth * features` 。默认值:-1。
  20. - **depth** (int) - 定义one-hot向量的深度。默认值:1。
  21. - **on_value** (float) - one-hot值,当 `indices[j] = i` 时,填充output[i][j]的取值。默认值:1.0。
  22. - **off_value** (float) - 非one-hot值,当 `indices[j] != i` 时,填充output[i][j]的取值。默认值:0.0。
  23. - **dtype** (:class:`mindspore.dtype`) - 是'on_value'和'off_value'的数据类型,而不是输入的数据类型。默认值:mindspore.float32。
  24. **输入:**
  25. **indices** (Tensor) - 输入索引,任意维度的Tensor,数据类型为int32或int64。
  26. **输出:**
  27. Tensor,输出Tensor,数据类型 `dtype` 的one-hot Tensor,维度为 `axis` 扩展到 `depth`,并填充on_value和off_value。`Outputs` 的维度等于 `indices` 的维度加1。
  28. **异常:**
  29. - **TypeError** - `axis` 或 `depth` 不是int。
  30. - **TypeError** - `indices` 的dtype既不是int32,也不是int64。
  31. - **ValueError** - 如果 `axis` 不在范围[-1, len(indices_shape)]内。
  32. - **ValueError** - `depth` 小于0。