You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AllReduceCommunicate.py 3.7 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from __future__ import absolute_import
  2. from .Node import Op
  3. from .. import ndarray
  4. from .._base import _LIB, check_call
  5. from ..stream import create_event_handle
  6. class AllReduceCommunicateOp(Op):
  7. def __init__(self, nodeA, comm):
  8. super().__init__(AllReduceCommunicateOp, [nodeA], nodeA.ctx)
  9. self.on_gpu = ndarray.is_gpu_ctx(self.ctx)
  10. self.on_cpu = not self.on_gpu
  11. self.comm = comm
  12. def compute(self, input_vals, output_val, stream_handle=None):
  13. if self.on_cpu:
  14. assert not isinstance(
  15. input_vals[0], (ndarray.IndexedSlices, ndarray.ND_Sparse_Array))
  16. self.comm.dlarrayNcclAllReduce(
  17. input_vals[0], output_val, self.dtype, self.reduce_op)
  18. else:
  19. if self.event == None:
  20. self.event = create_event_handle(input_vals[0].ctx)
  21. if isinstance(input_vals[0], ndarray.NDArray):
  22. self.comm.dlarrayNcclAllReduce(
  23. input_vals[0], output_val, self.dtype, self.reduce_op, stream_handle)
  24. self.event.record(stream_handle)
  25. elif isinstance(input_vals[0], ndarray.IndexedSlices):
  26. # ?should use allgather?
  27. self.comm.dlarrayNcclAllReduce(
  28. input_vals[0].indices, output_val.indices, self.dtype, self.reduce_op, stream_handle)
  29. self.comm.dlarrayNcclAllReduce(
  30. input_vals[0].values, output_val.values, self.dtype, self.reduce_op, stream_handle)
  31. self.event.record(stream_handle)
  32. def gradient(self, output_grad):
  33. raise NotImplementedError
  34. def infer_shape(self, input_shapes):
  35. return input_shapes[0]
  36. def forward_hook(self, config):
  37. from ..communicator.mpi_nccl_comm import ncclDataType_t, ncclRedOp_t
  38. self.ctx = self.inputs[0].ctx
  39. self.on_gpu = ndarray.is_gpu_ctx(self.ctx)
  40. self.on_cpu = not self.on_gpu
  41. if self.on_gpu and self.inputs[0].event is None:
  42. self.inputs[0].event = create_event_handle(self.ctx)
  43. # disable inplace if not lazy execution
  44. # previously we use array reshape lazy callback to do this, which is deprecated (not efficient)
  45. self.inputs[0].inplace = False
  46. self.dtype = ncclDataType_t.ncclFloat32
  47. self.reduce_op = ncclRedOp_t.ncclSum
  48. def allreduceCommunicate_op(node, comm):
  49. """Make a new instance of AllReduceCommunicateOp and call the instance.
  50. Parameters:
  51. ----
  52. node : Node
  53. The Node to do allreduce
  54. Returns:
  55. ----
  56. A new Node instance created by Op.
  57. """
  58. return AllReduceCommunicateOp(node, comm)
  59. class GroupAllReduceCommunicateOp(Op):
  60. def __init__(self, nodeA, group_comm):
  61. super().__init__(GroupAllReduceCommunicateOp, [nodeA], nodeA.ctx)
  62. self.group_comm = group_comm
  63. def compute(self, input_vals, output_val, stream_handle=None):
  64. from ..communicator.mpi_nccl_comm import ncclDataType_t, ncclRedOp_t
  65. input_vals[0].copyto(output_val)
  66. self.group_comm.dlarrayNcclAllReduce(
  67. output_val, output_val, ncclDataType_t.ncclFloat32, ncclRedOp_t.ncclSum, stream_handle)
  68. def gradient(self, output_grad):
  69. raise NotImplementedError
  70. def infer_shape(self, input_shapes):
  71. return input_shapes[0]
  72. def groupallreduceCommunicate_op(node, group_comm):
  73. """Make a new instance of GroupAllReduceCommunicateOp and call the instance.
  74. Parameters:
  75. ----
  76. node : Node
  77. The Node to do groupallreduce
  78. Returns:
  79. ----
  80. A new Node instance created by Op.
  81. """
  82. return GroupAllReduceCommunicateOp(node, group_comm)