|
|
|
@@ -16,16 +16,21 @@ |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.ops import functional as F |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
from mindspore.common.initializer import initializer |
|
|
|
from mindspore.communication.management import get_group_size |
|
|
|
from mindspore.context import ParallelMode |
|
|
|
from mindspore.parallel._utils import _get_parallel_mode |
|
|
|
from mindspore._checkparam import Rel |
|
|
|
from mindspore._checkparam import Validator as validator |
|
|
|
from mindspore.ops.primitive import constexpr |
|
|
|
from .basic import ClipByNorm |
|
|
|
from ..cell import Cell |
|
|
|
|
|
|
|
__all__ = ['Embedding', 'EmbeddingLookup'] |
|
|
|
|
|
|
|
|
|
|
|
class Embedding(Cell): |
|
|
|
r""" |
|
|
|
A simple lookup table that stores embeddings of a fixed dictionary and size. |
|
|
|
@@ -45,7 +50,8 @@ class Embedding(Cell): |
|
|
|
Refer to class `initializer` for the values of string when a string |
|
|
|
is specified. Default: 'normal'. |
|
|
|
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32. |
|
|
|
|
|
|
|
padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index |
|
|
|
will be initialized to zero. Default: None. The feature is inactivated. |
|
|
|
Inputs: |
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The elements of |
|
|
|
the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will |
|
|
|
@@ -63,16 +69,24 @@ class Embedding(Cell): |
|
|
|
>>> output.shape |
|
|
|
(8, 128, 768) |
|
|
|
""" |
|
|
|
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32): |
|
|
|
|
|
|
|
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', |
|
|
|
dtype=mstype.float32, padding_idx=None): |
|
|
|
super(Embedding, self).__init__() |
|
|
|
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) |
|
|
|
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) |
|
|
|
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) |
|
|
|
validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name) |
|
|
|
self.vocab_size = vocab_size |
|
|
|
self.embedding_size = embedding_size |
|
|
|
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) |
|
|
|
self.use_one_hot = use_one_hot |
|
|
|
self.embedding_table = Parameter(initializer(embedding_table, [vocab_size, embedding_size]), |
|
|
|
name='embedding_table') |
|
|
|
self.dtype = dtype |
|
|
|
self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size]) |
|
|
|
self.padding_idx = padding_idx |
|
|
|
if padding_idx is not None: |
|
|
|
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH, |
|
|
|
"padding_idx", self.cls_name) |
|
|
|
self.init_tensor = self.init_tensor.to_tensor().asnumpy() |
|
|
|
self.init_tensor[self.padding_idx] = 0 |
|
|
|
self.embedding_table = Parameter(self.init_tensor, name='embedding_table') |
|
|
|
self.expand = P.ExpandDims() |
|
|
|
self.reshape_flat = P.Reshape() |
|
|
|
self.shp_flat = (-1,) |
|
|
|
@@ -99,16 +113,17 @@ class Embedding(Cell): |
|
|
|
return output |
|
|
|
|
|
|
|
def extend_repr(self): |
|
|
|
s = 'vocab_size={}, embedding_size={},' \ |
|
|
|
'use_one_hot={}, ' \ |
|
|
|
'embedding_table={}, dtype={}'.format( |
|
|
|
self.vocab_size, |
|
|
|
self.embedding_size, |
|
|
|
self.use_one_hot, |
|
|
|
self.embedding_table, |
|
|
|
self.dtype) |
|
|
|
s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format( |
|
|
|
self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx) |
|
|
|
return s |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _make_axis_range(start, end): |
|
|
|
axis = tuple(range(start, end)) |
|
|
|
return axis |
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingLookup(Cell): |
|
|
|
r""" |
|
|
|
Returns a slice of input tensor based on the specified indices. |
|
|
|
@@ -120,8 +135,7 @@ class EmbeddingLookup(Cell): |
|
|
|
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which |
|
|
|
specified 'axis = 0' to lookup table. |
|
|
|
In field slice mode, the manual_shapes must be given. It is a tuple ,where |
|
|
|
the element is vocab[i], vocab[i] is the row numbers for i-th |
|
|
|
part. |
|
|
|
the element is vocab[i], vocab[i] is the row numbers for i-th part. |
|
|
|
|
|
|
|
Args: |
|
|
|
vocab_size (int): Size of the dictionary of embeddings. |
|
|
|
@@ -132,6 +146,8 @@ class EmbeddingLookup(Cell): |
|
|
|
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through |
|
|
|
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE. |
|
|
|
manual_shapes (tuple): The accompaniment array in field slice mode. |
|
|
|
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 |
|
|
|
or None. Default: None |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. |
|
|
|
@@ -152,7 +168,7 @@ class EmbeddingLookup(Cell): |
|
|
|
TABLE_COLUMN_SLICE = "table_column_slice" |
|
|
|
|
|
|
|
def __init__(self, vocab_size, embedding_size, param_init='normal', |
|
|
|
target='CPU', slice_mode='batch_slice', manual_shapes=None): |
|
|
|
target='CPU', slice_mode='batch_slice', manual_shapes=None, max_norm=None): |
|
|
|
super(EmbeddingLookup, self).__init__() |
|
|
|
self.target = target |
|
|
|
if target not in ('CPU', 'DEVICE'): |
|
|
|
@@ -160,7 +176,9 @@ class EmbeddingLookup(Cell): |
|
|
|
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') |
|
|
|
self.gatherv2 = P.GatherV2() |
|
|
|
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') |
|
|
|
self.embedding_table = Parameter(initializer(param_init, [vocab_size, embedding_size]), |
|
|
|
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) |
|
|
|
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) |
|
|
|
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), |
|
|
|
name='embedding_table') |
|
|
|
parallel_mode = _get_parallel_mode() |
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) |
|
|
|
@@ -188,10 +206,18 @@ class EmbeddingLookup(Cell): |
|
|
|
if is_auto_parallel: |
|
|
|
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get " |
|
|
|
+ str(slice_mode)) |
|
|
|
self.max_norm = max_norm |
|
|
|
if self.max_norm is not None: |
|
|
|
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) |
|
|
|
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) |
|
|
|
|
|
|
|
def construct(self, indices): |
|
|
|
if self.target == "CPU": |
|
|
|
out = self.embeddinglookup(self.embedding_table, indices, 0) |
|
|
|
else: |
|
|
|
out = self.gatherv2(self.embedding_table, indices, 0) |
|
|
|
if self.max_norm is not None: |
|
|
|
axis = _make_axis_range(F.rank(indices), F.rank(out)) |
|
|
|
clip_by_norm = ClipByNorm(axis) |
|
|
|
out = clip_by_norm(out, self.max_norm) |
|
|
|
return out |