Merge pull request !2499 from yihuaijie/mastertags/v0.6.0-beta
| @@ -36,7 +36,7 @@ class AllGatherCPUKernel : public CPUKernel { | |||
| std::vector<int> ranks_group_; | |||
| }; | |||
| MS_REG_CPU_KERNEL(HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| MS_REG_CPU_KERNEL(_HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| AllGatherCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -37,7 +37,7 @@ class ReduceScatterCPUKernel : public CPUKernel { | |||
| std::vector<int> ranks_group_; | |||
| }; | |||
| MS_REG_CPU_KERNEL(HostReduceScatter, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| MS_REG_CPU_KERNEL(_HostReduceScatter, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ReduceScatterCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -145,7 +145,7 @@ constexpr char MIRROR_OPERATOR[] = "_MirrorOperator"; | |||
| constexpr char STRIDED_SLICE[] = "StridedSlice"; | |||
| constexpr char ALL_GATHER[] = "AllGather"; | |||
| constexpr char REDUCE_SCATTER[] = "ReduceScatter"; | |||
| constexpr char HOST_REDUCE_SCATTER[] = "HostReduceScatter"; | |||
| constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter"; | |||
| constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup"; | |||
| constexpr char CONCAT[] = "Concat"; | |||
| constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits"; | |||
| @@ -55,9 +55,7 @@ const char kNameSimpleMeanGrad[] = "SimpleMeanGrad"; | |||
| const char kNameAllReduce[] = "AllReduce"; | |||
| const char kNameBroadcast[] = "Broadcast"; | |||
| const char kNameAllgather[] = "AllGather"; | |||
| const char kNameHostAllgather[] = "HostAllGather"; | |||
| const char kNameReduceScatter[] = "ReduceScatter"; | |||
| const char kNameHostReduceScatter[] = "HostReduceScatter"; | |||
| const char kNameReduceSum[] = "ReduceSum"; | |||
| const char kNameIsFinite[] = "isFinite"; | |||
| const char kNameReciprocal[] = "Reciprocal"; | |||
| @@ -18,9 +18,9 @@ import mindspore.common.dtype as mstype | |||
| from mindspore.ops import functional as F | |||
| from .. import operations as P | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ..operations.comm_ops import (AllGather, HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||
| from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||
| _GetTensorSlice, _MirrorOperator, ReduceOp, | |||
| ReduceScatter, HostReduceScatter, _VirtualDiv) | |||
| ReduceScatter, _HostReduceScatter, _VirtualDiv) | |||
| from .grad_base import bprop_getters | |||
| @@ -93,10 +93,10 @@ def get_bprop_all_gather(self): | |||
| return bprop | |||
| @bprop_getters.register(HostAllGather) | |||
| @bprop_getters.register(_HostAllGather) | |||
| def get_bprop_host_all_gather(self): | |||
| """Generate bprop for HostAllGather""" | |||
| host_all_gather_grad = HostReduceScatter(ReduceOp.SUM, self.group) | |||
| """Generate bprop for _HostAllGather""" | |||
| host_all_gather_grad = _HostReduceScatter(ReduceOp.SUM, self.group) | |||
| if self.instance_name: | |||
| instance_name = "grad" + self.instance_name | |||
| host_all_gather_grad.set_prim_instance_name(instance_name) | |||
| @@ -126,10 +126,10 @@ def get_bprop_reduce_scatter(self): | |||
| return bprop | |||
| @bprop_getters.register(HostReduceScatter) | |||
| @bprop_getters.register(_HostReduceScatter) | |||
| def get_bprop_host_reduce_scatter(self): | |||
| """Generate bprop for HostReduceScatter""" | |||
| host_reduce_scatter_grad = HostAllGather(self.group) | |||
| """Generate bprop for _HostReduceScatter""" | |||
| host_reduce_scatter_grad = _HostAllGather(self.group) | |||
| if self.instance_name: | |||
| instance_name = "grad" + self.instance_name | |||
| host_reduce_scatter_grad.set_prim_instance_name(instance_name) | |||
| @@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice, | |||
| HostAllGather, HostReduceScatter) | |||
| _HostAllGather, _HostReduceScatter) | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | |||
| TensorSummary, HistogramSummary, Debug, Print) | |||
| from .control_ops import ControlDepend, GeSwitch, Merge | |||
| @@ -244,10 +244,8 @@ __all__ = [ | |||
| 'UnsortedSegmentSum', | |||
| 'UnsortedSegmentMin', | |||
| "AllGather", | |||
| "HostAllGather", | |||
| "AllReduce", | |||
| "ReduceScatter", | |||
| "HostReduceScatter", | |||
| "Broadcast", | |||
| "ReduceOp", | |||
| 'ScalarCast', | |||
| @@ -1166,7 +1166,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer): | |||
| Perform the gradient for the communication part of EmbeddingLookup operator. | |||
| This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking, | |||
| this primitive is implemented by StridedSlice --> HostAllGather --> Concat. This primitive runs on host. | |||
| this primitive is implemented by StridedSlice --> _HostAllGather --> Concat. This primitive runs on host. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| @@ -1177,8 +1177,8 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer): | |||
| """ | |||
| This primitive is implemented by three steps: | |||
| 1) Split the 'dy' along dimension 0 into 'split_num' parts. | |||
| 2) For each part, perform HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host. | |||
| 3) After HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them | |||
| 2) For each part, perform _HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host. | |||
| 3) After _HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them | |||
| along dimension 0. | |||
| The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8 | |||
| @@ -176,13 +176,13 @@ class AllGather(PrimitiveWithInfer): | |||
| raise NotImplementedError | |||
| class HostAllGather(PrimitiveWithInfer): | |||
| class _HostAllGather(PrimitiveWithInfer): | |||
| """ | |||
| Gathers tensors from the specified communication group on host. | |||
| Note: | |||
| Tensor must have the same shape and format in all processes participating in the collective. | |||
| HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on | |||
| _HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on | |||
| to enable it. Using mpirun command to run it: | |||
| mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py | |||
| @@ -199,27 +199,6 @@ class HostAllGather(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor. If the number of devices in the group is N, | |||
| then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`. | |||
| Examples: | |||
| >>> import mindspore.nn as nn | |||
| >>> import mindspore.context as context | |||
| >>> import mindspore.ops.operations as P | |||
| >>> from mindspore import Tensor | |||
| >>> | |||
| >>> context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| >>> context.set_mpi_config(enable_mpi=True) | |||
| >>> | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.hostallgather = P.HostAllGather(group=(0, 1, 2, 3)) | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> return self.hostallgather(x) | |||
| >>> | |||
| >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) | |||
| >>> net = Net() | |||
| >>> output = net(input_) | |||
| """ | |||
| @prim_attr_register | |||
| @@ -308,13 +287,13 @@ class ReduceScatter(PrimitiveWithInfer): | |||
| raise NotImplementedError | |||
| class HostReduceScatter(PrimitiveWithInfer): | |||
| class _HostReduceScatter(PrimitiveWithInfer): | |||
| """ | |||
| Reduces and scatters tensors from the specified communication group on host. | |||
| Note: | |||
| Tensor must have the same shape and format in all processes participating in the collective. | |||
| HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option | |||
| _HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option | |||
| -M on to enable it. Using mpirun command to run it: | |||
| mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py | |||
| @@ -328,28 +307,6 @@ class HostReduceScatter(PrimitiveWithInfer): | |||
| or elements of group are not int. | |||
| ValueError: If the first dimension of input can not be divided by group size, | |||
| or group is not set, or rank_id not in [0, 7]. | |||
| Examples: | |||
| >>> import mindspore.nn as nn | |||
| >>> import mindspore.context as context | |||
| >>> import mindspore.ops.operations as P | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.ops.operations.comm_ops import ReduceOp | |||
| >>> | |||
| >>> context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| >>> context.set_mpi_config(enable_mpi=True) | |||
| >>> | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.hostreducescatter = P.HostReduceScatter(ReduceOp.SUM, group=[0, 1, 2, 3]) | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> return self.hostreducescatter(x) | |||
| >>> | |||
| >>> input_ = Tensor(np.ones([8, 8]).astype(np.float32)) | |||
| >>> net = Net() | |||
| >>> output = net(input_) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, op=ReduceOp.SUM, group=None): | |||
| @@ -1,76 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| import mindspore._ms_mpi as mpi | |||
| # run comand: | |||
| # mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_reduce_scatter.py | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| context.set_mpi_config(enable_mpi=True) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.op = "sum" | |||
| self.reducescatter = P.HostReduceScatter(op=self.op, group=[0,1,2]) | |||
| def construct(self, x): | |||
| return self.reducescatter(x) | |||
| class AllGatherNet(nn.Cell): | |||
| def __init__(self): | |||
| super(AllGatherNet, self).__init__() | |||
| self.hostallgather = P.HostAllGather(group=(0, 1, 2)) | |||
| def construct(self, x): | |||
| return self.hostallgather(x) | |||
| def test_net_reduce_scatter(): | |||
| x = np.arange(12).astype(np.float32) * 0.1 | |||
| reducescatter = Net() | |||
| rankid = mpi.get_rank_id() | |||
| print("self rankid:", rankid) | |||
| output = reducescatter(Tensor(x, mstype.float32)) | |||
| print("output:\n", output) | |||
| if rankid == 0: | |||
| expect_result = np.arange(4).astype(np.float32) * 0.3 | |||
| if rankid == 1: | |||
| expect_result = np.arange(4, 8).astype(np.float32) * 0.3 | |||
| if rankid == 2: | |||
| expect_result = np.arange(8, 12).astype(np.float32) * 0.3 | |||
| diff = abs(output.asnumpy() - expect_result) | |||
| error = np.ones(shape=expect_result.shape) * 1.0e-6 | |||
| assert np.all(diff < error) | |||
| allgather = AllGatherNet() | |||
| allgather_output = allgather(output) | |||
| print("allgather result:\n", allgather_output) | |||
| expect_allgather_result = np.arange(12).astype(np.float32) * 0.3 | |||
| diff = abs(allgather_output.asnumpy() - expect_allgather_result) | |||
| error = np.ones(shape=expect_allgather_result.shape) * 1.0e-6 | |||
| assert np.all(diff < error) | |||
| if __name__ == '__main__': | |||
| test_net_reduce_scatter() | |||
| @@ -26,7 +26,6 @@ from mindspore.nn import Momentum | |||
| from mindspore.nn import ReLU | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter | |||
| from mindspore.ops.operations.comm_ops import HostAllGather, HostReduceScatter | |||
| from mindspore.ops.operations.comm_ops import Broadcast | |||
| # pylint: disable=W0212 | |||
| @@ -87,21 +86,6 @@ class AllGatherNet(nn.Cell): | |||
| return self.relu(x) | |||
| class HostAllGatherNet(nn.Cell): | |||
| """HostAllGatherNet definition""" | |||
| def __init__(self, input_channel, output_channel): | |||
| super(HostAllGatherNet, self).__init__() | |||
| self.dense = Dense(input_channel, output_channel) | |||
| self.hostallgather = HostAllGather((0, 1)) | |||
| self.relu = ReLU() | |||
| def construct(self, x): | |||
| x = self.dense(x) | |||
| x = self.hostallgather(x) | |||
| return self.relu(x) | |||
| class ReduceScatterNet(nn.Cell): | |||
| """ReduceScatterNet definition""" | |||
| @@ -117,21 +101,6 @@ class ReduceScatterNet(nn.Cell): | |||
| return self.relu(x) | |||
| class HostReduceScatterNet(nn.Cell): | |||
| """HostReduceScatterNet definition""" | |||
| def __init__(self, input_channel, out_channel, op): | |||
| super(HostReduceScatterNet, self).__init__() | |||
| self.dense = Dense(input_channel, out_channel) | |||
| self.hostreducescatter = HostReduceScatter(op, (0, 1)) | |||
| self.relu = ReLU() | |||
| def construct(self, x): | |||
| x = self.dense(x) | |||
| x = self.hostreducescatter(x) | |||
| return self.relu(x) | |||
| class AlltoAllNet(nn.Cell): | |||
| """AlltoAllNet definition""" | |||
| @@ -185,21 +154,6 @@ def test_allgather(): | |||
| _executor.compile(network, input_tensor, label_tensor) | |||
| def test_hostallgather(): | |||
| """test_hostallgather""" | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) | |||
| label_tensor = Tensor(np.array([[1.2], [2.2], [3.2], [4.2]], dtype=np.float32)) | |||
| network = HostAllGatherNet(2, 1) | |||
| loss_fn = nn.SoftmaxCrossEntropyWithLogits() | |||
| optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), | |||
| learning_rate=0.1, | |||
| momentum=0.9) | |||
| network = WithLossCell(network, loss_fn) | |||
| network = TrainOneStepCell(network, optimizer) | |||
| _executor.compile(network, input_tensor, label_tensor) | |||
| def run_reducescatter(op): | |||
| """run_reducescatter""" | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @@ -221,21 +175,6 @@ def test_reducescatter(): | |||
| run_reducescatter(ReduceOp.SUM) | |||
| def test_hostreducescatter(): | |||
| """test_hostreducescatter""" | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) | |||
| label_tensor = Tensor(np.array([[1.2]], dtype=np.float32)) | |||
| network = HostReduceScatterNet(2, 1, ReduceOp.SUM) | |||
| loss_fn = nn.SoftmaxCrossEntropyWithLogits() | |||
| optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), | |||
| learning_rate=0.1, | |||
| momentum=0.9) | |||
| network = WithLossCell(network, loss_fn) | |||
| network = TrainOneStepCell(network, optimizer) | |||
| _executor.compile(network, input_tensor, label_tensor) | |||
| def test_broadcast(): | |||
| """test_broadcast""" | |||
| context.set_context(mode=context.GRAPH_MODE) | |||