You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embedding.py 32 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """embedding"""
  16. import mindspore.common.dtype as mstype
  17. from mindspore import log as logger
  18. from mindspore.common.tensor import Tensor
  19. from mindspore.ops import operations as P
  20. from mindspore.ops import functional as F
  21. from mindspore.common.parameter import Parameter
  22. from mindspore.common.initializer import initializer
  23. from mindspore.communication.management import get_group_size, get_rank
  24. from mindspore.context import ParallelMode
  25. from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
  26. from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context
  27. from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _set_rank_id
  28. from mindspore import context
  29. from mindspore._checkparam import Rel
  30. from mindspore._checkparam import Validator as validator
  31. from mindspore.ops.primitive import constexpr
  32. from .basic import ClipByNorm
  33. from .math import Range
  34. from ..cell import Cell
  35. __all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup']
  36. @constexpr
  37. def _check_input_2d(input_shape, param_name, func_name):
  38. if len(input_shape) != 2:
  39. raise ValueError(f"{func_name} {param_name} should be 2d, but got shape {input_shape}")
  40. return True
  41. @constexpr
  42. def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
  43. validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
  44. class Embedding(Cell):
  45. r"""
  46. A simple lookup table that stores embeddings of a fixed dictionary and size.
  47. This module is often used to store word embeddings and retrieve them using
  48. indices. The input to the module is a list of indices, and the output is
  49. the corresponding word embeddings.
  50. Note:
  51. When 'use_one_hot' is set to True, the type of the input must be mindspore.int32.
  52. Args:
  53. vocab_size (int): Size of the dictionary of embeddings.
  54. embedding_size (int): The size of each embedding vector.
  55. use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: False.
  56. embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
  57. Refer to class `initializer` for the values of string when a string
  58. is specified. Default: 'normal'.
  59. dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
  60. padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
  61. will be initialized to zero. Default: None. The feature is inactivated.
  62. Inputs:
  63. - **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The elements of
  64. the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
  65. be zero.
  66. Outputs:
  67. Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`.
  68. Raises:
  69. TypeError: If `vocab_size` or `embedding_size` is not an int.
  70. TypeError: If `use_one_hot` is not a bool.
  71. ValueError: If `padding_idx` is an int which not in range [0, `vocab_size`].
  72. Supported Platforms:
  73. ``Ascend`` ``GPU``
  74. Examples:
  75. >>> net = nn.Embedding(20000, 768, True)
  76. >>> input_data = Tensor(np.ones([8, 128]), mindspore.int32)
  77. >>>
  78. >>> # Maps the input word IDs to word embedding.
  79. >>> output = net(input_data)
  80. >>> result = output.shape
  81. >>> print(result)
  82. (8, 128, 768)
  83. """
  84. def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
  85. dtype=mstype.float32, padding_idx=None):
  86. super(Embedding, self).__init__()
  87. self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
  88. self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
  89. validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
  90. validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
  91. self.use_one_hot = use_one_hot
  92. self.dtype = dtype
  93. self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
  94. self.padding_idx = padding_idx
  95. if padding_idx is not None:
  96. self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
  97. "padding_idx", self.cls_name)
  98. if isinstance(self.init_tensor, Tensor) and self.init_tensor.init is not None:
  99. self.init_tensor = self.init_tensor.init_data()
  100. self.init_tensor = self.init_tensor.asnumpy()
  101. self.init_tensor[self.padding_idx] = 0
  102. self.init_tensor = Tensor(self.init_tensor)
  103. self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
  104. self.expand = P.ExpandDims()
  105. self.reshape_flat = P.Reshape()
  106. self.shp_flat = (-1,)
  107. self.gather = P.Gather()
  108. self.one_hot = P.OneHot()
  109. self.on_value = Tensor(1.0, self.dtype)
  110. self.off_value = Tensor(0.0, self.dtype)
  111. self.array_mul = P.MatMul()
  112. self.reshape = P.Reshape()
  113. self.get_shp = P.Shape()
  114. def construct(self, ids):
  115. extended_ids = self.expand(ids, -1)
  116. out_shape = self.get_shp(ids) + (self.embedding_size,)
  117. flat_ids = self.reshape_flat(extended_ids, self.shp_flat)
  118. if self.use_one_hot:
  119. one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
  120. output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
  121. else:
  122. output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
  123. output = self.reshape(output_for_reshape, out_shape)
  124. return output
  125. def extend_repr(self):
  126. s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
  127. self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
  128. return s
  129. @constexpr
  130. def _make_axis_range(start, end):
  131. axis = tuple(range(start, end))
  132. return axis
  133. class EmbeddingLookup(Cell):
  134. r"""
  135. Returns a slice of the input tensor based on the specified indices.
  136. Note:
  137. When 'target' is set to 'CPU', this module will use
  138. P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
  139. specified 'offset = 0' to lookup table.
  140. When 'target' is set to 'DEVICE', this module will use P.Gather() which
  141. specified 'axis = 0' to lookup table.
  142. In field slice mode, the manual_shapes must be given. It is a tuple ,where
  143. the element is vocab[i], vocab[i] is the row numbers for i-th part.
  144. Args:
  145. vocab_size (int): Size of the dictionary of embeddings.
  146. embedding_size (int): The size of each embedding vector.
  147. param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
  148. Refer to class `initializer` for the values of string when a string
  149. is specified. Default: 'normal'.
  150. target (str): Specifies the target where the op is executed. The value must in
  151. ['DEVICE', 'CPU']. Default: 'CPU'.
  152. slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
  153. nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
  154. manual_shapes (tuple): The accompaniment array in field slice mode.
  155. max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
  156. or None. Default: None
  157. sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
  158. vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
  159. 'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size.
  160. In addition, it should be noted that it will cost the 'DEVICE'
  161. memory, so suggests setting a reasonable value to avoid insufficient memory.
  162. Inputs:
  163. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
  164. Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
  165. and the exceeding part will be filled with 0 in the output. Values does not support negative and the result
  166. is undefined if values are negative. Input_indices must only be a 2d tensor in
  167. this interface when run in semi auto parallel/auto parallel mode.
  168. Outputs:
  169. Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
  170. Raises:
  171. TypeError: If `vocab_size` or `embedding_size` or `vocab_cache_size` is not an int.
  172. TypeError: If `sparse` is not a bool or `manual_shapes` is not a tuple.
  173. ValueError: If `vocab_size` or `embedding_size` is less than 1.
  174. ValueError: If `vocab_cache_size` is less than 0.
  175. ValueError: If `target` is neither 'CPU' nor 'DEVICE'.
  176. ValueError: If `slice_mode` is not one of 'batch_slice' or 'field_slice' or
  177. 'table_row_slice' or 'table_column_slice'.
  178. ValueError: If `sparse` is False and `target` is 'CPU'.
  179. ValueError: If `slice_mode` is 'field_slice' and `manual_shapes` is None.
  180. Supported Platforms:
  181. ``Ascend`` ``CPU``
  182. Examples:
  183. >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
  184. >>> result = nn.EmbeddingLookup(4,2)(input_indices)
  185. >>> print(result.shape)
  186. (2, 2, 2)
  187. """
  188. BATCH_SLICE = "batch_slice"
  189. FIELD_SLICE = "field_slice"
  190. TABLE_ROW_SLICE = "table_row_slice"
  191. TABLE_COLUMN_SLICE = "table_column_slice"
  192. def __init__(self, vocab_size, embedding_size, param_init='normal',
  193. target='CPU', slice_mode='batch_slice', manual_shapes=None,
  194. max_norm=None, sparse=True, vocab_cache_size=0):
  195. super(EmbeddingLookup, self).__init__()
  196. validator.check_value_type('sparse', sparse, [bool], self.cls_name)
  197. self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
  198. self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
  199. self.target = target
  200. self.sparse = sparse
  201. self.cache_enable = self.vocab_cache_size > 0
  202. self.forward_unique = False
  203. if target not in ('CPU', 'DEVICE'):
  204. raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
  205. + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
  206. if not sparse and target == 'CPU':
  207. raise ValueError('When target is CPU, embedding_lookup must be sparse.')
  208. if sparse:
  209. self.gatherv2 = P.SparseGatherV2()
  210. else:
  211. self.gatherv2 = P.Gather()
  212. self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
  213. enable_ps = _get_ps_context("enable_ps")
  214. if enable_ps:
  215. self._process_vocab_cache(slice_mode)
  216. self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
  217. self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
  218. name='embedding_table')
  219. parallel_mode = _get_parallel_mode()
  220. is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
  221. self.gather_revert = P.Gather()
  222. self.reshape_first = P.Reshape()
  223. self.reshape = P.Reshape()
  224. self.unique = P.Unique()
  225. self.shape = P.Shape()
  226. if is_auto_parallel:
  227. self.unique = P.Unique().shard(((1,),))
  228. if self.cache_enable and enable_ps:
  229. self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
  230. if is_auto_parallel:
  231. self.unique.add_prim_attr('cache_enable', True)
  232. indices_shape_size = 2
  233. if slice_mode == "field_slice" and is_auto_parallel:
  234. if not manual_shapes:
  235. raise ValueError("in slice field mode, the manual_shapes should not be none")
  236. if not isinstance(manual_shapes, tuple):
  237. raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes)))
  238. for dim in manual_shapes:
  239. validator.check_positive_int(dim, 'manual shape dim', self.cls_name)
  240. self.gatherv2.add_prim_attr("manual_split", manual_shapes)
  241. self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
  242. self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
  243. self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
  244. elif slice_mode == "table_row_slice" and is_auto_parallel:
  245. full_batch = _get_full_batch()
  246. if (target == 'DEVICE' and not full_batch) or (self.cache_enable and enable_ps and sparse):
  247. indices_shape_size = 1
  248. self.gather_revert.shard(((1, 1), (get_group_size(),)))
  249. self.forward_unique = True
  250. indices_strategy = (1,)*indices_shape_size
  251. self.gatherv2.shard(((get_group_size(), 1), indices_strategy))
  252. self.embeddinglookup.shard(((get_group_size(), 1), indices_strategy))
  253. elif slice_mode == "table_column_slice" and is_auto_parallel:
  254. if target == 'DEVICE':
  255. indices_shape_size = 1
  256. self.gather_revert.shard(((1, get_group_size()), (1,)))
  257. self.forward_unique = True
  258. indices_strategy = (1,)*indices_shape_size
  259. self.gatherv2.shard(((1, get_group_size()), indices_strategy))
  260. self.embeddinglookup.shard(((1, get_group_size()), indices_strategy))
  261. elif slice_mode == "batch_slice" and is_auto_parallel:
  262. indices_strategy = [get_group_size()]
  263. indices_strategy.extend([1]*(indices_shape_size - 1))
  264. indices_strategy = tuple(indices_strategy)
  265. self.gatherv2.shard(((1, 1), indices_strategy))
  266. self.embeddinglookup.shard(((1, 1), indices_strategy))
  267. else:
  268. if is_auto_parallel:
  269. raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
  270. + str(slice_mode))
  271. if self.cache_enable and not enable_ps:
  272. if parallel_mode != ParallelMode.STAND_ALONE:
  273. raise ValueError("parallel mode haven't supported cache enable yet.")
  274. self._set_cache_enable()
  275. self.embedding_table.unique = self.forward_unique
  276. self.max_norm = max_norm
  277. if self.max_norm is not None:
  278. self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
  279. self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
  280. def _set_cache_enable(self):
  281. """EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
  282. if self.target != 'DEVICE':
  283. raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target.")
  284. if not self.sparse:
  285. raise ValueError("The configuration of 'vocab_cache_size' is valid only 'sparse' is true.")
  286. if context.get_context("device_target") != 'Ascend':
  287. raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'ascend'.")
  288. logger.info("EmbeddingLookup cache enable takes effect.")
  289. self.forward_unique = True
  290. self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')
  291. self.unique.add_prim_attr('cache_enable', True)
  292. self.embedding_table.cache_enable = self.cache_enable
  293. self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
  294. self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
  295. def _process_vocab_cache(self, slice_mode):
  296. """PS embeddingLookup cache check and process."""
  297. self.cache_enable = False
  298. if self.vocab_cache_size > 0:
  299. if self.target == 'CPU':
  300. logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
  301. "current target is CPU, so it will be ignored.")
  302. return
  303. enable_ps = _get_ps_context("enable_ps")
  304. if not enable_ps:
  305. logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning "
  306. "mode, current mode is not parameter server trainning mode, so it will be ignored.")
  307. return
  308. parallel_mode = _get_parallel_mode()
  309. is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
  310. if is_auto_parallel:
  311. rank_size = get_group_size()
  312. rank_id = get_rank()
  313. full_batch = _get_full_batch()
  314. if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
  315. raise ValueError("The embeddingLookup cache of parameter server parallel only be used "
  316. "in 'full_batch' and 'table_row_slice' parallel strategy.")
  317. self.vocab_cache_size = self.vocab_cache_size * rank_size
  318. _set_rank_id(rank_id)
  319. self.cache_enable = True
  320. if _is_role_worker():
  321. self.vocab_size = self.vocab_cache_size
  322. if context.get_context("enable_sparse") != self.sparse:
  323. raise ValueError("The value of parameter 'sparse' must be same for all EmbeddingLookup "
  324. "kernels and equal the value of 'enable_sparse' in context setting in "
  325. "parameter server cache mode")
  326. def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
  327. """PS embeddingLookup cache enable set."""
  328. self.embedding_table.cache_enable = True
  329. self.embedding_table.is_param_ps = True
  330. _set_cache_enable(True)
  331. if self.sparse:
  332. self.forward_unique = True
  333. if _is_role_worker():
  334. _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
  335. def construct(self, indices):
  336. if self.target == "CPU":
  337. out = self.embeddinglookup(self.embedding_table, indices, 0)
  338. else:
  339. if self.forward_unique:
  340. shp = self.shape(indices) + (self.embedding_size,)
  341. indices_flatten = self.reshape_first(indices, (-1,))
  342. unique_id, unique_idx = self.unique(indices_flatten)
  343. weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
  344. weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
  345. out = self.reshape(weight_flatten, shp)
  346. else:
  347. out = self.gatherv2(self.embedding_table, indices, 0)
  348. if self.max_norm is not None:
  349. axis = _make_axis_range(F.rank(indices), F.rank(out))
  350. clip_by_norm = ClipByNorm(axis)
  351. out = clip_by_norm(out, self.max_norm)
  352. return out
  353. class MultiFieldEmbeddingLookup(EmbeddingLookup):
  354. r"""
  355. Returns a slice of input tensor based on the specified indices and the field ids. This operation
  356. supports looking up embeddings using multi hot and one hot fields simultaneously.
  357. Note:
  358. When 'target' is set to 'CPU', this module will use
  359. P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
  360. specified 'offset = 0' to lookup table.
  361. When 'target' is set to 'DEVICE', this module will use P.Gather() which
  362. specified 'axis = 0' to lookup table.
  363. The vectors with the same field_ids will be combined by the 'operator', such as 'SUM', 'MAX' and
  364. 'MEAN'. Ensure the input_values of the padded id is zero, so that they can be ignored. The final
  365. output will be zeros if the sum of absolute weight of the field is zero. This class only
  366. supports ['table_row_slice', 'batch_slice' and 'table_column_slice']. For the operation 'MAX' on
  367. device Ascend, there is a constrain where batch_size * (seq_length + field_size) < 3500.
  368. Args:
  369. vocab_size (int): The size of the dictionary of embeddings.
  370. embedding_size (int): The size of each embedding vector.
  371. field_size (int): The field size of the final outputs.
  372. param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
  373. Refer to class `initializer` for the values of string when a string
  374. is specified. Default: 'normal'.
  375. target (str): Specifies the target where the op is executed. The value must in
  376. ['DEVICE', 'CPU']. Default: 'CPU'.
  377. slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
  378. nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
  379. feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
  380. max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
  381. or None. Default: None
  382. sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
  383. operator (str): The pooling method for the features in one field. Support 'SUM, 'MEAN' and 'MAX'
  384. Inputs:
  385. - **input_indices** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
  386. Specifies the indices of elements of the original Tensor. Input_indices must be a 2d tensor in
  387. this interface. Type is Int32, Int64.
  388. - **input_values** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
  389. Specifies the weights of elements of the input_indices. The lookout vector will multiply with
  390. the input_values. Type is Float32.
  391. - **field_ids** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
  392. Specifies the field id of elements of the input_indices. Type is Int32.
  393. Outputs:
  394. Tensor, the shape of tensor is :math:`(batch\_size, field\_size, embedding\_size)`. Type is Float32.
  395. Raises:
  396. TypeError: If `vocab_size` or `embedding_size` or `field_size` is not an int.
  397. TypeError: If `sparse` is not a bool or `feature_num_list` is not a tuple.
  398. ValueError: If `vocab_size` or `embedding_size` or `field_size` is less than 1.
  399. ValueError: If `target` is neither 'CPU' nor 'DEVICE'.
  400. ValueError: If `slice_mode` is not one of 'batch_slice', 'field_slice', 'table_row_slice', 'table_column_slice'.
  401. ValueError: If `sparse` is False and `target` is 'CPU'.
  402. ValueError: If `slice_mode` is 'field_slice' and `feature_num_list` is None.
  403. ValueError: If `operator` is not one of 'SUM', 'MAX', 'MEAN'.
  404. Supported Platforms:
  405. ``Ascend`` ``GPU``
  406. Examples:
  407. >>> input_indices = Tensor([[2, 4, 6, 0, 0], [1, 3, 5, 0, 0]], mindspore.int32)
  408. >>> input_values = Tensor([[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], mindspore.float32)
  409. >>> field_ids = Tensor([[0, 1, 1, 0, 0], [0, 0, 1, 0, 0]], mindspore.int32)
  410. >>> net = nn.MultiFieldEmbeddingLookup(10, 2, field_size=2, operator='SUM', target='DEVICE')
  411. >>> out = net(input_indices, input_values, field_ids)
  412. >>> print(out.shape)
  413. (2, 2, 2)
  414. """
  415. OPERATOR_SUM = 'SUM'
  416. OPERATOR_MEAN = 'MEAN'
  417. OPERATOR_MAX = 'MAX'
  418. def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
  419. slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
  420. super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
  421. slice_mode, feature_num_list, max_norm, sparse)
  422. self.field_size = validator.check_positive_int(field_size, 'field_size')
  423. self.operator = operator
  424. self.mul = P.Mul()
  425. self.inf_mask_mul = P.Mul()
  426. self.bias_add = P.Add()
  427. self.inf_add = P.Add()
  428. self.merge_op = None
  429. self.count_op = P.UnsortedSegmentSum()
  430. self.abs = P.Abs()
  431. self.equal = P.Equal()
  432. self.add = P.Add()
  433. self.cast = P.Cast()
  434. self.div_no_nan = P.DivNoNan()
  435. self.expand = P.ExpandDims()
  436. self.max_mask_mul = P.Mul()
  437. self.max_no_equal = P.NotEqual()
  438. if operator == MultiFieldEmbeddingLookup.OPERATOR_SUM:
  439. self.merge_op = P.UnsortedSegmentSum()
  440. elif operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
  441. self.merge_op = P.UnsortedSegmentMax()
  442. elif operator == MultiFieldEmbeddingLookup.OPERATOR_MEAN:
  443. self.merge_op = P.UnsortedSegmentSum()
  444. else:
  445. raise ValueError("The operator supports ['SUM', 'MAX', 'MEAN'], but found: "+str(operator))
  446. parallel_mode = _get_parallel_mode()
  447. is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
  448. if slice_mode in ["table_row_slice", "batch_slice"] and is_auto_parallel:
  449. self.merge_op.shard(((get_group_size(), 1, 1), (get_group_size(), 1)))
  450. self.expand.shard(((get_group_size(),),))
  451. self.bias_add.shard(((1, 1), (1, 1)))
  452. self.mul.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1)))
  453. self.count_op.shard(((get_group_size(), 1), (get_group_size(), 1)))
  454. self.add.shard(((get_group_size(),), (get_group_size(),)))
  455. self.div_no_nan.shard(((get_group_size(), 1), (get_group_size(), 1)))
  456. self.max_mask_mul.shard(((get_group_size(), 1), (get_group_size(), 1)))
  457. self.max_no_equal.shard(((1,), ()))
  458. if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
  459. self.equal.shard(((get_group_size(), 1, 1), ()))
  460. self.inf_mask_mul.shard(((get_group_size(), 1, 1), ()))
  461. self.merge_op.shard(((get_group_size(), 1), (get_group_size(),)))
  462. self.count_op.shard(((get_group_size(),), (get_group_size(),)))
  463. self.inf_add.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1)))
  464. elif slice_mode == "table_column_slice" and is_auto_parallel:
  465. self.merge_op.shard(((1, 1, get_group_size()), (1, 1)))
  466. self.div_no_nan.shard(((1, get_group_size()), (1, 1)))
  467. self.bias_add.shard(((1, 1), (1, 1)))
  468. self.mul.shard(((1, 1, 1), (1, 1, get_group_size())))
  469. self.count_op.shard(((1, 1), (1, 1)))
  470. self.add.shard(((1,), (1,)))
  471. self.max_mask_mul.shard(((1, get_group_size()), (1, 1)))
  472. self.expand.shard(((1,),))
  473. self.max_no_equal.shard(((1,), ()))
  474. if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
  475. self.equal.shard(((1, 1, 1), ()))
  476. self.inf_mask_mul.shard(((1, 1, 1), ()))
  477. self.merge_op.shard(((1, get_group_size()), (1,)))
  478. self.count_op.shard(((1,), (1,)))
  479. self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
  480. else:
  481. if is_auto_parallel:
  482. raise ValueError("slice_mode should be ['table_row_slice', 'batch_slice' and \
  483. 'table_column_slice'], but get " + str(slice_mode))
  484. # Min value for fp32
  485. self.negative_inf_value = -3.402823466E+38
  486. def construct(self, input_indices, input_values, field_ids):
  487. _check_input_2d(F.shape(input_indices), "input_indices", self.cls_name)
  488. _check_input_2d(F.shape(input_values), "input_values", self.cls_name)
  489. _check_input_2d(F.shape(field_ids), "field_ids", self.cls_name)
  490. _check_input_dtype(F.dtype(input_indices), "input_indices", [mstype.int32, mstype.int64], self.cls_name)
  491. _check_input_dtype(F.dtype(input_values), "input_values", [mstype.float32], self.cls_name)
  492. _check_input_dtype(F.dtype(field_ids), "field_ids", [mstype.int32], self.cls_name)
  493. batch_size = self.shape(input_indices)[0]
  494. num_segments = batch_size * self.field_size
  495. bias = Range(0, num_segments, self.field_size)()
  496. bias = self.reshape(bias, (batch_size, -1))
  497. field_ids = self.bias_add(field_ids, bias)
  498. if self.target == "CPU":
  499. out = self.embeddinglookup(self.embedding_table, input_indices, 0)
  500. else:
  501. if self.forward_unique:
  502. shp = self.shape(input_indices) + (self.embedding_size,)
  503. indices_flatten = self.reshape(input_indices, (-1,))
  504. unique_id, unique_idx = self.unique(indices_flatten)
  505. weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
  506. weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
  507. out = self.reshape(weight_flatten, shp)
  508. else:
  509. out = self.gatherv2(self.embedding_table, input_indices, 0)
  510. if self.max_norm is not None:
  511. axis = _make_axis_range(F.rank(input_indices), F.rank(out))
  512. clip_by_norm = ClipByNorm(axis)
  513. out = clip_by_norm(out, self.max_norm)
  514. weights = self.reshape(input_values, (batch_size, self.shape(input_indices)[1], 1))
  515. embedding = self.mul(weights, out)
  516. if self.operator == 'MAX':
  517. # Fill the padding value to -inf, so the padded value will not influence the results
  518. negative_inf_mask = self.cast(self.equal(weights, 0), mstype.float32)
  519. inf_mask = self.inf_mask_mul(negative_inf_mask, self.negative_inf_value)
  520. embedding = self.inf_add(embedding, inf_mask)
  521. embedding = self.reshape(embedding, (-1, self.embedding_size))
  522. field_ids = self.reshape(field_ids, (-1,))
  523. merged_vectors = self.merge_op(embedding, field_ids, num_segments)
  524. if self.operator == 'MAX':
  525. value_count = self.count_op(self.abs(self.reshape(input_values, (-1,))), field_ids, num_segments)
  526. value_zeros = self.cast(self.max_no_equal(value_count, 0.0), mstype.float32)
  527. count = self.expand(value_zeros, -1)
  528. merged_vectors = self.max_mask_mul(merged_vectors, count)
  529. if self.operator == 'MEAN':
  530. value_count = self.count_op(self.abs(input_values), field_ids, num_segments)
  531. value_count = self.expand(value_count, -1)
  532. merged_vectors = self.div_no_nan(merged_vectors, value_count)
  533. merged_vectors = self.reshape(merged_vectors, (batch_size, self.field_size, -1))
  534. return merged_vectors