You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_cache_ops.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """cache_ops"""
  16. from ..._checkparam import Validator as validator
  17. from ...common import dtype as mstype
  18. from ..primitive import PrimitiveWithInfer, prim_attr_register
  19. from .. import signature as sig
  20. class UpdateCache(PrimitiveWithInfer):
  21. """
  22. Update the value fo input_x, similar to ScatterNdUpdate.
  23. The diffirent is that UpdateCache will not update when indices < 0 or indices >= max_num.
  24. Inputs:
  25. - **input_x** (Parameter) - Parameter which is going to be updated.
  26. - **indices** (Tensor) - Update indices of input_x.
  27. - **updates** (Tensor) - The update values.
  28. Outputs:
  29. - **out** (Tensor) - Returns a [1] Tensor, which is not usefull.
  30. """
  31. __mindspore_signature__ = (
  32. sig.make_sig('input_x', sig.sig_rw.RW_WRITE,
  33. dtype=sig.sig_dtype.T),
  34. sig.make_sig('indices', dtype=sig.sig_dtype.T1),
  35. sig.make_sig('updates', dtype=sig.sig_dtype.T),
  36. sig.make_sig('max_num', dtype=sig.sig_dtype.T1)
  37. )
  38. @prim_attr_register
  39. def __init__(self):
  40. """init UpdateCache"""
  41. self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'],
  42. outputs=['out'])
  43. def infer_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
  44. if len(indices_shape) < 2:
  45. raise ValueError("The dimension of 'indices' in UpdateCache must >= 2, "
  46. "but got %d." % len(indices_shape))
  47. return [1]
  48. def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
  49. args = {"indices": indices_dtype}
  50. validator.check_tensor_type_same(args, mstype.int_type, self.name)
  51. return input_x_dtype
  52. class SearchCacheIdx(PrimitiveWithInfer):
  53. """
  54. Search the keys of a hashmap, and return the values.
  55. Inputs:
  56. - **hashmap** (Parameter) - The dim of hashmap is (n, 4), which cols represent the `key, value, step, tag`.
  57. `key, value`: Map the indices of big table and cache table.
  58. `step`: The resent step, when searching the key, it will be updated at the same time.
  59. `step` can make sure the indices which are using in the last step will not be deleted in hashmap.
  60. `tag`: We use linear probing(`h(k, i) = (h(k) + i) % m`) to solve hash conflicts.
  61. tag is the count of linear probing times of the key. If `tag == 0`, means that the entry is empty.
  62. The Hash Function is:
  63. `((0.6180339 * key) - floor(0.618033 * key)) * hashmap_length`, in order to avoid data clustering.
  64. - **indices** (Tensor) - The indices which are keys of hashmap.
  65. - **step** (int) - The current step when searching.
  66. - **emb_max_num** (int) - Max length of big table.
  67. To avoid searching when `indices >= emb_max_num`, and make value = `cache_max_num`.
  68. - **cache_max_num** (int) - Max length of cache table.
  69. Outputs:
  70. - **cache_idx** (Tensor) - Result of searched value, if search missed, value = -1.
  71. - **miss_idx** (Tensor) - The index of Tensor indices which search missed.
  72. If search success, miss_idx[i] = -1.
  73. - **miss_emb_idx** (Tensor) - The value of Tensor indices which search missed.
  74. If search success, miss_emb_idx[i] = -1.
  75. Examples:
  76. >>> hashmap = Parameter(Tensor(np.array([[0, 0, 0, 0],
  77. [10, 5, -5, 1],
  78. [2, 1, -5, 1],
  79. [15, 7, -5, 2],
  80. [0, 0, 0, 0],
  81. [0, 0, 0, 0],
  82. [0, 0, 0, 0],
  83. [0, 0, 0, 0],
  84. [3, 3, -5, 1],
  85. [21, 9, -5, 1]], np.int32)), name="hashmap")
  86. >>> indices = Tensor(np.array([10, 2, 25, 5, 3], np.int32))
  87. >>> step = 0, emb_max_num = 25, cache_max_num = 10
  88. >>> ops = P.SearchCacheIdx()
  89. >>> cache_idx, miss_idx, miss_emb_idx = ops(hashmap, indices, step, emb_max_num, cache_max_num)
  90. cache_idx : [5, 1, 10, -1, 3]
  91. miss_idx : [-1, -1, -1, 3, -1]
  92. miss_emb_idx : [-1, -1, -1, 5, -1]
  93. hashmap after search : [[0, 0, 0, 0],
  94. [10, 5, 0, 1],
  95. [2, 1, 0, 1],
  96. [15, 7, -5, 2],
  97. [0, 0, 0, 0],
  98. [0, 0, 0, 0],
  99. [0, 0, 0, 0],
  100. [0, 0, 0, 0],
  101. [3, 3, 0, 1],
  102. [21, 9, -5, 1]]
  103. """
  104. __mindspore_signature__ = (
  105. sig.make_sig('hashmap', sig.sig_rw.RW_WRITE,
  106. dtype=sig.sig_dtype.T),
  107. sig.make_sig('indices', dtype=sig.sig_dtype.T),
  108. sig.make_sig('step', dtype=sig.sig_dtype.T),
  109. sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T),
  110. sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T)
  111. )
  112. @prim_attr_register
  113. def __init__(self):
  114. """init SearchCacheIdx"""
  115. self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'cache_max_num'],
  116. outputs=['cache_idx', 'miss_idx', 'miss_emb_idx'])
  117. def infer_shape(self, hashmap_shape, indices_shape, step_shape, emb_max_num_shape, cache_max_num_shape):
  118. if len(hashmap_shape) != 2:
  119. raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
  120. "but got %d." % len(hashmap_shape))
  121. out_shape = (indices_shape, indices_shape, indices_shape)
  122. return out_shape
  123. def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
  124. args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
  125. validator.check_tensor_type_same(args, mstype.int_type, self.name)
  126. out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype)
  127. return out_dtype
  128. class CacheSwapHashmap(PrimitiveWithInfer):
  129. """
  130. Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry.
  131. Inputs:
  132. - **hashmap** (Parameter) - Same to operation SearchCacheIdx.
  133. - **miss_emb_idx** (Tensor) - The keys which are going to insert, -1 is skipped. It is the result
  134. - **step** (int) - The current step.
  135. Outputs:
  136. - **swap_cache_idx** (Tensor) - Deleted value of entry, -1 is skipped.
  137. - **old_emb_idx** (Tensor) - Deleted key of entry, -1 is skipped.
  138. """
  139. __mindspore_signature__ = (
  140. sig.make_sig('hashmap', sig.sig_rw.RW_WRITE,
  141. dtype=sig.sig_dtype.T),
  142. sig.make_sig('miss_emb_idx', dtype=sig.sig_dtype.T),
  143. sig.make_sig('step', dtype=sig.sig_dtype.T)
  144. )
  145. @prim_attr_register
  146. def __init__(self):
  147. """init CacheSwapHashmap"""
  148. self.init_prim_io_names(inputs=['hashmap', 'miss_emb_idx', 'step'],
  149. outputs=['swap_cache_idx', 'old_emb_idx'])
  150. def infer_shape(self, hashmap_shape, miss_emb_idx_shape, step_shape):
  151. if len(hashmap_shape) != 2:
  152. raise ValueError("The dimension of 'hashmap' in CacheSwapHashmap must be 2, "
  153. "but got %d." % len(hashmap_shape))
  154. out_shape = (miss_emb_idx_shape, miss_emb_idx_shape)
  155. return out_shape
  156. def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype):
  157. args = {"miss_emb_idx": miss_emb_idx_dtype}
  158. validator.check_tensor_type_same(args, mstype.int_type, self.name)
  159. out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype)
  160. return out_dtype
  161. class CacheSwapTable(PrimitiveWithInfer):
  162. """
  163. Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry.
  164. Inputs:
  165. - **cache_table** (Parameter) - The cache table which is on device.
  166. - **swap_cache_idx** (Tensor) - The index of table which need to swap. -1 is skipped.
  167. - **miss_value** (int) - The values which arg going to swap into cache table.
  168. Outputs:
  169. - **old_value** (Tensor) - The values which are swapped out.
  170. """
  171. __mindspore_signature__ = (
  172. sig.make_sig('cache_table', sig.sig_rw.RW_WRITE,
  173. dtype=sig.sig_dtype.T),
  174. sig.make_sig('swap_cache_idx', dtype=sig.sig_dtype.T1),
  175. sig.make_sig('miss_value', dtype=sig.sig_dtype.T)
  176. )
  177. @prim_attr_register
  178. def __init__(self):
  179. """init CacheSwapTable"""
  180. self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'],
  181. outputs=['old_value'])
  182. def infer_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
  183. if len(cache_table_shape) != 2:
  184. raise ValueError(
  185. "cache table shape must be 2, but got %d" % len(cache_table_shape))
  186. if swap_cache_idx_shape + cache_table_shape[1:] != miss_value_shape:
  187. raise ValueError(
  188. "swap_cache_idx_shape + cache_table_shape[1:] must equal to miss_value_shape")
  189. return miss_value_shape
  190. def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
  191. args = {"swap_cache_idx": swap_cache_idx_dtype}
  192. validator.check_tensor_type_same(args, mstype.int_type, self.name)
  193. return miss_value_dtype
  194. class MapCacheIdx(PrimitiveWithInfer):
  195. """
  196. MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together.
  197. When input an indices tensor, it will output the cache indices which search in hashmap.
  198. """
  199. __mindspore_signature__ = (
  200. sig.make_sig('hashmap', sig.sig_rw.RW_WRITE,
  201. dtype=sig.sig_dtype.T),
  202. sig.make_sig('indices', dtype=sig.sig_dtype.T),
  203. sig.make_sig('step', dtype=sig.sig_dtype.T),
  204. sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T),
  205. sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T)
  206. )
  207. @prim_attr_register
  208. def __init__(self):
  209. """init MapCacheIdx"""
  210. self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'cache_max_num'],
  211. outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx'])
  212. def infer_shape(self, hashmap_shape, indices_shape, step_shape, emb_max_num_shape, cache_max_num_shape):
  213. if len(hashmap_shape) != 2:
  214. raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
  215. "but got %d." % len(hashmap_shape))
  216. out_shape = (indices_shape, indices_shape,
  217. indices_shape, indices_shape)
  218. return out_shape
  219. def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
  220. args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
  221. validator.check_tensor_type_same(args, mstype.int_type, self.name)
  222. out_dtype = (hashmap_dtype, hashmap_dtype,
  223. hashmap_dtype, hashmap_dtype)
  224. return out_dtype