| @@ -191,13 +191,12 @@ def get_bprop_tile(self): | |||
| return bprop | |||
| @bprop_getters.register(inner.EmbeddingLookup) | |||
| @bprop_getters.register(P.EmbeddingLookup) | |||
| def get_bprop_embedding_lookup(self): | |||
| """Generate bprop for EmbeddingLookup""" | |||
| sub_op = P.Sub() | |||
| reshape_op = P.Reshape() | |||
| host_reshape = P.Reshape().add_prim_attr('primitive_target', 'CPU') | |||
| def bprop_sparse(x, indices, offset, reduce_scatter_flag, split_num, out, dout): | |||
| def bprop_sparse(x, indices, offset, out, dout): | |||
| x_shp = shape_op(x) | |||
| new_indices = sub_op(indices, offset) | |||
| # Reshape the 'new_indices' | |||
| @@ -205,17 +204,9 @@ def get_bprop_embedding_lookup(self): | |||
| new_indices = reshape_op(new_indices, new_indices_shape_changed) | |||
| x_shp_tail = x_shp[1:] | |||
| actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail | |||
| if reduce_scatter_flag is True: | |||
| # On host | |||
| elu_grad = G.EmbeddingLookupCommGrad() | |||
| actual_dout = elu_grad(dout, split_num) | |||
| # Reshape the 'actual_dout' on host | |||
| actual_dout = host_reshape(actual_dout, actual_dout_shape_changed) | |||
| else: | |||
| # Reshape the 'actual_dout' on device | |||
| actual_dout = reshape_op(dout, actual_dout_shape_changed) | |||
| return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset), \ | |||
| zeros_like(reduce_scatter_flag), zeros_like(split_num) | |||
| # Reshape the 'actual_dout' on device | |||
| actual_dout = reshape_op(dout, actual_dout_shape_changed) | |||
| return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) | |||
| return bprop_sparse | |||
| @@ -32,7 +32,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, | |||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, | |||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | |||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence) | |||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup) | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice, | |||
| @@ -333,6 +333,7 @@ __all__ = [ | |||
| "Mod", | |||
| "PopulationCount", | |||
| "ParallelConcat", | |||
| "EmbeddingLookup" | |||
| ] | |||
| __all__.sort() | |||
| @@ -263,76 +263,6 @@ class AscendDequant(PrimitiveWithInfer): | |||
| return mstype.float16 | |||
| class EmbeddingLookup(PrimitiveWithInfer): | |||
| """ | |||
| Returns a slice of input tensor based on the specified indices. | |||
| This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs: | |||
| `offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices. | |||
| Inputs: | |||
| - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||
| The Tensor slice, instead of the entire Tensor. | |||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | |||
| Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, | |||
| and the exceeding part will be filled with 0 in the output. | |||
| - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices | |||
| are equal to `input_indices` minus `offset`. | |||
| - **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not. | |||
| Only constant value is allowed. | |||
| - **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable | |||
| is used only if `reduce_scatter_flag` is True. Only constant value is allowed. | |||
| Outputs: | |||
| Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | |||
| Examples: | |||
| >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) | |||
| >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) | |||
| >>> offset = 4 | |||
| >>> reduce_scatter_flag = False | |||
| >>> split_num = 1 | |||
| >>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num) | |||
| [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init index_select""" | |||
| self.__setattr_flag__ = True | |||
| self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'], | |||
| outputs=['output']) | |||
| self.add_prim_attr('primitive_target', 'CPU') | |||
| def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2): | |||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||
| validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) | |||
| validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name) | |||
| if split_num['value'] < 1: | |||
| raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num) | |||
| params_shp = params['shape'] | |||
| out_shape = indices['shape'] + params_shp[1:] | |||
| if reduce_scatter_flag is None: | |||
| raise ValueError("The value of 'reduce_scatter_flag' is None.") | |||
| reduce_scatter_flag_value = reduce_scatter_flag['value'] | |||
| if split_num is None: | |||
| raise ValueError("The value of 'split_num_value' is None.") | |||
| split_num_value = split_num['value'] | |||
| if reduce_scatter_flag_value is True: | |||
| # Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by | |||
| # (split_num * 8) | |||
| if out_shape[0] % (split_num_value * 8) != 0: | |||
| raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." % | |||
| (out_shape[0], (split_num_value * 8))) | |||
| # After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8 | |||
| out_shape[0] = out_shape[0] // 8 | |||
| out = {'shape': out_shape, | |||
| 'dtype': params['dtype'], | |||
| 'value': None} | |||
| return out | |||
| class SparseApplyFtrlNoReturn(PrimitiveWithInfer): | |||
| """ | |||
| Update relevant entries according to the FTRL-proximal scheme. | |||
| @@ -3236,3 +3236,50 @@ class TransShape(PrimitiveWithInfer): | |||
| return {'shape': shp, | |||
| 'dtype': dtype, | |||
| 'value': None} | |||
| class EmbeddingLookup(PrimitiveWithInfer): | |||
| """ | |||
| Returns a slice of input tensor based on the specified indices. | |||
| This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has one more inputs: | |||
| `offset`. | |||
| Inputs: | |||
| - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||
| The Tensor slice, instead of the entire Tensor. | |||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | |||
| Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, | |||
| and the exceeding part will be filled with 0 in the output. | |||
| - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices | |||
| are equal to `input_indices` minus `offset`. | |||
| Outputs: | |||
| Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | |||
| Examples: | |||
| >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) | |||
| >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) | |||
| >>> offset = 4 | |||
| >>> out = P.EmbeddingLookup()(input_params, input_indices, offset) | |||
| [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init index_select""" | |||
| self.__setattr_flag__ = True | |||
| self.init_prim_io_names(inputs=['params', 'indices', 'offset'], | |||
| outputs=['output']) | |||
| def __infer__(self, params, indices, offset): | |||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||
| validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) | |||
| params_shp = params['shape'] | |||
| if len(params_shp) != 2: | |||
| raise ValueError("The dimension of 'params' in EmbeddingLookup must be 2, but got %d." % len(params_shp)) | |||
| out_shape = indices['shape'] + params_shp[1:] | |||
| out = {'shape': out_shape, | |||
| 'dtype': params['dtype'], | |||
| 'value': None} | |||
| return out | |||
| @@ -19,7 +19,6 @@ import mindspore.nn as nn | |||
| from mindspore.common.api import _executor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore import Tensor, context | |||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||
| @@ -42,17 +41,15 @@ class NetWithLoss(nn.Cell): | |||
| return self.loss(predict) | |||
| class Net(nn.Cell): | |||
| def __init__(self, shape, offset, reduce_scatter_flag, split_num): | |||
| def __init__(self, shape, offset): | |||
| super().__init__() | |||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | |||
| self.offset = offset | |||
| self.reduce_scatter_flag = reduce_scatter_flag | |||
| self.split_num = split_num | |||
| self.elu = inner.EmbeddingLookup() | |||
| self.elu = P.EmbeddingLookup() | |||
| self.mm = P.BatchMatMul() | |||
| def construct(self, x, y): | |||
| out = self.elu(x, self.index, self.offset, self.reduce_scatter_flag, self.split_num) | |||
| out = self.elu(x, self.index, self.offset) | |||
| out = self.mm(out, y) | |||
| return out | |||
| @@ -60,9 +57,7 @@ class Net(nn.Cell): | |||
| def test_embeddinglookup_reducescatter_false(): | |||
| shape = [8, 8] | |||
| offset = 8 | |||
| reduce_scatter_flag = False | |||
| split_num = 1 | |||
| net = NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)) | |||
| net = NetWithLoss(Net(shape, offset)) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| @@ -71,11 +66,9 @@ def test_embeddinglookup_reducescatter_false(): | |||
| def test_embeddinglookup_reducescatter_true(): | |||
| shape = [64, 8] | |||
| shape = [8, 8] | |||
| offset = 8 | |||
| reduce_scatter_flag = True | |||
| split_num = 8 | |||
| net = NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)) | |||
| net = NetWithLoss(Net(shape, offset)) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| @@ -86,9 +79,7 @@ def test_embeddinglookup_reducescatter_true(): | |||
| def test_embeddinglookup_reducescatter_false_grad(): | |||
| shape = [8, 8] | |||
| offset = 8 | |||
| reduce_scatter_flag = False | |||
| split_num = 1 | |||
| net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num))) | |||
| net = GradWrap(NetWithLoss(Net(shape, offset))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| @@ -98,11 +89,9 @@ def test_embeddinglookup_reducescatter_false_grad(): | |||
| def test_embeddinglookup_reducescatter_true_grad(): | |||
| context.set_context(save_graphs=True) | |||
| shape = [64, 8] | |||
| shape = [8, 8] | |||
| offset = 8 | |||
| reduce_scatter_flag = True | |||
| split_num = 8 | |||
| net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num))) | |||
| net = GradWrap(NetWithLoss(Net(shape, offset))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| @@ -184,6 +185,7 @@ def test_gatherv2_auto1(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") | |||
| def test_gatherv2_cpu0(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((8, 1), (1, 1)) | |||
| @@ -196,6 +198,7 @@ def test_gatherv2_cpu0(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") | |||
| def test_gatherv2_cpu1(): | |||
| context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((16, 1), (1, 1)) | |||
| @@ -208,6 +211,7 @@ def test_gatherv2_cpu1(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") | |||
| def test_gatherv2_cpu2(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((1, 8), (1, 1)) | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| @@ -184,6 +185,7 @@ def test_gatherv2_auto1(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") | |||
| def test_gatherv2_cpu0(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((8, 1), (1, 1)) | |||
| @@ -196,6 +198,7 @@ def test_gatherv2_cpu0(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") | |||
| def test_gatherv2_cpu1(): | |||
| context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((16, 1), (1, 1)) | |||
| @@ -208,6 +211,7 @@ def test_gatherv2_cpu1(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") | |||
| def test_gatherv2_cpu2(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((1, 8), (1, 1)) | |||