| @@ -55,7 +55,9 @@ const char kNameSimpleMeanGrad[] = "SimpleMeanGrad"; | |||
| const char kNameAllReduce[] = "AllReduce"; | |||
| const char kNameBroadcast[] = "Broadcast"; | |||
| const char kNameAllgather[] = "AllGather"; | |||
| const char kNameHostAllgather[] = "HostAllGather"; | |||
| const char kNameReduceScatter[] = "ReduceScatter"; | |||
| const char kNameHostReduceScatter[] = "HostReduceScatter"; | |||
| const char kNameReduceSum[] = "ReduceSum"; | |||
| const char kNameIsFinite[] = "isFinite"; | |||
| const char kNameReciprocal[] = "Reciprocal"; | |||
| @@ -45,8 +45,10 @@ constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean"; | |||
| constexpr auto kGetNextOpName = "GetNext"; | |||
| constexpr auto kAllReduceOpName = "AllReduce"; | |||
| constexpr auto kAllGatherOpName = "AllGather"; | |||
| constexpr auto kHostAllGatherOpName = "HostAllGather"; | |||
| constexpr auto kBroadcastOpName = "Broadcast"; | |||
| constexpr auto kReduceScatterOpName = "ReduceScatter"; | |||
| constexpr auto kHostReduceScatterOpName = "HostReduceScatter"; | |||
| constexpr auto kMemCpyAsyncOpName = "memcpy_async"; | |||
| constexpr auto kTopKOpName = "TopK"; | |||
| constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches"; | |||
| @@ -18,9 +18,9 @@ import mindspore.common.dtype as mstype | |||
| from mindspore.ops import functional as F | |||
| from .. import operations as P | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ..operations.comm_ops import (AllGather, AllReduce, _AlltoAll, Broadcast, | |||
| from ..operations.comm_ops import (AllGather, HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||
| _GetTensorSlice, _MirrorOperator, ReduceOp, | |||
| ReduceScatter, _VirtualDiv) | |||
| ReduceScatter, HostReduceScatter, _VirtualDiv) | |||
| from .grad_base import bprop_getters | |||
| @@ -79,6 +79,21 @@ def get_bprop_all_gather(self): | |||
| return bprop | |||
| @bprop_getters.register(HostAllGather) | |||
| def get_bprop_host_all_gather(self): | |||
| """Generate bprop for HostAllGather""" | |||
| host_all_gather_grad = HostReduceScatter(ReduceOp.SUM, self.group) | |||
| if self.instance_name: | |||
| instance_name = "grad" + self.instance_name | |||
| host_all_gather_grad.set_prim_instance_name(instance_name) | |||
| def bprop(x, out, dout): | |||
| dx = host_all_gather_grad(dout) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(ReduceScatter) | |||
| def get_bprop_reduce_scatter(self): | |||
| """Generate bprop for ReduceScatter""" | |||
| @@ -97,6 +112,24 @@ def get_bprop_reduce_scatter(self): | |||
| return bprop | |||
| @bprop_getters.register(HostReduceScatter) | |||
| def get_bprop_host_reduce_scatter(self): | |||
| """Generate bprop for HostReduceScatter""" | |||
| host_reduce_scatter_grad = HostAllGather(self.group) | |||
| if self.instance_name: | |||
| instance_name = "grad" + self.instance_name | |||
| host_reduce_scatter_grad.set_prim_instance_name(instance_name) | |||
| if self.op != ReduceOp.SUM: | |||
| raise RuntimeError("The hostreducescatter bprop only support ReduceOp.SUM until now.") | |||
| def bprop(x, out, dout): | |||
| dx = host_reduce_scatter_grad(dout) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(_AlltoAll) | |||
| def get_bprop_all_to_all(self): | |||
| """Generate bprop for AlltoAll.""" | |||
| @@ -32,7 +32,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace) | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice) | |||
| _VirtualDiv, _GetTensorSlice, | |||
| HostAllGather, HostReduceScatter) | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | |||
| TensorSummary, HistogramSummary, Print) | |||
| from .control_ops import ControlDepend, GeSwitch, Merge | |||
| @@ -217,8 +218,10 @@ __all__ = [ | |||
| 'UnsortedSegmentSum', | |||
| 'UnsortedSegmentMin', | |||
| "AllGather", | |||
| "HostAllGather", | |||
| "AllReduce", | |||
| "ReduceScatter", | |||
| "HostReduceScatter", | |||
| "Broadcast", | |||
| "ReduceOp", | |||
| 'ScalarCast', | |||
| @@ -169,6 +169,72 @@ class AllGather(PrimitiveWithInfer): | |||
| raise NotImplementedError | |||
| class HostAllGather(PrimitiveWithInfer): | |||
| """ | |||
| Gathers tensors from the specified communication group on host. | |||
| Note: | |||
| Tensor must have the same shape and format in all processes participating in the collective. | |||
| Args: | |||
| group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. | |||
| Raises: | |||
| TypeError: If group is not a list nor tuple, or elements of group are not int. | |||
| ValueError: If the local rank id of the calling process not in group, | |||
| or rank_id from group not in [0, 7]. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||
| Outputs: | |||
| Tensor. If the number of devices in the group is N, | |||
| then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`. | |||
| Examples: | |||
| >>> from mindspore.communication import init | |||
| >>> import mindspore.ops.operations as P | |||
| >>> init('nccl') | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.hostallgather = P.HostAllGather(group=(0, 1, 2, 3)) | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> return self.hostallgather(x) | |||
| >>> | |||
| >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) | |||
| >>> net = Net() | |||
| >>> output = net(input_) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, group=None): | |||
| if group is None: | |||
| raise ValueError(f"For '{self.name}' group must be set.") | |||
| validator.check_value_type('group', group, (tuple, list), self.name) | |||
| validator.check_integer("group size", len(group), 2, Rel.GE, self.name) | |||
| for r in group: | |||
| validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name) | |||
| validator.check_value_type("rank_id", r, (int,), self.name) | |||
| self.group_size = len(group) | |||
| self.rank = get_rank() | |||
| validator.check('rank', self.rank, 'group', self.group, Rel.IN, self.name) | |||
| self.add_prim_attr('group', group) | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name) | |||
| x_shape[0] = x_shape[0] * self.group_size | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||
| return x_dtype | |||
| def __call__(self, tensor): | |||
| raise NotImplementedError | |||
| class ReduceScatter(PrimitiveWithInfer): | |||
| """ | |||
| Reduces and scatters tensors from the specified communication group. | |||
| @@ -226,6 +292,68 @@ class ReduceScatter(PrimitiveWithInfer): | |||
| raise NotImplementedError | |||
| class HostReduceScatter(PrimitiveWithInfer): | |||
| """ | |||
| Reduces and scatters tensors from the specified communication group on host. | |||
| Note: | |||
| Tensor must have the same shape and format in all processes participating in the collective. | |||
| Args: | |||
| op (str): Specifies an operation used for element-wise reductions, | |||
| like sum, max, avg. Default: ReduceOp.SUM. | |||
| group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. | |||
| Raise: | |||
| TypeError: If op is not a string and group is not a list nor tuple, | |||
| or elements of group are not int. | |||
| ValueError: If the first dimension of input can not be divided by rank size, | |||
| or group is not set, or rank_id not in [1, 7]. | |||
| Examples: | |||
| >>> from mindspore.communication import init | |||
| >>> import mindspore.ops.operations as P | |||
| >>> init('nccl') | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.hostreducescatter = P.HostReduceScatter(ReduceOp.SUM, group=[0, 1, 2, 3]) | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> return self.hostreducescatter(x) | |||
| >>> | |||
| >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) | |||
| >>> net = Net() | |||
| >>> output = net(input_) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, op=ReduceOp.SUM, group=None): | |||
| if group is None: | |||
| raise ValueError(f"For '{self.name}' group must be set.") | |||
| validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) | |||
| validator.check_value_type('group', group, (tuple, list), self.name) | |||
| validator.check_integer("group size", len(group), 2, Rel.GE, self.name) | |||
| for r in group: | |||
| validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name) | |||
| validator.check_value_type("rank_id", r, (int,), self.name) | |||
| self.op = op | |||
| self.group_size = len(group) | |||
| self.add_prim_attr('group', group) | |||
| def infer_shape(self, x_shape): | |||
| if x_shape[0] % self.group_size != 0: | |||
| raise ValueError(f"For '{self.name}' the first dimension of x should be divided by group_size.") | |||
| x_shape[0] = int(x_shape[0]/self.group_size) | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||
| return x_dtype | |||
| def __call__(self, tensor): | |||
| raise NotImplementedError | |||
| class Broadcast(PrimitiveWithInfer): | |||
| """ | |||
| Broadcasts the tensor to the whole group. | |||
| @@ -26,6 +26,7 @@ from mindspore.nn import Momentum | |||
| from mindspore.nn import ReLU | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter | |||
| from mindspore.ops.operations.comm_ops import HostAllGather, HostReduceScatter | |||
| from mindspore.ops.operations.comm_ops import Broadcast | |||
| # pylint: disable=W0212 | |||
| @@ -86,6 +87,21 @@ class AllGatherNet(nn.Cell): | |||
| return self.relu(x) | |||
| class HostAllGatherNet(nn.Cell): | |||
| """HostAllGatherNet definition""" | |||
| def __init__(self, input_channel, output_channel): | |||
| super(HostAllGatherNet, self).__init__() | |||
| self.dense = Dense(input_channel, output_channel) | |||
| self.hostallgather = HostAllGather((0, 1)) | |||
| self.relu = ReLU() | |||
| def construct(self, x): | |||
| x = self.dense(x) | |||
| x = self.hostallgather(x) | |||
| return self.relu(x) | |||
| class ReduceScatterNet(nn.Cell): | |||
| """ReduceScatterNet definition""" | |||
| @@ -101,6 +117,21 @@ class ReduceScatterNet(nn.Cell): | |||
| return self.relu(x) | |||
| class HostReduceScatterNet(nn.Cell): | |||
| """HostReduceScatterNet definition""" | |||
| def __init__(self, input_channel, out_channel, op): | |||
| super(HostReduceScatterNet, self).__init__() | |||
| self.dense = Dense(input_channel, out_channel) | |||
| self.hostreducescatter = HostReduceScatter(op, (0, 1)) | |||
| self.relu = ReLU() | |||
| def construct(self, x): | |||
| x = self.dense(x) | |||
| x = self.hostreducescatter(x) | |||
| return self.relu(x) | |||
| class AlltoAllNet(nn.Cell): | |||
| """AlltoAllNet definition""" | |||
| @@ -154,6 +185,21 @@ def test_allgather(): | |||
| _executor.compile(network, input_tensor, label_tensor) | |||
| def test_hostallgather(): | |||
| """test_hostallgather""" | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) | |||
| label_tensor = Tensor(np.array([[1.2], [2.2], [3.2], [4.2]], dtype=np.float32)) | |||
| network = HostAllGatherNet(2, 1) | |||
| loss_fn = nn.SoftmaxCrossEntropyWithLogits() | |||
| optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), | |||
| learning_rate=0.1, | |||
| momentum=0.9) | |||
| network = WithLossCell(network, loss_fn) | |||
| network = TrainOneStepCell(network, optimizer) | |||
| _executor.compile(network, input_tensor, label_tensor) | |||
| def run_reducescatter(op): | |||
| """run_reducescatter""" | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @@ -175,6 +221,21 @@ def test_reducescatter(): | |||
| run_reducescatter(ReduceOp.SUM) | |||
| def test_hostreducescatter(): | |||
| """test_hostreducescatter""" | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) | |||
| label_tensor = Tensor(np.array([[1.2]], dtype=np.float32)) | |||
| network = HostReduceScatterNet(2, 1, ReduceOp.SUM) | |||
| loss_fn = nn.SoftmaxCrossEntropyWithLogits() | |||
| optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), | |||
| learning_rate=0.1, | |||
| momentum=0.9) | |||
| network = WithLossCell(network, loss_fn) | |||
| network = TrainOneStepCell(network, optimizer) | |||
| _executor.compile(network, input_tensor, label_tensor) | |||
| def test_broadcast(): | |||
| """test_broadcast""" | |||
| context.set_context(mode=context.GRAPH_MODE) | |||