You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

EmbeddingLookUp.py 5.2 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from __future__ import absolute_import
  2. from .Node import Op
  3. from .. import ndarray
  4. from .._base import DNNL_LIB
  5. import numpy as np
  6. from ..cpu_links import embedding_lookup as cpu_embedding_lookup
  7. from ..gpu_links import embedding_lookup
  8. class EmbeddingLookUp(Op):
  9. def __init__(self, embedding, index, ctx=None):
  10. super().__init__(EmbeddingLookUp, [embedding, index], ctx)
  11. embedding.is_embed = True
  12. def _compute_cpu_dnnl(self, input_vals, output_val, stream_handle=None):
  13. cpu_embedding_lookup(input_vals[0], input_vals[1], output_val)
  14. def _compute_cpu_numpy(self, input_vals, output_val, stream_handle=None):
  15. flatten_index = input_vals[1].asnumpy().reshape(-1).astype(np.int32)
  16. output_val[:] = input_vals[0].asnumpy(
  17. )[flatten_index].reshape(output_val.shape)
  18. def _compute_gpu(self, input_vals, output_val, stream_handle=None):
  19. embedding_lookup(input_vals[0], input_vals[1],
  20. output_val, stream_handle)
  21. def _compute_sparsepull_from_ps(self, input_vals, output_val, stream_handle=None):
  22. self.event.sync()
  23. if self.bsp:
  24. self.comm.BarrierWorker()
  25. self.comm.SparsePull(
  26. self.ps_id, input_vals[1].handle, output_val.handle)
  27. self.event.update()
  28. def _compute_sparsepull_from_cache(self, input_vals, output_val, stream_handle=None):
  29. self.event.sync()
  30. if self.bsp:
  31. self.comm.BarrierWorker()
  32. ts = self.inputs[0].cache.embedding_lookup(input_vals[1], output_val)
  33. self.event.update_ts(ts)
  34. def gradient(self, output_grad):
  35. self.grad_node = embedding_lookup_gradient_op(
  36. output_grad, self.inputs[1], None, ctx=self.inputs[0].ctx)
  37. return [self.grad_node, None]
  38. def infer_shape(self, input_shapes):
  39. assert len(input_shapes) == 2
  40. if hasattr(self, 'grad_node'):
  41. self.grad_node.embed_shape = input_shapes[0]
  42. output_shape = list(input_shapes[1])
  43. output_shape.append(input_shapes[0][1])
  44. return tuple(output_shape)
  45. def forward_hook(self, config):
  46. super().forward_hook(config)
  47. # insert data transfer op if needed
  48. if config.use_sparse_pull or config.cstable_policy:
  49. self.event = self.inputs[0].event
  50. if not config.prefetch:
  51. self.bsp = config.bsp
  52. self.comm = config.ps_comm
  53. if config.cstable_policy:
  54. self.compute = self._compute_sparsepull_from_cache
  55. else:
  56. self.ps_id = self.inputs[0].id
  57. self.compute = self._compute_sparsepull_from_ps
  58. else:
  59. if self.on_cpu and DNNL_LIB['cpu_EmbeddingLookup']:
  60. self.compute = self._compute_cpu_dnnl
  61. elif self.on_cpu:
  62. self.compute = self._compute_cpu_numpy
  63. else:
  64. self.compute = self._compute_gpu
  65. def backward_hook(self, config):
  66. # insert data transfer op if needed
  67. local_comm_mode = config.node_strategy.get(self, config.comm_mode)
  68. assert local_comm_mode != 'AllReduce' and local_comm_mode == config.node_strategy.get(self.inputs[0], config.comm_mode), \
  69. 'Embedding lookup communication mode invalid. Should conform with embedding parameter and not be AllReduce.'
  70. if local_comm_mode in ('PS', 'Hybrid'):
  71. cpu_ctx = ndarray.cpu(0)
  72. self.ctx = cpu_ctx
  73. for n in self.inputs:
  74. n.ctx = cpu_ctx
  75. class EmbeddingLookUp_Gradient(Op):
  76. def __init__(self, vectors, index, embed_shape, ctx=None):
  77. super().__init__(EmbeddingLookUp_Gradient, [vectors, index], ctx)
  78. self.embed_shape = embed_shape
  79. def compute(self, input_vals, output_val, stream_handle=None):
  80. assert self.embed_shape
  81. output_val.update(
  82. values=input_vals[0], indices=input_vals[1], dense_shape=self.embed_shape)
  83. def gradient(self, output_grad):
  84. raise NotImplementedError
  85. def infer_shape(self, input_shapes):
  86. assert self.embed_shape
  87. return self.embed_shape
  88. def backward_hook(self, config):
  89. # insert data transfer op if needed
  90. if config.comm_mode == 'PS' or config.comm_mode == "Hybrid":
  91. self.ctx = ndarray.cpu(0)
  92. def embedding_lookup_op(embedding, index, ctx=None):
  93. """Make a new instance of EmbeddingLookUp and call the instance.
  94. Parameters:
  95. ----
  96. embedding : Node
  97. The Node of Embedding.
  98. index : Node
  99. The index to be looked up.
  100. Returns:
  101. ----
  102. A new Node instance created by Op.
  103. """
  104. return EmbeddingLookUp(embedding, index, ctx=ctx)
  105. def embedding_lookup_gradient_op(vectors, index, embed_shape, ctx=None):
  106. """Make a new instance of EmbeddingLookUp_Gradient and call the instance.
  107. Parameters:
  108. ----
  109. vectors : Node
  110. Vectors which looked up from Embedding.
  111. index : Node
  112. The index to be looked up.
  113. Returns:
  114. ----
  115. A new Node instance created by Op.
  116. """
  117. return EmbeddingLookUp_Gradient(vectors, index, embed_shape, ctx=ctx)