|
|
|
@@ -273,8 +273,8 @@ class EmbeddingLookup(Cell): |
|
|
|
|
|
|
|
class MultiFieldEmbeddingLookup(EmbeddingLookup): |
|
|
|
r""" |
|
|
|
Returns a slice of input tensor based on the specified indices based on the field ids. This operation |
|
|
|
supports looking up embeddings within multi hot and one hot fields simultaneously. |
|
|
|
Returns a slice of input tensor based on the specified indices and the field ids. This operation |
|
|
|
supports looking up embeddings using multi hot and one hot fields simultaneously. |
|
|
|
|
|
|
|
Note: |
|
|
|
When 'target' is set to 'CPU', this module will use |
|
|
|
@@ -282,13 +282,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup): |
|
|
|
specified 'offset = 0' to lookup table. |
|
|
|
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which |
|
|
|
specified 'axis = 0' to lookup table. |
|
|
|
The vectors with the same field_ids will be combined by the `operator`, such as `SUM`, `MAX` and |
|
|
|
`MEAN`. Ensure the input_values of the padded id is zero, so that they can be ignored. The final |
|
|
|
The vectors with the same field_ids will be combined by the 'operator', such as 'SUM', 'MAX' and |
|
|
|
'MEAN'. Ensure the input_values of the padded id is zero, so that they can be ignored. The final |
|
|
|
output will be zeros if the sum of absolute weight of the field is zero. This class only |
|
|
|
supports ['table_row_slice', 'batch_slice' and 'table_column_slice'] |
|
|
|
|
|
|
|
Args: |
|
|
|
vocab_size (int): Size of the dictionary of embeddings. |
|
|
|
vocab_size (int): The size of the dictionary of embeddings. |
|
|
|
embedding_size (int): The size of each embedding vector. |
|
|
|
field_size (int): The field size of the final outputs. |
|
|
|
param_init (str): The initialize way of embedding table. Default: 'normal'. |
|
|
|
@@ -296,7 +296,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup): |
|
|
|
['DEVICE', 'CPU']. Default: 'CPU'. |
|
|
|
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through |
|
|
|
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE. |
|
|
|
feature_num_list (tuple): The accompaniment array in field slice mode. |
|
|
|
feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently. |
|
|
|
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 |
|
|
|
or None. Default: None |
|
|
|
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. |
|
|
|
@@ -410,7 +410,6 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup): |
|
|
|
|
|
|
|
batch_size = self.shape(input_indices)[0] |
|
|
|
num_segments = batch_size * self.field_size |
|
|
|
|
|
|
|
bias = Range(0, num_segments, self.field_size)() |
|
|
|
bias = self.reshape(bias, (self.field_size, -1)) |
|
|
|
field_ids = self.bias_add(field_ids, bias) |
|
|
|
|