|
- from __future__ import absolute_import
- from .Node import Op
- from .. import ndarray
- from ..communicator.mpi_nccl_comm import ncclDataType_t, ncclRedOp_t
- from ..stream import create_event_handle, create_stream_handle
-
-
- class PipelineReceiveOp(Op):
- def __init__(self, source, comm, stream=None, ctx=None):
- assert ctx, "PipelineReceiveOp must be initialized with the ctx argument!"
- super().__init__(PipelineReceiveOp, [], ctx)
- self.const_attr = source
- self.comm = comm
- self.comm_stream = stream
- self.desc = self.name + \
- '(%s receive from %s)' % (str(self.ctx.device_id), str(source))
- self.shape = None
- self.shape_is_received = False
-
- def compute(self, input_vals, output_val, stream_handle=None):
- assert not self.on_cpu, "PipelineReceiveOp only support P2P communication between gpus"
-
- assert self.comm_stream, "communicate stream should not be None"
- if self.event == None:
- self.event = create_event_handle(self.ctx)
- self.comm.dlarrayRecv(output_val,
- ncclDataType_t.ncclFloat32,
- self.const_attr,
- self.comm_stream)
- self.event.record(self.comm_stream)
-
- def gradient(self, output_grad):
- return []
-
- def infer_shape(self, input_shapes):
- if not self.shape_is_received:
- # receive
- shape_arr = ndarray.array([0, 0, 0], self.ctx)
- self.comm.dlarrayRecv(shape_arr,
- ncclDataType_t.ncclFloat32,
- self.const_attr,
- self.comm_stream)
-
- # remove padding and save
- shape_arr = [int(x) for x in list(shape_arr.asnumpy()) if x != 0]
- self.shape = tuple(shape_arr)
- self.shape_is_received = True
-
- return self.shape
-
- def forward_hook(self, config):
- self.on_gpu = ndarray.is_gpu_ctx(self.ctx)
- self.on_cpu = not self.on_gpu
-
-
- def pipeline_receive_op(source, comm, stream=None, ctx=None):
- """Make a new instance of PipelineReceiveOp and call the instance.
-
- Parameters:
- ----
- source : scalar value
- The gpu index for source.
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return PipelineReceiveOp(source, comm, stream=stream, ctx=ctx)
|