|
- from __future__ import absolute_import
- from .Node import Op
- from .. import ndarray
- from .._base import DNNL_LIB
- import numpy as np
- from ..cpu_links import embedding_lookup as cpu_embedding_lookup
- from ..gpu_links import embedding_lookup
-
-
- class EmbeddingLookUp(Op):
- def __init__(self, embedding, index, ctx=None):
- super().__init__(EmbeddingLookUp, [embedding, index], ctx)
- embedding.is_embed = True
-
- def _compute_cpu_dnnl(self, input_vals, output_val, stream_handle=None):
- cpu_embedding_lookup(input_vals[0], input_vals[1], output_val)
-
- def _compute_cpu_numpy(self, input_vals, output_val, stream_handle=None):
- flatten_index = input_vals[1].asnumpy().reshape(-1).astype(np.int32)
- output_val[:] = input_vals[0].asnumpy(
- )[flatten_index].reshape(output_val.shape)
-
- def _compute_gpu(self, input_vals, output_val, stream_handle=None):
- embedding_lookup(input_vals[0], input_vals[1],
- output_val, stream_handle)
-
- def _compute_sparsepull_from_ps(self, input_vals, output_val, stream_handle=None):
- self.event.sync()
- if self.bsp:
- self.comm.BarrierWorker()
- self.comm.SparsePull(
- self.ps_id, input_vals[1].handle, output_val.handle)
- self.event.update()
-
- def _compute_sparsepull_from_cache(self, input_vals, output_val, stream_handle=None):
- self.event.sync()
- if self.bsp:
- self.comm.BarrierWorker()
- ts = self.inputs[0].cache.embedding_lookup(input_vals[1], output_val)
- self.event.update_ts(ts)
-
- def gradient(self, output_grad):
- self.grad_node = embedding_lookup_gradient_op(
- output_grad, self.inputs[1], None, ctx=self.inputs[0].ctx)
- return [self.grad_node, None]
-
- def infer_shape(self, input_shapes):
- assert len(input_shapes) == 2
- if hasattr(self, 'grad_node'):
- self.grad_node.embed_shape = input_shapes[0]
- output_shape = list(input_shapes[1])
- output_shape.append(input_shapes[0][1])
- return tuple(output_shape)
-
- def forward_hook(self, config):
- super().forward_hook(config)
- # insert data transfer op if needed
- if config.use_sparse_pull or config.cstable_policy:
- self.event = self.inputs[0].event
- if not config.prefetch:
- self.bsp = config.bsp
- self.comm = config.ps_comm
- if config.cstable_policy:
- self.compute = self._compute_sparsepull_from_cache
- else:
- self.ps_id = self.inputs[0].id
- self.compute = self._compute_sparsepull_from_ps
- else:
- if self.on_cpu and DNNL_LIB['cpu_EmbeddingLookup']:
- self.compute = self._compute_cpu_dnnl
- elif self.on_cpu:
- self.compute = self._compute_cpu_numpy
- else:
- self.compute = self._compute_gpu
-
- def backward_hook(self, config):
- # insert data transfer op if needed
- local_comm_mode = config.node_strategy.get(self, config.comm_mode)
- assert local_comm_mode != 'AllReduce' and local_comm_mode == config.node_strategy.get(self.inputs[0], config.comm_mode), \
- 'Embedding lookup communication mode invalid. Should conform with embedding parameter and not be AllReduce.'
- if local_comm_mode in ('PS', 'Hybrid'):
- cpu_ctx = ndarray.cpu(0)
- self.ctx = cpu_ctx
- for n in self.inputs:
- n.ctx = cpu_ctx
-
-
- class EmbeddingLookUp_Gradient(Op):
- def __init__(self, vectors, index, embed_shape, ctx=None):
- super().__init__(EmbeddingLookUp_Gradient, [vectors, index], ctx)
- self.embed_shape = embed_shape
-
- def compute(self, input_vals, output_val, stream_handle=None):
- assert self.embed_shape
- output_val.update(
- values=input_vals[0], indices=input_vals[1], dense_shape=self.embed_shape)
-
- def gradient(self, output_grad):
- raise NotImplementedError
-
- def infer_shape(self, input_shapes):
- assert self.embed_shape
- return self.embed_shape
-
- def backward_hook(self, config):
- # insert data transfer op if needed
- if config.comm_mode == 'PS' or config.comm_mode == "Hybrid":
- self.ctx = ndarray.cpu(0)
-
-
- def embedding_lookup_op(embedding, index, ctx=None):
- """Make a new instance of EmbeddingLookUp and call the instance.
-
- Parameters:
- ----
- embedding : Node
- The Node of Embedding.
- index : Node
- The index to be looked up.
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return EmbeddingLookUp(embedding, index, ctx=ctx)
-
-
- def embedding_lookup_gradient_op(vectors, index, embed_shape, ctx=None):
- """Make a new instance of EmbeddingLookUp_Gradient and call the instance.
-
- Parameters:
- ----
- vectors : Node
- Vectors which looked up from Embedding.
- index : Node
- The index to be looked up.
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return EmbeddingLookUp_Gradient(vectors, index, embed_shape, ctx=ctx)
|