| @@ -18,10 +18,14 @@ from mindspore.common.tensor import Tensor | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore._checkparam import Validator | |||||
| from mindspore.communication.management import get_group_size | |||||
| from mindspore.train.parallel_utils import ParallelMode | |||||
| from mindspore.parallel._utils import _get_parallel_mode | |||||
| from ..cell import Cell | from ..cell import Cell | ||||
| from ..._checkparam import Validator as validator | |||||
| from ..._checkparam import Validator as validator, Rel | |||||
| __all__ = ['Embedding', 'EmbeddingLookup'] | |||||
| __all__ = ['Embedding', 'EmbeddingLookup', 'EmbeddingLookUpSplitMode'] | |||||
| class Embedding(Cell): | class Embedding(Cell): | ||||
| r""" | r""" | ||||
| @@ -114,29 +118,36 @@ class EmbeddingLookup(Cell): | |||||
| When 'target' is set to 'CPU', this module will use | When 'target' is set to 'CPU', this module will use | ||||
| P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which | P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which | ||||
| specified 'offset = 0' to lookup table. | specified 'offset = 0' to lookup table. | ||||
| when 'target' is set to 'DEVICE', this module will use P.GatherV2() which | |||||
| When 'target' is set to 'DEVICE', this module will use P.GatherV2() which | |||||
| specified 'axis = 0' to lookup table. | specified 'axis = 0' to lookup table. | ||||
| In field slice mode, the manual_shapes should be given. It is a tuple ,where | |||||
| the element is (vocab[i], offset[i]), vocab[i] is the row numbers for i-th | |||||
| part and offset[i] is the feature id offset for i-th part. The feature id in | |||||
| i-th part will be subtracted by offset[i] to ensure the id start from 0. | |||||
| Args: | Args: | ||||
| vocab_size (int): Size of the dictionary of embeddings. | |||||
| embedding_size (int): The size of each embedding vector. | |||||
| param_init (str): The initialize way of embedding table. Default: 'normal'. | |||||
| target (str): Specify the target where the op is executed. Default: 'CPU'. | target (str): Specify the target where the op is executed. Default: 'CPU'. | ||||
| slice_mode (str): The slicing way in semi auto parallel/auto parallel. Default: 'batch_slice'. | |||||
| manual_shapes (tuple): The accompaniment array in field slice mode. | |||||
| Inputs: | Inputs: | ||||
| - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| The Tensor slice, instead of the entire Tensor. | |||||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | ||||
| Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, | |||||
| and the exceeding part will be filled with 0 in the output. | |||||
| Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table, | |||||
| and the exceeding part will be filled with 0 in the output. Input_indices should only be a 2d tensor in | |||||
| this interface. | |||||
| Outputs: | Outputs: | ||||
| Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | ||||
| Examples: | Examples: | ||||
| >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) | |||||
| >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) | >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) | ||||
| >>> out = nn.EmbeddingLookup()(input_params, input_indices) | |||||
| [[[10, 11], [8 ,9]], [[14, 15], [12, 13]]] | |||||
| >>> out = nn.EmbeddingLookup(4,2)(input_indices) | |||||
| """ | """ | ||||
| def __init__(self, target='CPU'): | |||||
| def __init__(self, vocab_size, embedding_size, param_init='normal', | |||||
| target='CPU', slice_mode='batch_slice', manual_shapes=None): | |||||
| super(EmbeddingLookup, self).__init__() | super(EmbeddingLookup, self).__init__() | ||||
| self.target = target | self.target = target | ||||
| if target not in ('CPU', 'DEVICE'): | if target not in ('CPU', 'DEVICE'): | ||||
| @@ -144,10 +155,60 @@ class EmbeddingLookup(Cell): | |||||
| + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') | + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') | ||||
| self.gatherv2 = P.GatherV2() | self.gatherv2 = P.GatherV2() | ||||
| self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | ||||
| self.embedding_table = Parameter(initializer(param_init, [vocab_size, embedding_size]), | |||||
| name='embedding_table') | |||||
| parallel_mode = _get_parallel_mode() | |||||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||||
| if slice_mode == EmbeddingLookUpSplitMode.FIELD_SLICE and is_auto_parallel: | |||||
| if not manual_shapes: | |||||
| raise ValueError("in slice field mode, the manual_shapes should not be none") | |||||
| if not isinstance(manual_shapes, tuple): | |||||
| raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes))) | |||||
| for dim in manual_shapes: | |||||
| Validator.check_integer('manul shape dim', dim, 0, Rel.GT, self.cls_name) | |||||
| self.gatherv2.add_prim_attr("manual_split", manual_shapes) | |||||
| self.embeddinglookup.add_prim_attr("manual_split", manual_shapes) | |||||
| self.gatherv2.set_strategy(((get_group_size(), 1), (1, get_group_size()))) | |||||
| self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, get_group_size()))) | |||||
| elif slice_mode == EmbeddingLookUpSplitMode.TABLE_ROW_SLICE and is_auto_parallel: | |||||
| self.gatherv2.set_strategy(((get_group_size(), 1), (1, 1))) | |||||
| self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1))) | |||||
| elif slice_mode == EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE and is_auto_parallel: | |||||
| self.gatherv2.set_strategy(((1, get_group_size()), (1, 1))) | |||||
| self.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1))) | |||||
| elif slice_mode == EmbeddingLookUpSplitMode.BATCH_SLICE and is_auto_parallel: | |||||
| self.gatherv2.set_strategy(((1, 1), (get_group_size(), 1))) | |||||
| self.embeddinglookup.set_strategy(((1, 1), (get_group_size(), 1))) | |||||
| else: | |||||
| if is_auto_parallel: | |||||
| raise ValueError("slice_mode should support mode in nn.EmbeddingLookUpSplitMode, but get " | |||||
| + str(slice_mode)) | |||||
| def construct(self, params, indices): | |||||
| def construct(self, indices): | |||||
| if self.target == "CPU": | if self.target == "CPU": | ||||
| out = self.embeddinglookup(params, indices, 0) | |||||
| out = self.embeddinglookup(self.embedding_table, indices, 0) | |||||
| else: | else: | ||||
| out = self.gatherv2(params, indices, 0) | |||||
| out = self.gatherv2(self.embedding_table, indices, 0) | |||||
| return out | return out | ||||
| class EmbeddingLookUpSplitMode: | |||||
| """ | |||||
| EmbeddingLookUp slice options in auto parallel and semi auto parallel mode. | |||||
| There are five kinds of slice options, "BATCH_SLICE", "FIELD_SLICE", | |||||
| "TABLE_ROW_SLICE" and "TABLE_COLUMN_SLICE". Default: "BATCH_SLICE". | |||||
| - BATCH_SLICE: Slicing batch dimensions of indices. | |||||
| - FIELD_SLICE: Slicing field dimensions of indices. | |||||
| - TABLE_ROW_SLICE: Slicing row of table. | |||||
| - TABLE_COLUMN_SLICE: Slicing column of table. | |||||
| MODE_LIST: The list for all supported parallel modes. | |||||
| """ | |||||
| BATCH_SLICE = "batch_slice" | |||||
| FIELD_SLICE = "field_slice" | |||||
| TABLE_ROW_SLICE = "table_row_slice" | |||||
| TABLE_COLUMN_SLICE = "table_column_slice" | |||||
| MODE_LIST = [BATCH_SLICE, FIELD_SLICE, TABLE_ROW_SLICE, TABLE_COLUMN_SLICE] | |||||
| @@ -209,19 +209,22 @@ class WideDeepModel(nn.Cell): | |||||
| if is_auto_parallel and host_device_mix: | if is_auto_parallel and host_device_mix: | ||||
| self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),)) | self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),)) | ||||
| self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) | self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) | ||||
| self.deep_embeddinglookup = nn.EmbeddingLookup() | |||||
| self.deep_embeddinglookup.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1))) | |||||
| self.wide_embeddinglookup = nn.EmbeddingLookup() | |||||
| self.wide_embeddinglookup.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1))) | |||||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, | |||||
| slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE) | |||||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, | |||||
| slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_ROW_SLICE) | |||||
| self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1))) | self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1))) | ||||
| self.deep_reshape.add_prim_attr("skip_redistribution", True) | self.deep_reshape.add_prim_attr("skip_redistribution", True) | ||||
| self.reduce_sum.add_prim_attr("cross_batch", True) | self.reduce_sum.add_prim_attr("cross_batch", True) | ||||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | |||||
| elif parameter_server: | elif parameter_server: | ||||
| self.deep_embeddinglookup = nn.EmbeddingLookup() | |||||
| self.wide_embeddinglookup = nn.EmbeddingLookup() | |||||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) | |||||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1) | |||||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | |||||
| else: | else: | ||||
| self.deep_embeddinglookup = nn.EmbeddingLookup(target='DEVICE') | |||||
| self.wide_embeddinglookup = nn.EmbeddingLookup(target='DEVICE') | |||||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE') | |||||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE') | |||||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | |||||
| def construct(self, id_hldr, wt_hldr): | def construct(self, id_hldr, wt_hldr): | ||||
| """ | """ | ||||
| @@ -231,11 +234,11 @@ class WideDeepModel(nn.Cell): | |||||
| """ | """ | ||||
| mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | ||||
| # Wide layer | # Wide layer | ||||
| wide_id_weight = self.wide_embeddinglookup(self.wide_w, id_hldr) | |||||
| wide_id_weight = self.wide_embeddinglookup(id_hldr) | |||||
| wx = self.wide_mul(wide_id_weight, mask) | wx = self.wide_mul(wide_id_weight, mask) | ||||
| wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) | wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) | ||||
| # Deep layer | # Deep layer | ||||
| deep_id_embs = self.deep_embeddinglookup(self.embedding_table, id_hldr) | |||||
| deep_id_embs = self.deep_embeddinglookup(id_hldr) | |||||
| vx = self.deep_mul(deep_id_embs, mask) | vx = self.deep_mul(deep_id_embs, mask) | ||||
| deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim)) | deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim)) | ||||
| deep_in = self.dense_layer_1(deep_in) | deep_in = self.dense_layer_1(deep_in) | ||||
| @@ -24,8 +24,7 @@ from mindspore.common import dtype as mstype | |||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | from mindspore.nn import TrainOneStepCell, WithLossCell | ||||
| from mindspore.nn.optim import Adam | from mindspore.nn.optim import Adam | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common.initializer import TruncatedNormal, initializer | |||||
| from mindspore import Parameter | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| parser = argparse.ArgumentParser(description="test_sparse_embedding") | parser = argparse.ArgumentParser(description="test_sparse_embedding") | ||||
| parser.add_argument("--device_target", type=str, default="Ascend") | parser.add_argument("--device_target", type=str, default="Ascend") | ||||
| @@ -53,16 +52,13 @@ class LeNet5(nn.Cell): | |||||
| super(LeNet5, self).__init__() | super(LeNet5, self).__init__() | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.flatten = nn.Flatten() | self.flatten = nn.Flatten() | ||||
| self.embedding_table = Parameter( | |||||
| initializer("normal", (16, 4), mstype.float32), name="embedding_table" | |||||
| ) | |||||
| self.embedding = nn.EmbeddingLookup() | |||||
| self.embedding = nn.EmbeddingLookup(16, 4) | |||||
| self.relu = nn.ReLU() | self.relu = nn.ReLU() | ||||
| self.fc = fc_with_initialize(12, num_class) | self.fc = fc_with_initialize(12, num_class) | ||||
| def construct(self, x): | def construct(self, x): | ||||
| x = self.cast(x, mstype.int32) | x = self.cast(x, mstype.int32) | ||||
| x = self.embedding(self.embedding_table, x) | |||||
| x = self.embedding(x) | |||||
| x = self.flatten(x) | x = self.flatten(x) | ||||
| x = self.fc(x) | x = self.fc(x) | ||||
| return x | return x | ||||
| @@ -72,7 +68,7 @@ def do_sparse_embedding(ps=False): | |||||
| epoch = 10 | epoch = 10 | ||||
| net = LeNet5(10) | net = LeNet5(10) | ||||
| if ps: | if ps: | ||||
| net.embedding_table.set_param_ps() | |||||
| net.embedding.embedding_table.set_param_ps() | |||||
| optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters())) | optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters())) | ||||
| optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU") | optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU") | ||||
| @@ -421,17 +421,16 @@ def test_row_tensor_with_control_flow_if(): | |||||
| class EmbeddingLookUpBnNet(nn.Cell): | class EmbeddingLookUpBnNet(nn.Cell): | ||||
| def __init__(self, param_np, target='CPU'): | |||||
| def __init__(self, vocab_size, embedding_size, target='CPU'): | |||||
| super().__init__() | super().__init__() | ||||
| self.param = Parameter(Tensor(param_np), name="w1") | |||||
| self.embedding_lookup = nn.EmbeddingLookup(target=target) | |||||
| self.embedding_lookup = nn.EmbeddingLookup(vocab_size, embedding_size, param_init='ones', target=target) | |||||
| self.bn = nn.BatchNorm2d(num_features=3) | self.bn = nn.BatchNorm2d(num_features=3) | ||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| self.relu = nn.PReLU() | self.relu = nn.PReLU() | ||||
| def construct(self, indices): | def construct(self, indices): | ||||
| x = self.embedding_lookup(self.param, indices) | |||||
| x = self.embedding_lookup(indices) | |||||
| x = self.reshape(x, (2, 3, 2, 2)) | x = self.reshape(x, (2, 3, 2, 2)) | ||||
| x = self.relu(x) | x = self.relu(x) | ||||
| x = self.bn(x) | x = self.bn(x) | ||||
| @@ -439,10 +438,9 @@ class EmbeddingLookUpBnNet(nn.Cell): | |||||
| def test_embedding_lookup_with_mix_precision(): | def test_embedding_lookup_with_mix_precision(): | ||||
| param_np = np.ones([8, 8]).astype(np.float32) | |||||
| data = Tensor(np.array([0, 1, 2]).astype(np.int32)) | data = Tensor(np.array([0, 1, 2]).astype(np.int32)) | ||||
| label = Tensor(np.random.randn(*(2, 3, 2, 2)).astype(np.float32)) | label = Tensor(np.random.randn(*(2, 3, 2, 2)).astype(np.float32)) | ||||
| net = EmbeddingLookUpBnNet(param_np, target='CPU') | |||||
| net = EmbeddingLookUpBnNet(8, 8, target='CPU') | |||||
| criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') | criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') | ||||
| optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1) | optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1) | ||||
| @@ -69,14 +69,12 @@ def test_bprop_with_sparse_feature_mirror(): | |||||
| super(Net, self).__init__() | super(Net, self).__init__() | ||||
| if shape is None: | if shape is None: | ||||
| shape = [8, 8] | shape = [8, 8] | ||||
| weight = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| self.weight = Parameter(weight, "w") | |||||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | self.index = Tensor(np.ones(shape), dtype=ms.int32) | ||||
| self.embeddinglookup = nn.EmbeddingLookup() | |||||
| self.embeddinglookup = nn.EmbeddingLookup(64, 64, param_init='ones') | |||||
| self.embeddinglookup.embeddinglookup.set_strategy(((1, 1), (8, 1))) | self.embeddinglookup.embeddinglookup.set_strategy(((1, 1), (8, 1))) | ||||
| def construct(self, x, b): | def construct(self, x, b): | ||||
| out = self.embeddinglookup(self.weight, self.index) | |||||
| out = self.embeddinglookup(self.index) | |||||
| return out | return out | ||||