From: @huangxinjing Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -217,6 +217,8 @@ AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const Primit | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBroadcast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -367,6 +367,45 @@ AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, co | |||
| return sparse_tensor->dense_shape(); | |||
| } | |||
| AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 3); | |||
| auto tensor_in = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(tensor_in); | |||
| MS_EXCEPTION_IF_NULL(tensor_in->shape()); | |||
| auto tensor_in_shape = tensor_in->shape()->shape(); | |||
| auto send_size = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(send_size); | |||
| auto recv_size = CheckArg<AbstractTensor>(op_name, args_spec_list, 2); | |||
| MS_EXCEPTION_IF_NULL(recv_size); | |||
| // Get the content of the recv size | |||
| auto recv_size_value_ptr = recv_size->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(recv_size_value_ptr); | |||
| auto recv_size_tensor = recv_size_value_ptr->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(recv_size_tensor); | |||
| auto data_pos = reinterpret_cast<int64_t *>(recv_size_tensor->data_c()); | |||
| MS_EXCEPTION_IF_NULL(data_pos); | |||
| int64_t infer_max_size = 0; | |||
| for (int64_t i = 0; i < recv_size_tensor->DataSize(); ++i) { | |||
| infer_max_size += *(data_pos + i); | |||
| } | |||
| ShapeVector tensor_out_shape = {Shape::SHP_ANY, tensor_in_shape[1]}; | |||
| ShapeVector min_shape = {1, tensor_in_shape[1]}; | |||
| ShapeVector max_shape = {infer_max_size / tensor_in_shape[1], tensor_in_shape[1]}; | |||
| auto tensor_out = std::make_shared<AbstractTensor>(tensor_in->element(), | |||
| std::make_shared<Shape>(tensor_out_shape, min_shape, max_shape)); | |||
| AbstractTensorPtr ret = std::make_shared<AbstractTensor>( | |||
| tensor_out->element(), std::make_shared<Shape>(tensor_out_shape, min_shape, max_shape)); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| @@ -135,6 +135,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimAllReduce, {InferImplAllReduce, true}}, | |||
| {prim::kPrimBroadcast, {InferImplBroadcast, true}}, | |||
| {prim::kPrimAllGather, {InferImplAllGather, true}}, | |||
| {prim::kPrimAllSwap, {InferImplAllSwap, true}}, | |||
| {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, | |||
| {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | |||
| {prim::kPrimCast, {InferImplCast, true}}, | |||
| @@ -186,6 +186,7 @@ inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOper | |||
| inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | |||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | |||
| inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||
| inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap"); | |||
| inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast"); | |||
| inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather"); | |||
| inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter"); | |||
| @@ -21,7 +21,7 @@ from ...common.tensor import RowTensor | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||
| _GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive, | |||
| ReduceScatter, _HostReduceScatter, _VirtualDiv) | |||
| ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) | |||
| from .grad_base import bprop_getters | |||
| @@ -155,6 +155,21 @@ def get_bprop_reduce_scatter(self): | |||
| return bprop | |||
| @bprop_getters.register(AllSwap) | |||
| def get_bprop_allswap(self): | |||
| """Generate bprop for AllSwap.""" | |||
| all_swap_grad = AllSwap(self.group) | |||
| if self.instance_name: | |||
| instance_name = "grad" + self.instance_name | |||
| all_to_all_grad.set_prim_instance_name(instance_name) | |||
| def bprop(x, send_size, recv_size, out, dout): | |||
| dx = all_swap_grad(dout, recv_size, send_size) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(_HostReduceScatter) | |||
| def get_bprop_host_reduce_scatter(self): | |||
| """Generate bprop for _HostReduceScatter""" | |||
| @@ -34,7 +34,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | |||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | |||
| Unique, GatherD, Identity, RepeatElements) | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice, Send, Receive, | |||
| _HostAllGather, _HostReduceScatter) | |||
| @@ -295,6 +295,7 @@ __all__ = [ | |||
| 'UnsortedSegmentProd', | |||
| "AllGather", | |||
| "AllReduce", | |||
| "AllSwap", | |||
| "ReduceScatter", | |||
| "Broadcast", | |||
| "ReduceOp", | |||
| @@ -20,7 +20,7 @@ from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group | |||
| from ...common import dtype as mstype | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||
| from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register | |||
| class ReduceOp: | |||
| @@ -518,6 +518,59 @@ class Broadcast(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class AllSwap(PrimitiveWithCheck): | |||
| """ | |||
| AllSwap is a collective operation. | |||
| AllSwap sends data from the all processes to the all processes in the specified group. It has two phases: | |||
| - The scatter phase: On each process, the operand is split into the send size of blocks along the | |||
| 0-th axis, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process. | |||
| - The gather phase: Each process concatenates the received blocks along the 0-th axis. | |||
| Note: | |||
| The tensors must have the same format in all processes of the collection. | |||
| Args: | |||
| group (str): The communication group name. | |||
| Inputs: | |||
| tensor_in (tensor): A 2-D tensor. On each process, divide blocks into number of the send size. | |||
| send_size (tensor): A 1-D int64 tensor. The element is the send data size for each process. | |||
| recv_size (tensor): A 1-D int64 tensor. The element is the receive data size for each process. | |||
| Returns: | |||
| tensor_out (tensor): The result tensor. | |||
| Raises: | |||
| TypeError: If group is not a string. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): | |||
| """Initialize AllSwap""" | |||
| validator.check_value_type('group', _get_group(group), (str,), self.name) | |||
| self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out']) | |||
| self.add_prim_attr('group', _get_group(group)) | |||
| def __check__(self, tensor_in, send_size, recv_size): | |||
| validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name) | |||
| validator.check_tensor_dtype_valid("send_size", send_size['dtype'], [mstype.int64], | |||
| self.name) | |||
| validator.check_tensor_dtype_valid("recv_size", recv_size['dtype'], [mstype.int64], | |||
| self.name) | |||
| validator.check_equal_int(len(tensor_in['shape']), 2, "tensor_in", self.name) | |||
| validator.check_equal_int(len(send_size['shape']), 1, "send_size", self.name) | |||
| validator.check_equal_int(len(recv_size['shape']), 1, "recv_size", self.name) | |||
| out_shape = [-1] + [tensor_in['shape'][1]] | |||
| out = {'shape': out_shape, | |||
| 'dtype': tensor_in['dtype'], | |||
| 'value': None} | |||
| return out | |||
| class _AlltoAll(PrimitiveWithInfer): | |||
| """ | |||
| AlltoAll is a collective operation. | |||
| @@ -26,7 +26,9 @@ from mindspore.nn import Momentum | |||
| from mindspore.nn import ReLU | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter | |||
| from mindspore.ops.operations.comm_ops import Broadcast | |||
| from mindspore.ops.operations.comm_ops import Broadcast, AllSwap | |||
| from mindspore.ops.operations.math_ops import ReduceSum | |||
| import mindspore | |||
| # pylint: disable=W0212 | |||
| # W0212: protected-access | |||
| @@ -117,6 +119,25 @@ class AlltoAllNet(nn.Cell): | |||
| return self.relu(x) | |||
| class AllSwapNet(nn.Cell): | |||
| """AlltoAllNet definition""" | |||
| def __init__(self, batch_size, input_channel, out_channel): | |||
| super(AllSwapNet, self).__init__() | |||
| self.dense = Dense(input_channel, out_channel) | |||
| self.allswap = AllSwap() | |||
| self.relu = ReLU() | |||
| self.reduce = ReduceSum() | |||
| part_slice = batch_size / 2 | |||
| self.send_size = Tensor([0, part_slice*out_channel, part_slice*out_channel], mindspore.int64) | |||
| self.recv_size = Tensor([part_slice*out_channel, part_slice*out_channel, 0], mindspore.int64) | |||
| def construct(self, x): | |||
| x = self.dense(x) | |||
| x = self.allswap(x, self.send_size, self.recv_size) | |||
| x = self.relu(x) | |||
| return x | |||
| def run_allreduce(op): | |||
| """run_allreduce""" | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @@ -154,6 +175,13 @@ def test_allgather(): | |||
| network = TrainOneStepCell(network, optimizer) | |||
| _executor.compile(network, input_tensor, label_tensor) | |||
| def test_allswap(): | |||
| """run_allswap""" | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| input_tensor = Tensor(np.ones((100, 20)), dtype=mindspore.float32) | |||
| network = AllSwapNet(100, 20, 20) | |||
| _executor.compile(network, input_tensor) | |||
| def run_reducescatter(op): | |||
| """run_reducescatter""" | |||