|
|
|
@@ -26,9 +26,10 @@ from mindspore._checkparam import Rel |
|
|
|
from mindspore._checkparam import Validator as validator |
|
|
|
from mindspore.ops.primitive import constexpr |
|
|
|
from .basic import ClipByNorm |
|
|
|
from .math import Range |
|
|
|
from ..cell import Cell |
|
|
|
|
|
|
|
__all__ = ['Embedding', 'EmbeddingLookup'] |
|
|
|
__all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup'] |
|
|
|
|
|
|
|
|
|
|
|
class Embedding(Cell): |
|
|
|
@@ -268,3 +269,190 @@ class EmbeddingLookup(Cell): |
|
|
|
clip_by_norm = ClipByNorm(axis) |
|
|
|
out = clip_by_norm(out, self.max_norm) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class MultiFieldEmbeddingLookup(EmbeddingLookup): |
|
|
|
r""" |
|
|
|
Returns a slice of input tensor based on the specified indices based on the filed ids. This operation |
|
|
|
supports looking up embeddings within multi hot and one hot fields simultaneously. |
|
|
|
|
|
|
|
Note: |
|
|
|
When 'target' is set to 'CPU', this module will use |
|
|
|
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which |
|
|
|
specified 'offset = 0' to lookup table. |
|
|
|
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which |
|
|
|
specified 'axis = 0' to lookup table. |
|
|
|
The vectors with the same field_ids will be combined by the `operator`, such as `SUM`, `MAX` and |
|
|
|
`MEAN`. Ensure the input_values of the padded id is zero, so that they can be ignored. The final |
|
|
|
output will be zeros if the summed of absolute weight of the field is zero. This class only |
|
|
|
supports ['table_row_slice', 'batch_slice' and 'table_column_slice'] |
|
|
|
|
|
|
|
Args: |
|
|
|
vocab_size (int): Size of the dictionary of embeddings. |
|
|
|
embedding_size (int): The size of each embedding vector. |
|
|
|
field_size (int): The field size of the final outputs. |
|
|
|
param_init (str): The initialize way of embedding table. Default: 'normal'. |
|
|
|
target (str): Specifies the target where the op is executed. The value must in |
|
|
|
['DEVICE', 'CPU']. Default: 'CPU'. |
|
|
|
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through |
|
|
|
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE. |
|
|
|
feature_num_list (tuple): The accompaniment array in field slice mode. |
|
|
|
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 |
|
|
|
or None. Default: None |
|
|
|
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. |
|
|
|
operator (string): The pooling method for the features in one field. Support `SUM`, `MEAN` and 'MAX' |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`. |
|
|
|
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table, |
|
|
|
and the exceeding part will be filled with 0 in the output. Input_indices must be a 2d tensor in |
|
|
|
this interface. Type is Int16, Int32, Int64. |
|
|
|
- **input_values** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`. |
|
|
|
Specifies the weights of elements of the input_indices. The lookout vector will multiply with |
|
|
|
the input_values. Type is Float32. |
|
|
|
- **field_ids** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`. |
|
|
|
Specifics the field id of elements of the input_indices. Type is Type is Int16, Int32, Int64. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, the shape of tensor is :math:`(batch_size, field_size, embedding_size)`. Type is Float32. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> input_indices = Tensor([[2, 4, 6, 0, 0], [1, 3, 5, 0, 0]], mindspore.int32) |
|
|
|
>>> input_values = Tensor([[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], mindspore.float32) |
|
|
|
>>> field_ids = Tensor([[0, 1, 1, 0, 0], [0, 0, 1, 0, 0]], mindspore.int32) |
|
|
|
>>> net = nn.MultiFieldEmbeddingLookup(10, 2, field_size=2, operator='SUM') |
|
|
|
>>> out = net(input_indices, input_values, field_ids) |
|
|
|
>>> print(result) |
|
|
|
[[[-0.00478983 -0.00772568] |
|
|
|
[-0.00968955 -0.00064902]] |
|
|
|
[[-0.01251151 -0.01251151] |
|
|
|
[-0.00196387 -0.00196387] |
|
|
|
""" |
|
|
|
OPERATOR_SUM = 'SUM' |
|
|
|
OPERATOR_MEAN = 'MEAN' |
|
|
|
OPERATOR_MAX = 'MAX' |
|
|
|
def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU', |
|
|
|
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'): |
|
|
|
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target, |
|
|
|
slice_mode, feature_num_list, max_norm, sparse) |
|
|
|
self.field_size = field_size |
|
|
|
self.operator = operator |
|
|
|
|
|
|
|
self.mul = P.Mul() |
|
|
|
self.inf_mask_mul = P.Mul() |
|
|
|
self.bias_add = P.TensorAdd() |
|
|
|
self.inf_add = P.TensorAdd() |
|
|
|
self.merge_op = None |
|
|
|
self.count_op = P.UnsortedSegmentSum() |
|
|
|
self.abs = P.Abs() |
|
|
|
self.equal = P.Equal() |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.cast = P.Cast() |
|
|
|
self.div_no_nan = P.DivNoNan() |
|
|
|
self.expand = P.ExpandDims() |
|
|
|
self.max_mask_mul = P.Mul() |
|
|
|
self.max_no_equal = P.NotEqual() |
|
|
|
|
|
|
|
if operator == MultiFieldEmbeddingLookup.OPERATOR_SUM: |
|
|
|
self.merge_op = P.UnsortedSegmentSum() |
|
|
|
elif operator == MultiFieldEmbeddingLookup.OPERATOR_MAX: |
|
|
|
self.merge_op = P.UnsortedSegmentMax() |
|
|
|
elif operator == MultiFieldEmbeddingLookup.OPERATOR_MEAN: |
|
|
|
self.merge_op = P.UnsortedSegmentSum() |
|
|
|
else: |
|
|
|
raise ValueError("The operator supports ['SUM', 'MAX', 'MEAN'], but found: "+str(operator)) |
|
|
|
|
|
|
|
parallel_mode = _get_parallel_mode() |
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) |
|
|
|
if slice_mode in ["table_row_slice", "batch_slice"] and is_auto_parallel: |
|
|
|
self.merge_op.shard(((get_group_size(), 1, 1), (get_group_size(), 1))) |
|
|
|
self.expand.shard(((get_group_size(),),)) |
|
|
|
self.bias_add.shard(((1, 1), (1, 1))) |
|
|
|
self.mul.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1))) |
|
|
|
self.count_op.shard(((get_group_size(), 1), (get_group_size(), 1))) |
|
|
|
self.add.shard(((get_group_size(),), (get_group_size(),))) |
|
|
|
self.div_no_nan.shard(((get_group_size(), 1), (get_group_size(), 1))) |
|
|
|
self.max_mask_mul.shard(((get_group_size(), 1), (get_group_size(), 1))) |
|
|
|
self.max_no_equal.shard(((1,), ())) |
|
|
|
if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX: |
|
|
|
self.equal.shard(((get_group_size(), 1, 1), ())) |
|
|
|
self.inf_mask_mul.shard(((get_group_size(), 1, 1), ())) |
|
|
|
self.merge_op.shard(((get_group_size(), 1), (get_group_size(),))) |
|
|
|
self.count_op.shard(((get_group_size(),), (get_group_size(),))) |
|
|
|
self.inf_add.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1))) |
|
|
|
elif slice_mode == "table_column_slice" and is_auto_parallel: |
|
|
|
self.merge_op.shard(((1, 1, get_group_size()), (1, 1))) |
|
|
|
self.div_no_nan.shard(((1, get_group_size()), (1, 1))) |
|
|
|
self.bias_add.shard(((1, 1), (1, 1))) |
|
|
|
self.mul.shard(((1, 1, 1), (1, 1, get_group_size()))) |
|
|
|
self.count_op.shard(((1, 1), (1, 1))) |
|
|
|
self.add.shard(((1,), (1,))) |
|
|
|
self.max_mask_mul.shard(((1, get_group_size()), (1, 1))) |
|
|
|
self.expand.shard(((1,),)) |
|
|
|
self.max_no_equal.shard(((1,), ())) |
|
|
|
if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX: |
|
|
|
self.equal.shard(((1, 1, 1), ())) |
|
|
|
self.inf_mask_mul.shard(((1, 1, 1), ())) |
|
|
|
self.merge_op.shard(((1, get_group_size()), (1,))) |
|
|
|
self.count_op.shard(((1,), (1,))) |
|
|
|
self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1))) |
|
|
|
else: |
|
|
|
if is_auto_parallel: |
|
|
|
raise ValueError("slice_mode should be ['table_row_slice', 'batch_slice' and \ |
|
|
|
'table_column_slice'], but get " + str(slice_mode)) |
|
|
|
|
|
|
|
# Min value for fp32 |
|
|
|
self.negative_inf_value = -3.402823466E+38 |
|
|
|
|
|
|
|
def construct(self, input_indices, input_values, field_ids): |
|
|
|
|
|
|
|
batch_size = self.shape(input_indices)[0] |
|
|
|
num_segments = batch_size * self.field_size |
|
|
|
|
|
|
|
bias = Range(0, num_segments, self.field_size)() |
|
|
|
bias = self.reshape(bias, (self.field_size, -1)) |
|
|
|
field_ids = self.bias_add(field_ids, bias) |
|
|
|
|
|
|
|
if self.target == "CPU": |
|
|
|
out = self.embeddinglookup(self.embedding_table, input_indices, 0) |
|
|
|
else: |
|
|
|
if self.forward_unique: |
|
|
|
shp = self.shape(input_indices) + (self.embedding_size,) |
|
|
|
indices_flatten = self.reshape(input_indices, (-1,)) |
|
|
|
unique_id, unique_idx = self.unique(indices_flatten) |
|
|
|
weight_unique = self.gatherv2(self.embedding_table, unique_id, 0) |
|
|
|
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0) |
|
|
|
out = self.reshape(weight_flatten, shp) |
|
|
|
else: |
|
|
|
out = self.gatherv2(self.embedding_table, input_indices, 0) |
|
|
|
if self.max_norm is not None: |
|
|
|
axis = _make_axis_range(F.rank(input_indices), F.rank(out)) |
|
|
|
clip_by_norm = ClipByNorm(axis) |
|
|
|
out = clip_by_norm(out, self.max_norm) |
|
|
|
|
|
|
|
weights = self.reshape(input_values, (batch_size, self.shape(input_indices)[1], 1)) |
|
|
|
embedding = self.mul(weights, out) |
|
|
|
|
|
|
|
if self.operator == 'MAX': |
|
|
|
# Fill the padding value to -inf, so the padded value will not influence the results |
|
|
|
negatvie_inf_mask = self.cast(self.equal(weights, 0), mstype.float32) |
|
|
|
inf_mask = self.inf_mask_mul(negatvie_inf_mask, self.negative_inf_value) |
|
|
|
embedding = self.inf_add(embedding, inf_mask) |
|
|
|
embedding = self.reshape(embedding, (-1, self.embedding_size)) |
|
|
|
field_ids = self.reshape(field_ids, (-1,)) |
|
|
|
|
|
|
|
merged_vectors = self.merge_op(embedding, field_ids, num_segments) |
|
|
|
|
|
|
|
if self.operator == 'MAX': |
|
|
|
value_count = self.count_op(self.abs(self.reshape(input_values, (-1,))), field_ids, num_segments) |
|
|
|
value_zeros = self.cast(self.max_no_equal(value_count, 0.0), mstype.float32) |
|
|
|
count = self.expand(value_zeros, -1) |
|
|
|
merged_vectors = self.max_mask_mul(merged_vectors, count) |
|
|
|
|
|
|
|
if self.operator == 'MEAN': |
|
|
|
value_count = self.count_op(self.abs(input_values), field_ids, num_segments) |
|
|
|
value_count = self.expand(value_count, -1) |
|
|
|
merged_vectors = self.div_no_nan(merged_vectors, value_count) |
|
|
|
|
|
|
|
merged_vectors = self.reshape(merged_vectors, (batch_size, self.field_size, -1)) |
|
|
|
return merged_vectors |