You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embedding.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """embedding"""
  16. import mindspore.common.dtype as mstype
  17. from mindspore.common.tensor import Tensor
  18. from mindspore.ops import operations as P
  19. from mindspore.common.parameter import Parameter
  20. from mindspore.common.initializer import initializer
  21. from ..cell import Cell
  22. from ..._checkparam import Validator as validator
  23. class Embedding(Cell):
  24. r"""
  25. A simple lookup table that stores embeddings of a fixed dictionary and size.
  26. This module is often used to store word embeddings and retrieve them using
  27. indices. The input to the module is a list of indices, and the output is
  28. the corresponding word embeddings.
  29. Note:
  30. When 'use_one_hot' is set to True, the input should be of type mindspore.int32.
  31. Args:
  32. vocab_size (int): Size of the dictionary of embeddings.
  33. embedding_size (int): The size of each embedding vector.
  34. use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: False.
  35. embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
  36. Refer to class `initializer` for the values of string when a string
  37. is specified. Default: 'normal'.
  38. dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
  39. Inputs:
  40. - **input** (Tensor) - Tensor of shape :math:`(\text{vocab_size})`.
  41. Outputs:
  42. Tensor of shape :math:`(\text{vocab_size}, \text{embedding_size})`.
  43. Examples:
  44. >>> net = nn.Embedding(20000, 768, True)
  45. >>> input_data = Tensor(np.ones([8, 128]), mindspore.int32)
  46. >>>
  47. >>> # Maps the input word IDs to word embedding.
  48. >>> output = net(input_data)
  49. >>> output.shape()
  50. (8, 128, 768)
  51. """
  52. def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32):
  53. super(Embedding, self).__init__()
  54. validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
  55. self.vocab_size = vocab_size
  56. self.embedding_size = embedding_size
  57. self.use_one_hot = use_one_hot
  58. self.embedding_table = Parameter(initializer(embedding_table, [vocab_size, embedding_size]),
  59. name='embedding_table')
  60. self.dtype = dtype
  61. self.expand = P.ExpandDims()
  62. self.reshape_flat = P.Reshape()
  63. self.shp_flat = (-1,)
  64. self.gather = P.GatherV2()
  65. self.one_hot = P.OneHot()
  66. self.on_value = Tensor(1.0, self.dtype)
  67. self.off_value = Tensor(0.0, self.dtype)
  68. self.array_mul = P.MatMul()
  69. self.reshape = P.Reshape()
  70. self.get_shp = P.Shape()
  71. def construct(self, ids):
  72. extended_ids = self.expand(ids, -1)
  73. out_shape = self.get_shp(ids) + (self.embedding_size,)
  74. flat_ids = self.reshape_flat(extended_ids, self.shp_flat)
  75. if self.use_one_hot:
  76. one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
  77. output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
  78. else:
  79. output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
  80. output = self.reshape(output_for_reshape, out_shape)
  81. return output
  82. def extend_repr(self):
  83. s = 'vocab_size={}, embedding_size={},' \
  84. 'use_one_hot={}, ' \
  85. 'embedding_table={}, dtype={}'.format(
  86. self.vocab_size,
  87. self.embedding_size,
  88. self.use_one_hot,
  89. self.embedding_table,
  90. self.dtype)
  91. return s