from __future__ import absolute_import from .Node import Op from .. import ndarray from .._base import _LIB, check_call from ..stream import create_event_handle class AllReduceCommunicateOp(Op): def __init__(self, nodeA, comm): super().__init__(AllReduceCommunicateOp, [nodeA], nodeA.ctx) self.on_gpu = ndarray.is_gpu_ctx(self.ctx) self.on_cpu = not self.on_gpu self.comm = comm def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: assert not isinstance( input_vals[0], (ndarray.IndexedSlices, ndarray.ND_Sparse_Array)) self.comm.dlarrayNcclAllReduce( input_vals[0], output_val, self.dtype, self.reduce_op) else: if self.event == None: self.event = create_event_handle(input_vals[0].ctx) if isinstance(input_vals[0], ndarray.NDArray): self.comm.dlarrayNcclAllReduce( input_vals[0], output_val, self.dtype, self.reduce_op, stream_handle) self.event.record(stream_handle) elif isinstance(input_vals[0], ndarray.IndexedSlices): # ?should use allgather? self.comm.dlarrayNcclAllReduce( input_vals[0].indices, output_val.indices, self.dtype, self.reduce_op, stream_handle) self.comm.dlarrayNcclAllReduce( input_vals[0].values, output_val.values, self.dtype, self.reduce_op, stream_handle) self.event.record(stream_handle) def gradient(self, output_grad): raise NotImplementedError def infer_shape(self, input_shapes): return input_shapes[0] def forward_hook(self, config): from ..communicator.mpi_nccl_comm import ncclDataType_t, ncclRedOp_t self.ctx = self.inputs[0].ctx self.on_gpu = ndarray.is_gpu_ctx(self.ctx) self.on_cpu = not self.on_gpu if self.on_gpu and self.inputs[0].event is None: self.inputs[0].event = create_event_handle(self.ctx) # disable inplace if not lazy execution # previously we use array reshape lazy callback to do this, which is deprecated (not efficient) self.inputs[0].inplace = False self.dtype = ncclDataType_t.ncclFloat32 self.reduce_op = ncclRedOp_t.ncclSum def allreduceCommunicate_op(node, comm): """Make a new instance of AllReduceCommunicateOp and call the instance. Parameters: ---- node : Node The Node to do allreduce Returns: ---- A new Node instance created by Op. """ return AllReduceCommunicateOp(node, comm) class GroupAllReduceCommunicateOp(Op): def __init__(self, nodeA, group_comm): super().__init__(GroupAllReduceCommunicateOp, [nodeA], nodeA.ctx) self.group_comm = group_comm def compute(self, input_vals, output_val, stream_handle=None): from ..communicator.mpi_nccl_comm import ncclDataType_t, ncclRedOp_t input_vals[0].copyto(output_val) self.group_comm.dlarrayNcclAllReduce( output_val, output_val, ncclDataType_t.ncclFloat32, ncclRedOp_t.ncclSum, stream_handle) def gradient(self, output_grad): raise NotImplementedError def infer_shape(self, input_shapes): return input_shapes[0] def groupallreduceCommunicate_op(node, group_comm): """Make a new instance of GroupAllReduceCommunicateOp and call the instance. Parameters: ---- node : Node The Node to do groupallreduce Returns: ---- A new Node instance created by Op. """ return GroupAllReduceCommunicateOp(node, group_comm)