You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embedding.py 9.5 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """embedding"""
  16. import mindspore.common.dtype as mstype
  17. from mindspore.common.tensor import Tensor
  18. from mindspore.ops import operations as P
  19. from mindspore.common.parameter import Parameter
  20. from mindspore.common.initializer import initializer
  21. from mindspore.communication.management import get_group_size
  22. from mindspore.context import ParallelMode
  23. from mindspore.parallel._utils import _get_parallel_mode
  24. from mindspore._checkparam import Validator as validator
  25. from ..cell import Cell
  26. __all__ = ['Embedding', 'EmbeddingLookup']
  27. class Embedding(Cell):
  28. r"""
  29. A simple lookup table that stores embeddings of a fixed dictionary and size.
  30. This module is often used to store word embeddings and retrieve them using
  31. indices. The input to the module is a list of indices, and the output is
  32. the corresponding word embeddings.
  33. Note:
  34. When 'use_one_hot' is set to True, the type of the input must be mindspore.int32.
  35. Args:
  36. vocab_size (int): Size of the dictionary of embeddings.
  37. embedding_size (int): The size of each embedding vector.
  38. use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: False.
  39. embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
  40. Refer to class `initializer` for the values of string when a string
  41. is specified. Default: 'normal'.
  42. dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
  43. Inputs:
  44. - **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The elements of
  45. the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
  46. be zero.
  47. Outputs:
  48. Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`.
  49. Examples:
  50. >>> net = nn.Embedding(20000, 768, True)
  51. >>> input_data = Tensor(np.ones([8, 128]), mindspore.int32)
  52. >>>
  53. >>> # Maps the input word IDs to word embedding.
  54. >>> output = net(input_data)
  55. >>> output.shape
  56. (8, 128, 768)
  57. """
  58. def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32):
  59. super(Embedding, self).__init__()
  60. validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
  61. validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
  62. self.vocab_size = vocab_size
  63. self.embedding_size = embedding_size
  64. self.use_one_hot = use_one_hot
  65. self.embedding_table = Parameter(initializer(embedding_table, [vocab_size, embedding_size]),
  66. name='embedding_table')
  67. self.dtype = dtype
  68. self.expand = P.ExpandDims()
  69. self.reshape_flat = P.Reshape()
  70. self.shp_flat = (-1,)
  71. self.gather = P.GatherV2()
  72. self.one_hot = P.OneHot()
  73. self.on_value = Tensor(1.0, self.dtype)
  74. self.off_value = Tensor(0.0, self.dtype)
  75. self.array_mul = P.MatMul()
  76. self.reshape = P.Reshape()
  77. self.get_shp = P.Shape()
  78. def construct(self, ids):
  79. extended_ids = self.expand(ids, -1)
  80. out_shape = self.get_shp(ids) + (self.embedding_size,)
  81. flat_ids = self.reshape_flat(extended_ids, self.shp_flat)
  82. if self.use_one_hot:
  83. one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
  84. output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
  85. else:
  86. output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
  87. output = self.reshape(output_for_reshape, out_shape)
  88. return output
  89. def extend_repr(self):
  90. s = 'vocab_size={}, embedding_size={},' \
  91. 'use_one_hot={}, ' \
  92. 'embedding_table={}, dtype={}'.format(
  93. self.vocab_size,
  94. self.embedding_size,
  95. self.use_one_hot,
  96. self.embedding_table,
  97. self.dtype)
  98. return s
  99. class EmbeddingLookup(Cell):
  100. r"""
  101. Returns a slice of input tensor based on the specified indices.
  102. Note:
  103. When 'target' is set to 'CPU', this module will use
  104. P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
  105. specified 'offset = 0' to lookup table.
  106. When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
  107. specified 'axis = 0' to lookup table.
  108. In field slice mode, the manual_shapes must be given. It is a tuple ,where
  109. the element is vocab[i], vocab[i] is the row numbers for i-th
  110. part.
  111. Args:
  112. vocab_size (int): Size of the dictionary of embeddings.
  113. embedding_size (int): The size of each embedding vector.
  114. param_init (str): The initialize way of embedding table. Default: 'normal'.
  115. target (str): Specifies the target where the op is executed. The value must in
  116. ['DEVICE', 'CPU']. Default: 'CPU'.
  117. slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
  118. nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
  119. manual_shapes (tuple): The accompaniment array in field slice mode.
  120. Inputs:
  121. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
  122. Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
  123. and the exceeding part will be filled with 0 in the output. Input_indices must only be a 2d tensor in
  124. this interface.
  125. Outputs:
  126. Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
  127. Examples:
  128. >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
  129. >>> out = nn.EmbeddingLookup(4,2)(input_indices)
  130. """
  131. BATCH_SLICE = "batch_slice"
  132. FIELD_SLICE = "field_slice"
  133. TABLE_ROW_SLICE = "table_row_slice"
  134. TABLE_COLUMN_SLICE = "table_column_slice"
  135. def __init__(self, vocab_size, embedding_size, param_init='normal',
  136. target='CPU', slice_mode='batch_slice', manual_shapes=None):
  137. super(EmbeddingLookup, self).__init__()
  138. self.target = target
  139. if target not in ('CPU', 'DEVICE'):
  140. raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
  141. + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
  142. self.gatherv2 = P.GatherV2()
  143. self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
  144. self.embedding_table = Parameter(initializer(param_init, [vocab_size, embedding_size]),
  145. name='embedding_table')
  146. parallel_mode = _get_parallel_mode()
  147. is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
  148. if slice_mode == "field_slice" and is_auto_parallel:
  149. if not manual_shapes:
  150. raise ValueError("in slice field mode, the manual_shapes should not be none")
  151. if not isinstance(manual_shapes, tuple):
  152. raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes)))
  153. for dim in manual_shapes:
  154. validator.check_positive_int(dim, 'manual shape dim', self.cls_name)
  155. self.gatherv2.add_prim_attr("manual_split", manual_shapes)
  156. self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
  157. self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
  158. self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
  159. elif slice_mode == "table_row_slice" and is_auto_parallel:
  160. self.gatherv2.shard(((get_group_size(), 1), (1, 1)))
  161. self.embeddinglookup.shard(((get_group_size(), 1), (1, 1)))
  162. elif slice_mode == "table_column_slice" and is_auto_parallel:
  163. self.gatherv2.shard(((1, get_group_size()), (1, 1)))
  164. self.embeddinglookup.shard(((1, get_group_size()), (1, 1)))
  165. elif slice_mode == "batch_slice" and is_auto_parallel:
  166. self.gatherv2.shard(((1, 1), (get_group_size(), 1)))
  167. self.embeddinglookup.shard(((1, 1), (get_group_size(), 1)))
  168. else:
  169. if is_auto_parallel:
  170. raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
  171. + str(slice_mode))
  172. def construct(self, indices):
  173. if self.target == "CPU":
  174. out = self.embeddinglookup(self.embedding_table, indices, 0)
  175. else:
  176. out = self.gatherv2(self.embedding_table, indices, 0)
  177. return out