| @@ -55,7 +55,9 @@ const char kNameSimpleMeanGrad[] = "SimpleMeanGrad"; | |||||
| const char kNameAllReduce[] = "AllReduce"; | const char kNameAllReduce[] = "AllReduce"; | ||||
| const char kNameBroadcast[] = "Broadcast"; | const char kNameBroadcast[] = "Broadcast"; | ||||
| const char kNameAllgather[] = "AllGather"; | const char kNameAllgather[] = "AllGather"; | ||||
| const char kNameHostAllgather[] = "HostAllGather"; | |||||
| const char kNameReduceScatter[] = "ReduceScatter"; | const char kNameReduceScatter[] = "ReduceScatter"; | ||||
| const char kNameHostReduceScatter[] = "HostReduceScatter"; | |||||
| const char kNameReduceSum[] = "ReduceSum"; | const char kNameReduceSum[] = "ReduceSum"; | ||||
| const char kNameIsFinite[] = "isFinite"; | const char kNameIsFinite[] = "isFinite"; | ||||
| const char kNameReciprocal[] = "Reciprocal"; | const char kNameReciprocal[] = "Reciprocal"; | ||||
| @@ -45,8 +45,10 @@ constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean"; | |||||
| constexpr auto kGetNextOpName = "GetNext"; | constexpr auto kGetNextOpName = "GetNext"; | ||||
| constexpr auto kAllReduceOpName = "AllReduce"; | constexpr auto kAllReduceOpName = "AllReduce"; | ||||
| constexpr auto kAllGatherOpName = "AllGather"; | constexpr auto kAllGatherOpName = "AllGather"; | ||||
| constexpr auto kHostAllGatherOpName = "HostAllGather"; | |||||
| constexpr auto kBroadcastOpName = "Broadcast"; | constexpr auto kBroadcastOpName = "Broadcast"; | ||||
| constexpr auto kReduceScatterOpName = "ReduceScatter"; | constexpr auto kReduceScatterOpName = "ReduceScatter"; | ||||
| constexpr auto kHostReduceScatterOpName = "HostReduceScatter"; | |||||
| constexpr auto kMemCpyAsyncOpName = "memcpy_async"; | constexpr auto kMemCpyAsyncOpName = "memcpy_async"; | ||||
| constexpr auto kTopKOpName = "TopK"; | constexpr auto kTopKOpName = "TopK"; | ||||
| constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches"; | constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches"; | ||||
| @@ -18,9 +18,9 @@ import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from .. import operations as P | from .. import operations as P | ||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | from ..composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| from ..operations.comm_ops import (AllGather, AllReduce, _AlltoAll, Broadcast, | |||||
| from ..operations.comm_ops import (AllGather, HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||||
| _GetTensorSlice, _MirrorOperator, ReduceOp, | _GetTensorSlice, _MirrorOperator, ReduceOp, | ||||
| ReduceScatter, _VirtualDiv) | |||||
| ReduceScatter, HostReduceScatter, _VirtualDiv) | |||||
| from .grad_base import bprop_getters | from .grad_base import bprop_getters | ||||
| @@ -79,6 +79,21 @@ def get_bprop_all_gather(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(HostAllGather) | |||||
| def get_bprop_host_all_gather(self): | |||||
| """Generate bprop for HostAllGather""" | |||||
| host_all_gather_grad = HostReduceScatter(ReduceOp.SUM, self.group) | |||||
| if self.instance_name: | |||||
| instance_name = "grad" + self.instance_name | |||||
| host_all_gather_grad.set_prim_instance_name(instance_name) | |||||
| def bprop(x, out, dout): | |||||
| dx = host_all_gather_grad(dout) | |||||
| return (dx,) | |||||
| return bprop | |||||
| @bprop_getters.register(ReduceScatter) | @bprop_getters.register(ReduceScatter) | ||||
| def get_bprop_reduce_scatter(self): | def get_bprop_reduce_scatter(self): | ||||
| """Generate bprop for ReduceScatter""" | """Generate bprop for ReduceScatter""" | ||||
| @@ -97,6 +112,24 @@ def get_bprop_reduce_scatter(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(HostReduceScatter) | |||||
| def get_bprop_host_reduce_scatter(self): | |||||
| """Generate bprop for HostReduceScatter""" | |||||
| host_reduce_scatter_grad = HostAllGather(self.group) | |||||
| if self.instance_name: | |||||
| instance_name = "grad" + self.instance_name | |||||
| host_reduce_scatter_grad.set_prim_instance_name(instance_name) | |||||
| if self.op != ReduceOp.SUM: | |||||
| raise RuntimeError("The hostreducescatter bprop only support ReduceOp.SUM until now.") | |||||
| def bprop(x, out, dout): | |||||
| dx = host_reduce_scatter_grad(dout) | |||||
| return (dx,) | |||||
| return bprop | |||||
| @bprop_getters.register(_AlltoAll) | @bprop_getters.register(_AlltoAll) | ||||
| def get_bprop_all_to_all(self): | def get_bprop_all_to_all(self): | ||||
| """Generate bprop for AlltoAll.""" | """Generate bprop for AlltoAll.""" | ||||
| @@ -32,7 +32,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace) | UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace) | ||||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | ||||
| _MirrorOperator, ReduceOp, _VirtualDataset, | _MirrorOperator, ReduceOp, _VirtualDataset, | ||||
| _VirtualDiv, _GetTensorSlice) | |||||
| _VirtualDiv, _GetTensorSlice, | |||||
| HostAllGather, HostReduceScatter) | |||||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | ||||
| TensorSummary, HistogramSummary, Print) | TensorSummary, HistogramSummary, Print) | ||||
| from .control_ops import ControlDepend, GeSwitch, Merge | from .control_ops import ControlDepend, GeSwitch, Merge | ||||
| @@ -217,8 +218,10 @@ __all__ = [ | |||||
| 'UnsortedSegmentSum', | 'UnsortedSegmentSum', | ||||
| 'UnsortedSegmentMin', | 'UnsortedSegmentMin', | ||||
| "AllGather", | "AllGather", | ||||
| "HostAllGather", | |||||
| "AllReduce", | "AllReduce", | ||||
| "ReduceScatter", | "ReduceScatter", | ||||
| "HostReduceScatter", | |||||
| "Broadcast", | "Broadcast", | ||||
| "ReduceOp", | "ReduceOp", | ||||
| 'ScalarCast', | 'ScalarCast', | ||||
| @@ -169,6 +169,72 @@ class AllGather(PrimitiveWithInfer): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| class HostAllGather(PrimitiveWithInfer): | |||||
| """ | |||||
| Gathers tensors from the specified communication group on host. | |||||
| Note: | |||||
| Tensor must have the same shape and format in all processes participating in the collective. | |||||
| Args: | |||||
| group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. | |||||
| Raises: | |||||
| TypeError: If group is not a list nor tuple, or elements of group are not int. | |||||
| ValueError: If the local rank id of the calling process not in group, | |||||
| or rank_id from group not in [0, 7]. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| Outputs: | |||||
| Tensor. If the number of devices in the group is N, | |||||
| then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`. | |||||
| Examples: | |||||
| >>> from mindspore.communication import init | |||||
| >>> import mindspore.ops.operations as P | |||||
| >>> init('nccl') | |||||
| >>> class Net(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(Net, self).__init__() | |||||
| >>> self.hostallgather = P.HostAllGather(group=(0, 1, 2, 3)) | |||||
| >>> | |||||
| >>> def construct(self, x): | |||||
| >>> return self.hostallgather(x) | |||||
| >>> | |||||
| >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) | |||||
| >>> net = Net() | |||||
| >>> output = net(input_) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, group=None): | |||||
| if group is None: | |||||
| raise ValueError(f"For '{self.name}' group must be set.") | |||||
| validator.check_value_type('group', group, (tuple, list), self.name) | |||||
| validator.check_integer("group size", len(group), 2, Rel.GE, self.name) | |||||
| for r in group: | |||||
| validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name) | |||||
| validator.check_value_type("rank_id", r, (int,), self.name) | |||||
| self.group_size = len(group) | |||||
| self.rank = get_rank() | |||||
| validator.check('rank', self.rank, 'group', self.group, Rel.IN, self.name) | |||||
| self.add_prim_attr('group', group) | |||||
| def infer_shape(self, x_shape): | |||||
| validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name) | |||||
| x_shape[0] = x_shape[0] * self.group_size | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||||
| return x_dtype | |||||
| def __call__(self, tensor): | |||||
| raise NotImplementedError | |||||
| class ReduceScatter(PrimitiveWithInfer): | class ReduceScatter(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Reduces and scatters tensors from the specified communication group. | Reduces and scatters tensors from the specified communication group. | ||||
| @@ -226,6 +292,68 @@ class ReduceScatter(PrimitiveWithInfer): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| class HostReduceScatter(PrimitiveWithInfer): | |||||
| """ | |||||
| Reduces and scatters tensors from the specified communication group on host. | |||||
| Note: | |||||
| Tensor must have the same shape and format in all processes participating in the collective. | |||||
| Args: | |||||
| op (str): Specifies an operation used for element-wise reductions, | |||||
| like sum, max, avg. Default: ReduceOp.SUM. | |||||
| group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. | |||||
| Raise: | |||||
| TypeError: If op is not a string and group is not a list nor tuple, | |||||
| or elements of group are not int. | |||||
| ValueError: If the first dimension of input can not be divided by rank size, | |||||
| or group is not set, or rank_id not in [1, 7]. | |||||
| Examples: | |||||
| >>> from mindspore.communication import init | |||||
| >>> import mindspore.ops.operations as P | |||||
| >>> init('nccl') | |||||
| >>> class Net(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(Net, self).__init__() | |||||
| >>> self.hostreducescatter = P.HostReduceScatter(ReduceOp.SUM, group=[0, 1, 2, 3]) | |||||
| >>> | |||||
| >>> def construct(self, x): | |||||
| >>> return self.hostreducescatter(x) | |||||
| >>> | |||||
| >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) | |||||
| >>> net = Net() | |||||
| >>> output = net(input_) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, op=ReduceOp.SUM, group=None): | |||||
| if group is None: | |||||
| raise ValueError(f"For '{self.name}' group must be set.") | |||||
| validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) | |||||
| validator.check_value_type('group', group, (tuple, list), self.name) | |||||
| validator.check_integer("group size", len(group), 2, Rel.GE, self.name) | |||||
| for r in group: | |||||
| validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name) | |||||
| validator.check_value_type("rank_id", r, (int,), self.name) | |||||
| self.op = op | |||||
| self.group_size = len(group) | |||||
| self.add_prim_attr('group', group) | |||||
| def infer_shape(self, x_shape): | |||||
| if x_shape[0] % self.group_size != 0: | |||||
| raise ValueError(f"For '{self.name}' the first dimension of x should be divided by group_size.") | |||||
| x_shape[0] = int(x_shape[0]/self.group_size) | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||||
| return x_dtype | |||||
| def __call__(self, tensor): | |||||
| raise NotImplementedError | |||||
| class Broadcast(PrimitiveWithInfer): | class Broadcast(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Broadcasts the tensor to the whole group. | Broadcasts the tensor to the whole group. | ||||
| @@ -26,6 +26,7 @@ from mindspore.nn import Momentum | |||||
| from mindspore.nn import ReLU | from mindspore.nn import ReLU | ||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | from mindspore.nn import TrainOneStepCell, WithLossCell | ||||
| from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter | from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter | ||||
| from mindspore.ops.operations.comm_ops import HostAllGather, HostReduceScatter | |||||
| from mindspore.ops.operations.comm_ops import Broadcast | from mindspore.ops.operations.comm_ops import Broadcast | ||||
| # pylint: disable=W0212 | # pylint: disable=W0212 | ||||
| @@ -86,6 +87,21 @@ class AllGatherNet(nn.Cell): | |||||
| return self.relu(x) | return self.relu(x) | ||||
| class HostAllGatherNet(nn.Cell): | |||||
| """HostAllGatherNet definition""" | |||||
| def __init__(self, input_channel, output_channel): | |||||
| super(HostAllGatherNet, self).__init__() | |||||
| self.dense = Dense(input_channel, output_channel) | |||||
| self.hostallgather = HostAllGather((0, 1)) | |||||
| self.relu = ReLU() | |||||
| def construct(self, x): | |||||
| x = self.dense(x) | |||||
| x = self.hostallgather(x) | |||||
| return self.relu(x) | |||||
| class ReduceScatterNet(nn.Cell): | class ReduceScatterNet(nn.Cell): | ||||
| """ReduceScatterNet definition""" | """ReduceScatterNet definition""" | ||||
| @@ -101,6 +117,21 @@ class ReduceScatterNet(nn.Cell): | |||||
| return self.relu(x) | return self.relu(x) | ||||
| class HostReduceScatterNet(nn.Cell): | |||||
| """HostReduceScatterNet definition""" | |||||
| def __init__(self, input_channel, out_channel, op): | |||||
| super(HostReduceScatterNet, self).__init__() | |||||
| self.dense = Dense(input_channel, out_channel) | |||||
| self.hostreducescatter = HostReduceScatter(op, (0, 1)) | |||||
| self.relu = ReLU() | |||||
| def construct(self, x): | |||||
| x = self.dense(x) | |||||
| x = self.hostreducescatter(x) | |||||
| return self.relu(x) | |||||
| class AlltoAllNet(nn.Cell): | class AlltoAllNet(nn.Cell): | ||||
| """AlltoAllNet definition""" | """AlltoAllNet definition""" | ||||
| @@ -154,6 +185,21 @@ def test_allgather(): | |||||
| _executor.compile(network, input_tensor, label_tensor) | _executor.compile(network, input_tensor, label_tensor) | ||||
| def test_hostallgather(): | |||||
| """test_hostallgather""" | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) | |||||
| label_tensor = Tensor(np.array([[1.2], [2.2], [3.2], [4.2]], dtype=np.float32)) | |||||
| network = HostAllGatherNet(2, 1) | |||||
| loss_fn = nn.SoftmaxCrossEntropyWithLogits() | |||||
| optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), | |||||
| learning_rate=0.1, | |||||
| momentum=0.9) | |||||
| network = WithLossCell(network, loss_fn) | |||||
| network = TrainOneStepCell(network, optimizer) | |||||
| _executor.compile(network, input_tensor, label_tensor) | |||||
| def run_reducescatter(op): | def run_reducescatter(op): | ||||
| """run_reducescatter""" | """run_reducescatter""" | ||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
| @@ -175,6 +221,21 @@ def test_reducescatter(): | |||||
| run_reducescatter(ReduceOp.SUM) | run_reducescatter(ReduceOp.SUM) | ||||
| def test_hostreducescatter(): | |||||
| """test_hostreducescatter""" | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) | |||||
| label_tensor = Tensor(np.array([[1.2]], dtype=np.float32)) | |||||
| network = HostReduceScatterNet(2, 1, ReduceOp.SUM) | |||||
| loss_fn = nn.SoftmaxCrossEntropyWithLogits() | |||||
| optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), | |||||
| learning_rate=0.1, | |||||
| momentum=0.9) | |||||
| network = WithLossCell(network, loss_fn) | |||||
| network = TrainOneStepCell(network, optimizer) | |||||
| _executor.compile(network, input_tensor, label_tensor) | |||||
| def test_broadcast(): | def test_broadcast(): | ||||
| """test_broadcast""" | """test_broadcast""" | ||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||