diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc index 8374914dd5..9927798c0c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc @@ -47,5 +47,15 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), NcclGpuKernel, int) + +MS_REG_GPU_KERNEL_ONE( + Broadcast, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + NcclGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + Broadcast, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + NcclGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Broadcast, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + NcclGpuKernel, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h index 18caa149f6..a1a9eeecdb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h @@ -30,11 +30,18 @@ namespace mindspore { namespace kernel { -enum NcclKernelType { NCCL_ALL_REDUCE = 0, NCCL_ALL_GATHER, NCCL_REDUCE_SCATTER, NCCL_INVALID_TYPE = 255 }; +enum NcclKernelType { + NCCL_ALL_REDUCE = 0, + NCCL_ALL_GATHER, + NCCL_REDUCE_SCATTER, + NCCL_BROADCAST, + NCCL_INVALID_TYPE = 255 +}; const std::map kNcclTypeMap = { {"AllReduce", NCCL_ALL_REDUCE}, {"AllGather", NCCL_ALL_GATHER}, {"ReduceScatter", NCCL_REDUCE_SCATTER}, + {"Broadcast", NCCL_BROADCAST}, }; static std::map kNcclDtypeMap = { @@ -45,6 +52,7 @@ typedef ncclResult_t (*AllReduce)(const void *, void *, size_t, ncclDataType_t, typedef ncclResult_t (*AllGather)(const void *, void *, size_t, ncclDataType_t, cudaStream_t, const std::string &); typedef ncclResult_t (*ReduceScatter)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t, const std::string &); +typedef ncclResult_t (*Broadcast)(const void *, void *, size_t, ncclDataType_t, int, cudaStream_t, const std::string &); template class NcclGpuKernel : public GpuKernel { @@ -55,6 +63,7 @@ class NcclGpuKernel : public GpuKernel { group_name_(""), input_size_(0), output_size_(0), + root_(0), collective_handle_(nullptr), comm_stream_(nullptr) {} ~NcclGpuKernel() override = default; @@ -96,6 +105,15 @@ class NcclGpuKernel : public GpuKernel { "ncclReduceScatter failed"); break; } + case NCCL_BROADCAST: { + auto broadcast_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "Broadcast")); + MS_EXCEPTION_IF_NULL(broadcast_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT((*broadcast_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), + nccl_data_type_, root_, stream, group_name_), + "ncclBroadcast failed"); + break; + } default: { MS_LOG(EXCEPTION) << "Kernel type " << nccl_kernel_type_ << " is not supported."; } @@ -167,6 +185,11 @@ class NcclGpuKernel : public GpuKernel { MS_LOG(EXCEPTION) << "Nccl reduce type " << type << " is not supported."; } } + + auto root_rank = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kAttrRootRank); + if (root_rank) { + root_ = GetValue(root_rank); + } return; } @@ -176,6 +199,7 @@ class NcclGpuKernel : public GpuKernel { std::string group_name_; size_t input_size_; size_t output_size_; + int root_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc index d74f1ebea0..d029cf9615 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc @@ -48,3 +48,8 @@ ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t cou ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group) { return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream, group); } + +ncclResult_t Broadcast(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, int root, + cudaStream_t stream, const std::string &group) { + return NCCLWrapper::instance().Broadcast(input_addr, output_addr, count, data_type, root, stream, group); +} diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.h index e76ede4d38..8c7801b89a 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.h +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.h @@ -45,3 +45,6 @@ extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *o extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group); +extern "C" EXPORT_WRAPPER ncclResult_t Broadcast(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, int root, cudaStream_t stream, + const std::string &group); diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc index 519a29a597..2def2d6db1 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc @@ -66,6 +66,14 @@ ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_add return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream); } +ncclResult_t NCCLWrapper::Broadcast(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + int root, cudaStream_t stream, const std::string &group_name) { + CHECK_RET(group_info_.count(group_name), 1, + "Failed to find NCCL communicator for Broadcast by the group name " + group_name); + ncclComm_t group_comm = group_info_[group_name].comm; + return ncclBroadcast(input_addr, output_addr, count, data_type, root, group_comm, stream); +} + void NCCLWrapper::AddGroupInfo(const std::string &group_name, NcclGroupInfo *group) { if (comm_init_done_) { CHECK_RET(ncclCommInitRank(&(group->comm), group->size, group->unique_id, group->rank), ncclSuccess, diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h index 94525ebe46..c019e0dcbf 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h @@ -40,6 +40,8 @@ class NCCLWrapper { cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); + ncclResult_t Broadcast(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, int root, + cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); void AddGroupInfo(const std::string &group_name, NcclGroupInfo *group); void DestroyGroup(const std::string &group_name); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index c529b86197..1f43a49e8e 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -224,6 +224,7 @@ constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active constexpr auto kAttrFusion = "fusion"; constexpr auto kAttrGroup = "group"; constexpr auto kAttrOp = "op"; +constexpr auto kAttrRootRank = "root_rank"; constexpr auto kAttrIsTraining = "is_training"; constexpr auto kAttrFusionId = "fusion_id"; constexpr auto kAttrLabelIndex = "label_index"; diff --git a/tests/st/nccl/test_nccl_all.py b/tests/st/nccl/test_nccl_all.py index adb662969c..0856f383bb 100644 --- a/tests/st/nccl/test_nccl_all.py +++ b/tests/st/nccl/test_nccl_all.py @@ -46,3 +46,10 @@ def test_nccl_all_gather_op(): def test_nccl_reduce_scatter_op(): return_code = os.system("mpirun -n 8 pytest -s test_nccl_reduce_scatter_op.py") assert return_code == 0 + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_single +def test_nccl_broadcast_op(): + return_code = os.system("mpirun -n 8 pytest -s test_nccl_broadcast_op.py") + assert return_code == 0 diff --git a/tests/st/nccl/test_nccl_broadcast_op.py b/tests/st/nccl/test_nccl_broadcast_op.py new file mode 100644 index 0000000000..4541bf6e6b --- /dev/null +++ b/tests/st/nccl/test_nccl_broadcast_op.py @@ -0,0 +1,71 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + +init('nccl') +rank = get_rank() +size = get_group_size() +x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1') + self.x2 = Parameter(initializer(Tensor(x), x.shape), name='x2') + self.x3 = Parameter(initializer(Tensor(x), x.shape), name='x3') + + self.broadcast1 = P.Broadcast(0) + self.broadcast2 = P.Broadcast(1) + self.broadcast3 = P.Broadcast(2) + + def construct(self): + return (self.broadcast1((self.x1,)), + self.broadcast2((self.x2,)), + self.broadcast3((self.x3,))) + + +def test_Broadcast(): + broadcast = Net() + output = broadcast() + + expect0 = np.ones([3, 1, 3, 3]).astype(np.float32) * 1 + expect1 = np.ones([3, 1, 3, 3]).astype(np.float32) * 2 + expect2 = np.ones([3, 1, 3, 3]).astype(np.float32) * 3 + + diff0 = output[0][0].asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output[0][0].shape == expect0.shape + + diff1 = output[1][0].asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output[1][0].shape == expect1.shape + + diff2 = output[2][0].asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output[2][0].shape == expect2.shape