|
|
|
@@ -4649,6 +4649,17 @@ class GatherD(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
Gathers values along an axis specified by dim. |
|
|
|
|
|
|
|
For a 3-D tensor, the output is: |
|
|
|
output[i][j][k] = x[index[i][j][k]][j][k] # if dim == 0 |
|
|
|
|
|
|
|
output[i][j][k] = x[i][index[i][j][k]][k] # if dim == 1 |
|
|
|
|
|
|
|
output[i][j][k] = x[i][j][index[i][j][k]] # if dim == 2 |
|
|
|
|
|
|
|
If `x` is an n-D tensor with shape :math:`(z_0, z_1, ..., z_i, ..., z_{n-1})` and `dim` = i, |
|
|
|
the `index` must be an n-D tensor with shape :math:`(z_0, z_1, ..., y, ..., z_{n-1})` |
|
|
|
where `y`>=1 and the output will have the same shape as `index`. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - The source tensor. |
|
|
|
- **dim** (int) - The axis along which to index. It must be int32. Only constant value is allowed. |
|
|
|
|