Merge pull request !2163 from Xiaoda/3-changing-embeddinglookup-internaltags/v0.5.0-beta
| @@ -191,7 +191,7 @@ def get_bprop_tile(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.EmbeddingLookup) | |||||
| @bprop_getters.register(inner.EmbeddingLookup) | |||||
| def get_bprop_embedding_lookup(self): | def get_bprop_embedding_lookup(self): | ||||
| """Generate bprop for EmbeddingLookup""" | """Generate bprop for EmbeddingLookup""" | ||||
| host_sub = P.Sub().add_prim_attr('primitive_target', 'CPU') | host_sub = P.Sub().add_prim_attr('primitive_target', 'CPU') | ||||
| @@ -26,7 +26,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, | Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, | ||||
| SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, | SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, | ||||
| ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ||||
| Shape, Size, Slice, Split, EmbeddingLookup, | |||||
| Shape, Size, Slice, Split, | |||||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, | Squeeze, StridedSlice, Tile, TensorScatterUpdate, | ||||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, | Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, | ||||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | ||||
| @@ -138,7 +138,6 @@ __all__ = [ | |||||
| 'ReduceSum', | 'ReduceSum', | ||||
| 'ReduceMean', | 'ReduceMean', | ||||
| 'LayerNorm', | 'LayerNorm', | ||||
| 'EmbeddingLookup', | |||||
| 'Rank', | 'Rank', | ||||
| 'Less', | 'Less', | ||||
| 'LessEqual', | 'LessEqual', | ||||
| @@ -258,3 +258,73 @@ class AscendDequant(PrimitiveWithInfer): | |||||
| validator.check_type_name("x", x_type, [mstype.int32], self.name) | validator.check_type_name("x", x_type, [mstype.int32], self.name) | ||||
| validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name) | validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name) | ||||
| return mstype.float16 | return mstype.float16 | ||||
| class EmbeddingLookup(PrimitiveWithInfer): | |||||
| """ | |||||
| Returns a slice of input tensor based on the specified indices. | |||||
| This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs: | |||||
| `offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices. | |||||
| Inputs: | |||||
| - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| The Tensor slice, instead of the entire Tensor. | |||||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | |||||
| Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, | |||||
| and the exceeding part will be filled with 0 in the output. | |||||
| - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices | |||||
| are equal to `input_indices` minus `offset`. | |||||
| - **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not. | |||||
| Only constant value is allowed. | |||||
| - **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable | |||||
| is used only if `reduce_scatter_flag` is True. Only constant value is allowed. | |||||
| Outputs: | |||||
| Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | |||||
| Examples: | |||||
| >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) | |||||
| >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) | |||||
| >>> offset = 4 | |||||
| >>> reduce_scatter_flag = False | |||||
| >>> split_num = 1 | |||||
| >>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num) | |||||
| [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init index_select""" | |||||
| self.__setattr_flag__ = True | |||||
| self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'], | |||||
| outputs=['output']) | |||||
| self.add_prim_attr('primitive_target', 'CPU') | |||||
| def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2): | |||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||||
| validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) | |||||
| validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name) | |||||
| if split_num['value'] < 1: | |||||
| raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num) | |||||
| params_shp = params['shape'] | |||||
| out_shape = indices['shape'] + params_shp[1:] | |||||
| if reduce_scatter_flag is None: | |||||
| raise ValueError("The value of 'reduce_scatter_flag' is None.") | |||||
| reduce_scatter_flag_value = reduce_scatter_flag['value'] | |||||
| if split_num is None: | |||||
| raise ValueError("The value of 'split_num_value' is None.") | |||||
| split_num_value = split_num['value'] | |||||
| if reduce_scatter_flag_value is True: | |||||
| # Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by | |||||
| # (split_num * 8) | |||||
| if out_shape[0] % (split_num_value * 8) != 0: | |||||
| raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." % | |||||
| (out_shape[0], (split_num_value * 8))) | |||||
| # After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8 | |||||
| out_shape[0] = out_shape[0] // 8 | |||||
| out = {'shape': out_shape, | |||||
| 'dtype': params['dtype'], | |||||
| 'value': None} | |||||
| return out | |||||
| @@ -558,76 +558,6 @@ class SparseGatherV2(GatherV2): | |||||
| """ | """ | ||||
| class EmbeddingLookup(PrimitiveWithInfer): | |||||
| """ | |||||
| Returns a slice of input tensor based on the specified indices. | |||||
| This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs: | |||||
| `offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices. | |||||
| Inputs: | |||||
| - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| The Tensor slice, instead of the entire Tensor. | |||||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | |||||
| Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, | |||||
| and the exceeding part will be filled with 0 in the output. | |||||
| - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices | |||||
| are equal to `input_indices` minus `offset`. | |||||
| - **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not. | |||||
| Only constant value is allowed. | |||||
| - **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable | |||||
| is used only if `reduce_scatter_flag` is True. Only constant value is allowed. | |||||
| Outputs: | |||||
| Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | |||||
| Examples: | |||||
| >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) | |||||
| >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) | |||||
| >>> offset = 4 | |||||
| >>> reduce_scatter_flag = False | |||||
| >>> split_num = 1 | |||||
| >>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num) | |||||
| [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init index_select""" | |||||
| self.__setattr_flag__ = True | |||||
| self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'], | |||||
| outputs=['output']) | |||||
| self.add_prim_attr('primitive_target', 'CPU') | |||||
| def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2): | |||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||||
| validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) | |||||
| validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name) | |||||
| if split_num['value'] < 1: | |||||
| raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num) | |||||
| params_shp = params['shape'] | |||||
| out_shape = indices['shape'] + params_shp[1:] | |||||
| if reduce_scatter_flag is None: | |||||
| raise ValueError("The value of 'reduce_scatter_flag' is None.") | |||||
| reduce_scatter_flag_value = reduce_scatter_flag['value'] | |||||
| if split_num is None: | |||||
| raise ValueError("The value of 'split_num_value' is None.") | |||||
| split_num_value = split_num['value'] | |||||
| if reduce_scatter_flag_value is True: | |||||
| # Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by | |||||
| # (split_num * 8) | |||||
| if out_shape[0] % (split_num_value * 8) != 0: | |||||
| raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." % | |||||
| (out_shape[0], (split_num_value * 8))) | |||||
| # After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8 | |||||
| out_shape[0] = out_shape[0] // 8 | |||||
| out = {'shape': out_shape, | |||||
| 'dtype': params['dtype'], | |||||
| 'value': None} | |||||
| return out | |||||
| class Split(PrimitiveWithInfer): | class Split(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Splits input tensor into output_num of tensors along the given axis and output numbers. | Splits input tensor into output_num of tensors along the given axis and output numbers. | ||||
| @@ -19,6 +19,7 @@ import mindspore.nn as nn | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.common.api import _executor | from mindspore.common.api import _executor | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | from tests.ut.python.ops.test_math_ops import VirtualLoss | ||||
| @@ -39,7 +40,7 @@ class Net(nn.Cell): | |||||
| self.offset = offset | self.offset = offset | ||||
| self.reduce_scatter_flag = reduce_scatter_flag | self.reduce_scatter_flag = reduce_scatter_flag | ||||
| self.split_num = split_num | self.split_num = split_num | ||||
| self.elu = P.EmbeddingLookup() | |||||
| self.elu = inner.EmbeddingLookup() | |||||
| self.mm = P.BatchMatMul() | self.mm = P.BatchMatMul() | ||||
| def construct(self, x, y): | def construct(self, x, y): | ||||