You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.ops.GatherD.rst 1.3 kB

123456789101112131415161718192021222324252627282930313233
  1. mindspore.ops.GatherD
  2. =======================
  3. .. py:class:: mindspore.ops.GatherD
  4. 获取指定轴的元素。
  5. 对于三维Tensor,输出为:
  6. .. code-block::
  7. output[i][j][k] = x[index[i][j][k]][j][k] # if dim == 0
  8. output[i][j][k] = x[i][index[i][j][k]][k] # if dim == 1
  9. output[i][j][k] = x[i][j][index[i][j][k]] # if dim == 2
  10. 如果 `x` 是shape为 :math:`(z_0, z_1, ..., z_i, ..., z_{n-1})` ,维度 `dim` 为i的n维Tensor,则 `index` 必须是shape为 :math:`(z_0, z_1, ..., y, ..., z_{n-1})` 的n维Tensor,其中 `y` 大于等于1,输出的shape与 `index` 相同。
  11. **输入:**
  12. - **x** (Tensor) - GatherD的输入,任意维度的Tensor。
  13. - **dim** (int) - 获取元素的轴。数据类型为int32或int64。只能是常量值。
  14. - **index** (Tensor) - 获取收集元素的索引。支持的数据类型包括:int32,int64。每个索引元素的取值范围为[-x_rank[dim], x_rank[dim])。
  15. **输出:**
  16. Tensor,shape为 :math:`(z_1, z_2, ..., z_N)` 的Tensor,数据类型与 `x` 相同。
  17. **异常:**
  18. - **TypeError** - `dim` 或 `index` 的数据类型既不是int32,也不是int64。
  19. - **ValueError** - `x` 的shape长度不等于 `index` 的shape长度。