|
|
|
@@ -25,7 +25,7 @@ from mindspore.parallel._utils import _get_parallel_mode |
|
|
|
from ..cell import Cell |
|
|
|
from ..._checkparam import Validator as validator, Rel |
|
|
|
|
|
|
|
__all__ = ['Embedding', 'EmbeddingLookup', 'EmbeddingLookUpSplitMode'] |
|
|
|
__all__ = ['Embedding', 'EmbeddingLookup'] |
|
|
|
|
|
|
|
class Embedding(Cell): |
|
|
|
r""" |
|
|
|
@@ -131,7 +131,7 @@ class EmbeddingLookup(Cell): |
|
|
|
target (str): Specify the target where the op is executed. The value should in |
|
|
|
['DEVICE', 'CPU']. Default: 'CPU'. |
|
|
|
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value should get through |
|
|
|
nn.EmbeddingLookUpSplitMode. Default: nn.EmbeddingLookUpSplitMode.BATCH_SLICE. |
|
|
|
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE. |
|
|
|
manual_shapes (tuple): The accompaniment array in field slice mode. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
@@ -147,6 +147,11 @@ class EmbeddingLookup(Cell): |
|
|
|
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) |
|
|
|
>>> out = nn.EmbeddingLookup(4,2)(input_indices) |
|
|
|
""" |
|
|
|
BATCH_SLICE = "batch_slice" |
|
|
|
FIELD_SLICE = "field_slice" |
|
|
|
TABLE_ROW_SLICE = "table_row_slice" |
|
|
|
TABLE_COLUMN_SLICE = "table_column_slice" |
|
|
|
|
|
|
|
def __init__(self, vocab_size, embedding_size, param_init='normal', |
|
|
|
target='CPU', slice_mode='batch_slice', manual_shapes=None): |
|
|
|
super(EmbeddingLookup, self).__init__() |
|
|
|
@@ -160,7 +165,7 @@ class EmbeddingLookup(Cell): |
|
|
|
name='embedding_table') |
|
|
|
parallel_mode = _get_parallel_mode() |
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) |
|
|
|
if slice_mode == EmbeddingLookUpSplitMode.FIELD_SLICE and is_auto_parallel: |
|
|
|
if slice_mode == "field_slice" and is_auto_parallel: |
|
|
|
if not manual_shapes: |
|
|
|
raise ValueError("in slice field mode, the manual_shapes should not be none") |
|
|
|
if not isinstance(manual_shapes, tuple): |
|
|
|
@@ -171,18 +176,18 @@ class EmbeddingLookup(Cell): |
|
|
|
self.embeddinglookup.add_prim_attr("manual_split", manual_shapes) |
|
|
|
self.gatherv2.set_strategy(((get_group_size(), 1), (1, get_group_size()))) |
|
|
|
self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, get_group_size()))) |
|
|
|
elif slice_mode == EmbeddingLookUpSplitMode.TABLE_ROW_SLICE and is_auto_parallel: |
|
|
|
elif slice_mode == "table_row_slice" and is_auto_parallel: |
|
|
|
self.gatherv2.set_strategy(((get_group_size(), 1), (1, 1))) |
|
|
|
self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1))) |
|
|
|
elif slice_mode == EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE and is_auto_parallel: |
|
|
|
elif slice_mode == "table_column_slice" and is_auto_parallel: |
|
|
|
self.gatherv2.set_strategy(((1, get_group_size()), (1, 1))) |
|
|
|
self.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1))) |
|
|
|
elif slice_mode == EmbeddingLookUpSplitMode.BATCH_SLICE and is_auto_parallel: |
|
|
|
elif slice_mode == "batch_slice" and is_auto_parallel: |
|
|
|
self.gatherv2.set_strategy(((1, 1), (get_group_size(), 1))) |
|
|
|
self.embeddinglookup.set_strategy(((1, 1), (get_group_size(), 1))) |
|
|
|
else: |
|
|
|
if is_auto_parallel: |
|
|
|
raise ValueError("slice_mode should support mode in nn.EmbeddingLookUpSplitMode, but get " |
|
|
|
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get " |
|
|
|
+ str(slice_mode)) |
|
|
|
|
|
|
|
def construct(self, indices): |
|
|
|
@@ -191,25 +196,3 @@ class EmbeddingLookup(Cell): |
|
|
|
else: |
|
|
|
out = self.gatherv2(self.embedding_table, indices, 0) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingLookUpSplitMode: |
|
|
|
""" |
|
|
|
EmbeddingLookUp slice options in auto parallel and semi auto parallel mode. |
|
|
|
|
|
|
|
There are five kinds of slice options, "BATCH_SLICE", "FIELD_SLICE", |
|
|
|
"TABLE_ROW_SLICE" and "TABLE_COLUMN_SLICE". Default: "BATCH_SLICE". |
|
|
|
|
|
|
|
- BATCH_SLICE: Slicing batch dimensions of indices. |
|
|
|
- FIELD_SLICE: Slicing field dimensions of indices. |
|
|
|
- TABLE_ROW_SLICE: Slicing row of table. |
|
|
|
- TABLE_COLUMN_SLICE: Slicing column of table. |
|
|
|
|
|
|
|
MODE_LIST: The list for all supported parallel modes. |
|
|
|
""" |
|
|
|
|
|
|
|
BATCH_SLICE = "batch_slice" |
|
|
|
FIELD_SLICE = "field_slice" |
|
|
|
TABLE_ROW_SLICE = "table_row_slice" |
|
|
|
TABLE_COLUMN_SLICE = "table_column_slice" |
|
|
|
MODE_LIST = [BATCH_SLICE, FIELD_SLICE, TABLE_ROW_SLICE, TABLE_COLUMN_SLICE] |