|
|
@@ -576,19 +576,21 @@ class EmbeddingLookup(PrimitiveWithInfer): |
|
|
""" |
|
|
""" |
|
|
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar |
|
|
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar |
|
|
functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`. |
|
|
functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`. |
|
|
|
|
|
This primitive runs on the host instead of devices. |
|
|
|
|
|
|
|
|
Inputs: |
|
|
Inputs: |
|
|
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. |
|
|
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. |
|
|
The Tensor slice, instead of the entire Tensor. |
|
|
The Tensor slice, instead of the entire Tensor. |
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. |
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. |
|
|
Specifies the indices of elements of the original Tensor. Must be in the range |
|
|
|
|
|
`[0, input_param.shape()[axis])`. |
|
|
|
|
|
|
|
|
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, |
|
|
|
|
|
and the exceeding part will be filled with 0 in the output. |
|
|
- **axis** (int) - Specifies the dimension index to gather indices. |
|
|
- **axis** (int) - Specifies the dimension index to gather indices. |
|
|
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices |
|
|
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices |
|
|
are equal to `input_indices` minus `offset`. |
|
|
are equal to `input_indices` minus `offset`. |
|
|
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not. |
|
|
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not. |
|
|
|
|
|
Only constant value is allowed. |
|
|
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable |
|
|
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable |
|
|
is used only if `reduce_scatter_flag` is True. |
|
|
|
|
|
|
|
|
is used only if `reduce_scatter_flag` is True. Only constant value is allowed. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Outputs: |
|
|
Outputs: |
|
|
@@ -627,12 +629,20 @@ class EmbeddingLookup(PrimitiveWithInfer): |
|
|
if axis_v < 0: |
|
|
if axis_v < 0: |
|
|
axis_v += rank |
|
|
axis_v += rank |
|
|
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] |
|
|
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] |
|
|
if reduce_scatter_flag: |
|
|
|
|
|
# partition the tensor along the dimension 0. |
|
|
|
|
|
if out_shape[0] % split_num['value'] != 0: |
|
|
|
|
|
raise ValueError("The dimension 0 of the shape: %d, is not divisible by split_num: %d." % |
|
|
|
|
|
(out_shape[0], split_num['value'])) |
|
|
|
|
|
out_shape[0] = out_shape[0] // split_num['value'] |
|
|
|
|
|
|
|
|
if reduce_scatter_flag is None: |
|
|
|
|
|
raise ValueError("The value of 'reduce_scatter_flag' is None.") |
|
|
|
|
|
reduce_scatter_flag_value = reduce_scatter_flag['value'] |
|
|
|
|
|
if split_num is None: |
|
|
|
|
|
raise ValueError("The value of 'split_num_value' is None.") |
|
|
|
|
|
split_num_value = split_num['value'] |
|
|
|
|
|
if reduce_scatter_flag_value is True: |
|
|
|
|
|
# Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by |
|
|
|
|
|
# (split_num * 8) |
|
|
|
|
|
if out_shape[0] % (split_num_value * 8) != 0: |
|
|
|
|
|
raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." % |
|
|
|
|
|
(out_shape[0], (split_num_value * 8))) |
|
|
|
|
|
# After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8 |
|
|
|
|
|
out_shape[0] = out_shape[0] // 8 |
|
|
out = {'shape': out_shape, |
|
|
out = {'shape': out_shape, |
|
|
'dtype': params['dtype'], |
|
|
'dtype': params['dtype'], |
|
|
'value': None} |
|
|
'value': None} |
|
|
|