You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_embedding_cache_ops.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """cache_ops"""
  16. from ..._checkparam import Validator as validator
  17. from ...common import dtype as mstype
  18. from ..primitive import PrimitiveWithInfer, prim_attr_register, PrimitiveWithCheck
  19. from .. import signature as sig
  20. class UpdateCache(PrimitiveWithCheck):
  21. """
  22. Update the value fo input_x, similar to ScatterNdUpdate.
  23. The diffirent is that UpdateCache will not update when indices < 0 or indices >= max_num.
  24. Inputs:
  25. - **input_x** (Parameter) - Parameter which is going to be updated.
  26. - **indices** (Tensor) - Update indices of input_x.
  27. - **updates** (Tensor) - The update values.
  28. Outputs:
  29. - **out** (Tensor) - Returns a [1] Tensor, which is not usefull.
  30. """
  31. __mindspore_signature__ = (
  32. sig.make_sig('input_x', sig.sig_rw.RW_WRITE,
  33. dtype=sig.sig_dtype.T),
  34. sig.make_sig('indices', dtype=sig.sig_dtype.T1),
  35. sig.make_sig('updates', dtype=sig.sig_dtype.T),
  36. sig.make_sig('max_num', dtype=sig.sig_dtype.T1)
  37. )
  38. @prim_attr_register
  39. def __init__(self):
  40. """init UpdateCache"""
  41. self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'],
  42. outputs=['out'])
  43. def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
  44. return [1]
  45. def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
  46. validator.check_tensor_dtype_valid(
  47. "indices", indices_dtype, mstype.int_type, self.name)
  48. return input_x_dtype
  49. class SubAndFilter(PrimitiveWithCheck):
  50. """
  51. Dynamic kernel, sub an offset and
  52. return the elements which in range [0, max_num).
  53. Inputs:
  54. - **input_x** (Tensor) - Input tensor.
  55. - **max_num** (Int) - The max value of element that after sub `offset`.
  56. - **offset** (int) - Specifies the offset value of this `input_x`.
  57. Outputs:
  58. tuple(Tensor), tuple of 2 tensors, filter_res and filter_idx.
  59. - **filter_res** (Tensor) - The result that `input_x` minus `offset`,
  60. and return which in the range [0, max_num).
  61. - **filter_idx** (Tensor) - A tensor containing indices of elements in the input
  62. coressponding to the output tensor.
  63. Supported Platforms:
  64. `CPU`
  65. Examples:
  66. >>> x = Tensor(np.array([1, 3, 5, 8, 9, 16]), mindspore.int32)
  67. >>> max_num = 10
  68. >>> offset = 5
  69. >>> output = ops.SubAndFilter()(x, max_num, offset)
  70. >>> print(output)
  71. (Tensor(shape=[3], dtype=Int32, value= [0, 3, 4]),
  72. Tensor(shape=[3], dtype=Int32, value= [2, 3, 4]))
  73. """
  74. @prim_attr_register
  75. def __init__(self):
  76. """init SubAndFilter"""
  77. self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'],
  78. outputs=['sub_res', 'sub_idx'])
  79. def check_shape(self, input_x_shape, max_num_shape, offset_shape):
  80. return ((-1,), (-1,))
  81. def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype):
  82. validator.check_tensor_dtype_valid(
  83. "input_x", input_x_dtype, mstype.int_type, self.name)
  84. return input_x_dtype
  85. class SearchCacheIdx(PrimitiveWithInfer):
  86. """
  87. Search the keys of a hashmap, and return the values.
  88. Inputs:
  89. - **hashmap** (Parameter) - The dim of hashmap is (n, 4), which cols represent the `key, value, step, tag`.
  90. `key, value`: Map the indices of big table and cache table.
  91. `step`: The resent step, when searching the key, it will be updated at the same time.
  92. `step` can make sure the indices which are using in the last step will not be deleted in hashmap.
  93. `tag`: We use linear probing(`h(k, i) = (h(k) + i) % m`) to solve hash conflicts.
  94. tag is the count of linear probing times of the key. If `tag == 0`, means that the entry is empty.
  95. The Hash Function is:
  96. `((0.6180339 * key) - floor(0.618033 * key)) * hashmap_length`, in order to avoid data clustering.
  97. - **indices** (Tensor) - The indices which are keys of hashmap.
  98. - **step** (int) - The current step when searching.
  99. - **emb_max_num** (int) - Max length of big table.
  100. To avoid searching when `indices >= emb_max_num`, and make value = `cache_max_num`.
  101. - **cache_max_num** (int) - Max length of cache table.
  102. Outputs:
  103. - **cache_idx** (Tensor) - Result of searched value, if search missed, value = -1.
  104. - **miss_idx** (Tensor) - The index of Tensor indices which search missed.
  105. If search success, miss_idx[i] = -1.
  106. - **miss_emb_idx** (Tensor) - The value of Tensor indices which search missed.
  107. If search success, miss_emb_idx[i] = -1.
  108. Examples:
  109. >>> hashmap = Parameter(Tensor(np.array([[0, 0, 0, 0],
  110. [10, 5, -5, 1],
  111. [2, 1, -5, 1],
  112. [15, 7, -5, 2],
  113. [0, 0, 0, 0],
  114. [0, 0, 0, 0],
  115. [0, 0, 0, 0],
  116. [0, 0, 0, 0],
  117. [3, 3, -5, 1],
  118. [21, 9, -5, 1]], np.int32)), name="hashmap")
  119. >>> indices = Tensor(np.array([10, 2, 25, 5, 3], np.int32))
  120. >>> step = 0, emb_max_num = 25, cache_max_num = 10
  121. >>> ops = ops.SearchCacheIdx()
  122. >>> cache_idx, miss_idx, miss_emb_idx = ops(hashmap, indices, step, emb_max_num, cache_max_num)
  123. cache_idx : [5, 1, 10, -1, 3]
  124. miss_idx : [-1, -1, -1, 3, -1]
  125. miss_emb_idx : [-1, -1, -1, 5, -1]
  126. hashmap after search : [[0, 0, 0, 0],
  127. [10, 5, 0, 1],
  128. [2, 1, 0, 1],
  129. [15, 7, -5, 2],
  130. [0, 0, 0, 0],
  131. [0, 0, 0, 0],
  132. [0, 0, 0, 0],
  133. [0, 0, 0, 0],
  134. [3, 3, 0, 1],
  135. [21, 9, -5, 1]]
  136. """
  137. __mindspore_signature__ = (
  138. sig.make_sig('hashmap', sig.sig_rw.RW_WRITE,
  139. dtype=sig.sig_dtype.T),
  140. sig.make_sig('indices', dtype=sig.sig_dtype.T),
  141. sig.make_sig('step', dtype=sig.sig_dtype.T),
  142. sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T),
  143. sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T)
  144. )
  145. @prim_attr_register
  146. def __init__(self):
  147. """init SearchCacheIdx"""
  148. self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'cache_max_num'],
  149. outputs=['cache_idx', 'miss_idx', 'miss_emb_idx'])
  150. def infer_shape(self, hashmap_shape, indices_shape, step_shape, emb_max_num_shape, cache_max_num_shape):
  151. if len(hashmap_shape) != 2:
  152. raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
  153. "but got %d." % len(hashmap_shape))
  154. out_shape = (indices_shape, indices_shape, indices_shape)
  155. return out_shape
  156. def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
  157. args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
  158. validator.check_tensors_dtypes_same_and_valid(
  159. args, mstype.int_type, self.name)
  160. out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype)
  161. return out_dtype
  162. class MapUniform(PrimitiveWithCheck):
  163. """
  164. Map a tensor by using fomula : value = key % `group_num` * `per_group_size` + key // `group_num`.
  165. Inputs:
  166. - **input** (Tensor) - Input Tensor.
  167. - **per_group_size** (int) - The size of each group.
  168. - **group_num** (int) - The number of group.
  169. Outputs:
  170. Tensor, has the same dtype and shape as the `input`.
  171. Supported Platforms:
  172. `CPU`
  173. Examples:
  174. >>> input_x = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7]))
  175. >>> per_group_size = 4
  176. >>> group_num = 2
  177. >>> map_uniform = ops.MapUniform()
  178. >>> output = map_uniform(input_x, per_group_size, group_num)
  179. >>> print(output)
  180. [0, 4, 1, 5, 2, 6, 3, 7]
  181. """
  182. @prim_attr_register
  183. def __init__(self):
  184. """init MapUniform"""
  185. self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'],
  186. outputs=['output'])
  187. def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype):
  188. validator.check_tensor_dtype_valid(
  189. "input", input_dtype, mstype.int_type, self.name)
  190. validator.check_value_type(
  191. 'per_group_size', per_group_size_dtype, [mstype.Int], self.name)
  192. validator.check_value_type(
  193. 'group_num', group_num_dtype, [mstype.Int], self.name)
  194. class CacheSwapHashmap(PrimitiveWithInfer):
  195. """
  196. Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry.
  197. Inputs:
  198. - **hashmap** (Parameter) - Same to operation SearchCacheIdx.
  199. - **miss_emb_idx** (Tensor) - The keys which are going to insert, -1 is skipped. It is the result
  200. - **step** (int) - The current step.
  201. Outputs:
  202. - **swap_cache_idx** (Tensor) - Deleted value of entry, -1 is skipped.
  203. - **old_emb_idx** (Tensor) - Deleted key of entry, -1 is skipped.
  204. """
  205. __mindspore_signature__ = (
  206. sig.make_sig('hashmap', sig.sig_rw.RW_WRITE,
  207. dtype=sig.sig_dtype.T),
  208. sig.make_sig('miss_emb_idx', dtype=sig.sig_dtype.T),
  209. sig.make_sig('step', dtype=sig.sig_dtype.T)
  210. )
  211. @prim_attr_register
  212. def __init__(self):
  213. """init CacheSwapHashmap"""
  214. self.init_prim_io_names(inputs=['hashmap', 'miss_emb_idx', 'step'],
  215. outputs=['swap_cache_idx', 'old_emb_idx'])
  216. def infer_shape(self, hashmap_shape, miss_emb_idx_shape, step_shape):
  217. if len(hashmap_shape) != 2:
  218. raise ValueError("The dimension of 'hashmap' in CacheSwapHashmap must be 2, "
  219. "but got %d." % len(hashmap_shape))
  220. out_shape = (miss_emb_idx_shape, miss_emb_idx_shape)
  221. return out_shape
  222. def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype):
  223. validator.check_tensor_dtype_valid(
  224. "miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name)
  225. out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype)
  226. return out_dtype
  227. class CacheSwapTable(PrimitiveWithCheck):
  228. """
  229. Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry.
  230. Inputs:
  231. - **cache_table** (Parameter) - The cache table which is on device.
  232. - **swap_cache_idx** (Tensor) - The index of table which need to swap. -1 is skipped.
  233. - **miss_value** (int) - The values which arg going to swap into cache table.
  234. Outputs:
  235. - **old_value** (Tensor) - The values which are swapped out.
  236. """
  237. __mindspore_signature__ = (
  238. sig.make_sig('cache_table', sig.sig_rw.RW_WRITE,
  239. dtype=sig.sig_dtype.T),
  240. sig.make_sig('swap_cache_idx', dtype=sig.sig_dtype.T1),
  241. sig.make_sig('miss_value', dtype=sig.sig_dtype.T)
  242. )
  243. @prim_attr_register
  244. def __init__(self):
  245. """init CacheSwapTable"""
  246. self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'],
  247. outputs=['old_value'])
  248. def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
  249. if len(cache_table_shape) != 2:
  250. raise ValueError(
  251. "cache table shape must be 2, but got %d" % len(cache_table_shape))
  252. return miss_value_shape
  253. def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
  254. validator.check_tensor_dtype_valid(
  255. "swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
  256. return miss_value_dtype
  257. class MapCacheIdx(PrimitiveWithCheck):
  258. """
  259. MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together.
  260. When input an indices tensor, it will output the cache indices which search in hashmap.
  261. """
  262. __mindspore_signature__ = (
  263. sig.make_sig('hashmap', sig.sig_rw.RW_WRITE,
  264. dtype=sig.sig_dtype.T),
  265. sig.make_sig('indices', dtype=sig.sig_dtype.T),
  266. sig.make_sig('step', dtype=sig.sig_dtype.T),
  267. sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T),
  268. sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T)
  269. )
  270. @prim_attr_register
  271. def __init__(self):
  272. """init MapCacheIdx"""
  273. self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'],
  274. outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx'])
  275. def __check__(self, hashmap, indices, step, emb_max_num, offset):
  276. hashmap_shape = hashmap['shape']
  277. if len(hashmap_shape) != 2:
  278. raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
  279. "but got %d." % len(hashmap_shape))
  280. out_shape = (indices['shape'], -1, -1, -1)
  281. hashmap_dtype = hashmap['dtype']
  282. indices_dtype = indices['dtype']
  283. args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
  284. validator.check_tensors_dtypes_same_and_valid(
  285. args, mstype.int_type, self.name)
  286. out_dtype = (hashmap_dtype, hashmap_dtype,
  287. hashmap_dtype, hashmap_dtype)
  288. out = {'shape': out_shape,
  289. 'dtype': out_dtype,
  290. 'value': None}
  291. if 'max_shape' in indices:
  292. out['max_shape'] = (indices['max_shape'], indices['max_shape'],
  293. indices['max_shape'], indices['max_shape'])
  294. else:
  295. out['max_shape'] = (indices['shape'], indices['shape'],
  296. indices['shape'], indices['shape'])
  297. if 'min_shape' in indices:
  298. out['min_shape'] = (indices['min_shape'], 0, 0, 0)
  299. else:
  300. out['min_shape'] = (0, 0, 0, 0)
  301. return out
  302. class DynamicAssign(PrimitiveWithCheck):
  303. """
  304. Assigns `Parameter` with a value, the `value` can have a dynamic shape.
  305. Inputs:
  306. - **variable** (Parameter) - The `Parameter`.
  307. - **value** (Tensor) - The value to be assigned.
  308. Outputs:
  309. Tensor, has the same type as original `variable`.
  310. Supported Platforms:
  311. `CPU`
  312. """
  313. __mindspore_signature__ = (
  314. sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
  315. sig.make_sig('value', dtype=sig.sig_dtype.T)
  316. )
  317. @prim_attr_register
  318. def __init__(self):
  319. self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
  320. def check_dtype(self, variable, value):
  321. if variable != mstype.type_refkey:
  322. validator.check_tensor_dtype_valid(
  323. "variable", variable, mstype.number_type, self.name)
  324. validator.check_scalar_or_tensor_types_same(
  325. {"value": value}, mstype.number_type, self.name)
  326. class PadAndShift(PrimitiveWithCheck):
  327. """
  328. Pad a tensor with -1, and shift with a length.
  329. Inputs:
  330. - **input_x** (Tensor) - The input Tensor, which will be copyed
  331. to `output`.
  332. - **cum_sum_arr** (Tensor) - The last value of cum_sum_arr is
  333. the pad length of output tensor, cum_sum_arr[shift_idx] is
  334. the start to shift, and cum_sum_arr[shift_idx+1] is the end.
  335. - **shift_idx** (Int) - The idx of cum_sum_arr.
  336. if use python, PadAndShift is:
  337. output = [-1] * cum_sum_arr[-1]
  338. start = cum_sum_arr[shift_idx]
  339. end = cum_sum_arr[shift_idx + 1]
  340. output[start:end] = input_x[:(end-start)]
  341. Outputs:
  342. Tensor, has the same type as original `variable`.
  343. Supported Platforms:
  344. `CPU`
  345. Examples:
  346. >>> input_x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32)
  347. >>> cum_sum_arr = Tensor(np.array([0, 3, 5]), mstype.int32)
  348. >>> shift_idx = 1
  349. >>> pad_and_shift = ops.PadAndShift()
  350. >>> output = pad_and_shift(input_x, cum_sum_arr, shift_idx)
  351. >>> print(output)
  352. [-1, -1, -1, 9, 13]
  353. """
  354. @prim_attr_register
  355. def __init__(self):
  356. self.init_prim_io_names(
  357. inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output'])
  358. def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape):
  359. return input_x_shape
  360. def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype):
  361. return input_x_dtype