You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.ops.TensorScatterUpdate.rst 1.8 kB

1234567891011121314151617181920212223242526
  1. mindspore.ops.TensorScatterUpdate
  2. ==================================
  3. .. py:class:: mindspore.ops.TensorScatterUpdate
  4. 根据指定的更新值和输入索引,通过更新操作更新输入Tensor的值。此操作几乎等同于使用 :class:`mindspore.ops.ScatterNd` ,只是更新操作应用到 `input_x` Tensor而不是0。
  5. `indices` 的rank至少要为2,最后一个轴表示每个索引向量的深度。对于每个索引向量, `update` 中必须有相应的值。如果每个索引Tensor的深度与 `input_x` 的rank匹配,则每个索引向量对应于 `input_x` 中的Scalar,并且每次更新都会更新一个Scalar。如果每个索引Tensor的深度小于 `input_x` 的rank,则每个索引向量对应于 `input_x` 中的切片,并且每次更新都会更新一个切片。
  6. 更新的顺序是不确定的,这意味着如果 `indices` 中有多个索引向量对应于同一位置,则输出中该位置值是不确定的。
  7. **输入:**
  8. - **input_x** (Tensor) - TensorScatterUpdate的输入,任意维度的Tensor。其数据类型为数值型。 `input_x` 的维度必须不小于indices.shape[-1]。
  9. - **indices** (Tensor) - 输入Tensor的索引,数据类型为int32或int64。其rank必须至少为2。
  10. - **update** (Tensor) - 指定与 `input_x` 做更新操作的Tensor,其数据类型与输入相同。update.shape应等于indices.shape[:-1] + input_x.shape[indices.shape[-1]:]。
  11. **输出:**
  12. Tensor,shape和数据类型与输入 `input_x` 相同。
  13. **异常:**
  14. - **TypeError** - `indices` 的数据类型既不是int32,也不是int64。
  15. - **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。
  16. - **ValueError** - `input_x` 的值与输入 `indices` 不匹配。