| @@ -57,7 +57,9 @@ if (ENABLE_GPU) | |||
| ) | |||
| file(GLOB_RECURSE GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc") | |||
| list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_gpu_kernel.cc") | |||
| list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_collective_gpu_kernel.cc") | |||
| list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_send_gpu_kernel.cc") | |||
| list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_recv_gpu_kernel.cc") | |||
| if (ENABLE_MPI) | |||
| include(ExternalProject) | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,48 +14,48 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| NcclGpuKernel, float) | |||
| NcclCollectiveGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| NcclGpuKernel, half) | |||
| NcclCollectiveGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(AllReduce, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| NcclGpuKernel, int) | |||
| NcclCollectiveGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| NcclGpuKernel, float) | |||
| NcclCollectiveGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| NcclGpuKernel, half) | |||
| NcclCollectiveGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(AllGather, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| NcclGpuKernel, int) | |||
| NcclCollectiveGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| NcclGpuKernel, float) | |||
| NcclCollectiveGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| NcclGpuKernel, half) | |||
| NcclCollectiveGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(ReduceScatter, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| NcclGpuKernel, int) | |||
| NcclCollectiveGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Broadcast, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| NcclGpuKernel, float) | |||
| NcclCollectiveGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Broadcast, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| NcclGpuKernel, half) | |||
| NcclCollectiveGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Broadcast, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| NcclGpuKernel, int) | |||
| NcclCollectiveGpuKernel, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,211 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_COLLECTIVE_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_COLLECTIVE_GPU_KERNEL_H_ | |||
| #include <dlfcn.h> | |||
| #include <stdint.h> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <map> | |||
| #include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| enum NcclKernelType { | |||
| NCCL_ALL_REDUCE = 0, | |||
| NCCL_ALL_GATHER, | |||
| NCCL_REDUCE_SCATTER, | |||
| NCCL_BROADCAST, | |||
| NCCL_INVALID_TYPE = 255 | |||
| }; | |||
| const std::map<std::string, NcclKernelType> kNcclTypeMap = { | |||
| {"AllReduce", NCCL_ALL_REDUCE}, | |||
| {"AllGather", NCCL_ALL_GATHER}, | |||
| {"ReduceScatter", NCCL_REDUCE_SCATTER}, | |||
| {"Broadcast", NCCL_BROADCAST}, | |||
| }; | |||
| template <typename T> | |||
| class NcclCollectiveGpuKernel : public NcclGpuKernel { | |||
| public: | |||
| NcclCollectiveGpuKernel() | |||
| : nccl_kernel_type_(NCCL_INVALID_TYPE), | |||
| nccl_reduce_type_(ncclSum), | |||
| input_size_(0), | |||
| output_size_(0), | |||
| root_(0), | |||
| collective_handle_(nullptr), | |||
| comm_stream_(nullptr) {} | |||
| ~NcclCollectiveGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr); | |||
| switch (nccl_kernel_type_) { | |||
| case NCCL_ALL_REDUCE: { | |||
| auto all_reduce_funcptr = | |||
| reinterpret_cast<AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce")); | |||
| MS_EXCEPTION_IF_NULL(all_reduce_funcptr); | |||
| CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), | |||
| nccl_data_type_, nccl_reduce_type_, stream, group_name_), | |||
| "ncclAllReduce failed"); | |||
| break; | |||
| } | |||
| case NCCL_ALL_GATHER: { | |||
| auto all_gather_funcptr = | |||
| reinterpret_cast<AllGather>(dlsym(const_cast<void *>(collective_handle_), "AllGather")); | |||
| MS_EXCEPTION_IF_NULL(all_gather_funcptr); | |||
| CHECK_NCCL_RET_WITH_EXCEPT( | |||
| (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream, group_name_), | |||
| "ncclAllGather failed"); | |||
| break; | |||
| } | |||
| case NCCL_REDUCE_SCATTER: { | |||
| auto reduce_scatter_funcptr = | |||
| reinterpret_cast<ReduceScatter>(dlsym(const_cast<void *>(collective_handle_), "ReduceScatter")); | |||
| MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr); | |||
| CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), | |||
| nccl_data_type_, nccl_reduce_type_, stream, group_name_), | |||
| "ncclReduceScatter failed"); | |||
| break; | |||
| } | |||
| case NCCL_BROADCAST: { | |||
| auto broadcast_funcptr = | |||
| reinterpret_cast<Broadcast>(dlsym(const_cast<void *>(collective_handle_), "Broadcast")); | |||
| MS_EXCEPTION_IF_NULL(broadcast_funcptr); | |||
| for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) { | |||
| input_addr = GetDeviceAddress<T>(inputs, i); | |||
| output_addr = GetDeviceAddress<T>(outputs, i); | |||
| CHECK_NCCL_RET_WITH_EXCEPT((*broadcast_funcptr)(input_addr, output_addr, output_size_list_[i] / sizeof(T), | |||
| nccl_data_type_, root_, stream, group_name_), | |||
| "ncclBroadcast failed"); | |||
| } | |||
| break; | |||
| } | |||
| default: { | |||
| MS_LOG(EXCEPTION) << "Kernel type " << nccl_kernel_type_ << " is not supported."; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)); | |||
| InferCommType(kernel_node); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); | |||
| size_t size = sizeof(T); | |||
| for (size_t j = 0; j < shape.size(); j++) { | |||
| size *= IntToSize(shape[j]); | |||
| } | |||
| size_t aligned_size = (nccl_kernel_type_ != NCCL_ALL_REDUCE) ? size : AlignMemorySize(size); | |||
| input_size_list_.push_back(aligned_size); | |||
| input_size_ += aligned_size; | |||
| } | |||
| for (size_t i = 0; i < output_num; ++i) { | |||
| auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i); | |||
| size_t size = sizeof(T); | |||
| for (size_t j = 0; j < shape.size(); j++) { | |||
| size *= IntToSize(shape[j]); | |||
| } | |||
| size_t aligned_size = (nccl_kernel_type_ != NCCL_ALL_REDUCE) ? size : AlignMemorySize(size); | |||
| output_size_list_.push_back(aligned_size); | |||
| output_size_ += aligned_size; | |||
| } | |||
| group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup); | |||
| MS_LOG(INFO) << AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_; | |||
| auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id"); | |||
| if (comm_stream_attr) { | |||
| comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr)); | |||
| MS_EXCEPTION_IF_NULL(comm_stream_); | |||
| } | |||
| collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); | |||
| MS_EXCEPTION_IF_NULL(collective_handle_); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { return; } | |||
| private: | |||
| void InferCommType(const CNodePtr &kernel_node) { | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| auto iter = kNcclTypeMap.find(kernel_name); | |||
| if (iter == kNcclTypeMap.end()) { | |||
| MS_LOG(EXCEPTION) << "Kernel " << kernel_name << " is not supported."; | |||
| } else { | |||
| nccl_kernel_type_ = iter->second; | |||
| } | |||
| auto reduce_op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kAttrOp); | |||
| if (reduce_op) { | |||
| std::string type = GetValue<std::string>(reduce_op); | |||
| if (type == "sum") { | |||
| nccl_reduce_type_ = ncclSum; | |||
| } else if (type == "max") { | |||
| nccl_reduce_type_ = ncclMax; | |||
| } else if (type == "min") { | |||
| nccl_reduce_type_ = ncclMin; | |||
| } else if (type == "prod") { | |||
| nccl_reduce_type_ = ncclProd; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Nccl reduce type " << type << " is not supported."; | |||
| } | |||
| } | |||
| auto root_rank = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kAttrRootRank); | |||
| if (root_rank) { | |||
| root_ = static_cast<int>(GetValue<int64_t>(root_rank)); | |||
| } | |||
| return; | |||
| } | |||
| size_t AlignMemorySize(size_t size) const { | |||
| if (size == 0) { | |||
| return COMMUNICATION_MEM_ALIGN_SIZE; | |||
| } | |||
| return ((size + COMMUNICATION_MEM_ALIGN_SIZE - 1) / COMMUNICATION_MEM_ALIGN_SIZE) * COMMUNICATION_MEM_ALIGN_SIZE; | |||
| } | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| NcclKernelType nccl_kernel_type_; | |||
| ncclRedOp_t nccl_reduce_type_; | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| int root_; | |||
| const void *collective_handle_; | |||
| cudaStream_t comm_stream_; | |||
| static const size_t COMMUNICATION_MEM_ALIGN_SIZE = 16; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_COLLECTIVE_GPU_KERNEL_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -18,11 +18,9 @@ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_GPU_KERNEL_H_ | |||
| #include <nccl.h> | |||
| #include <dlfcn.h> | |||
| #include <stdint.h> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| @@ -30,20 +28,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| enum NcclKernelType { | |||
| NCCL_ALL_REDUCE = 0, | |||
| NCCL_ALL_GATHER, | |||
| NCCL_REDUCE_SCATTER, | |||
| NCCL_BROADCAST, | |||
| NCCL_INVALID_TYPE = 255 | |||
| }; | |||
| const std::map<std::string, NcclKernelType> kNcclTypeMap = { | |||
| {"AllReduce", NCCL_ALL_REDUCE}, | |||
| {"AllGather", NCCL_ALL_GATHER}, | |||
| {"ReduceScatter", NCCL_REDUCE_SCATTER}, | |||
| {"Broadcast", NCCL_BROADCAST}, | |||
| }; | |||
| static std::map<std::string, ncclDataType_t> kNcclDtypeMap = { | |||
| {"kNumberTypeFloat32", ncclFloat}, {"kNumberTypeFloat16", ncclHalf}, {"kNumberTypeInt32", ncclInt}}; | |||
| @@ -53,174 +37,22 @@ typedef ncclResult_t (*AllGather)(const void *, void *, size_t, ncclDataType_t, | |||
| typedef ncclResult_t (*ReduceScatter)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t, | |||
| const std::string &); | |||
| typedef ncclResult_t (*Broadcast)(const void *, void *, size_t, ncclDataType_t, int, cudaStream_t, const std::string &); | |||
| typedef ncclResult_t (*Send)(const void *, size_t, ncclDataType_t, int, cudaStream_t, const std::string &); | |||
| typedef ncclResult_t (*Recv)(void *, size_t, ncclDataType_t, int, cudaStream_t, const std::string &); | |||
| typedef ncclResult_t (*GroupStart)(); | |||
| typedef ncclResult_t (*GroupEnd)(); | |||
| typedef std::vector<int> (*GetGroupRanks)(const std::string &); | |||
| template <typename T> | |||
| class NcclGpuKernel : public GpuKernel { | |||
| public: | |||
| NcclGpuKernel() | |||
| : nccl_kernel_type_(NCCL_INVALID_TYPE), | |||
| nccl_reduce_type_(ncclSum), | |||
| group_name_(""), | |||
| input_size_(0), | |||
| output_size_(0), | |||
| root_(0), | |||
| collective_handle_(nullptr), | |||
| comm_stream_(nullptr) {} | |||
| NcclGpuKernel() : group_name_(""), nccl_data_type_(ncclHalf) {} | |||
| ~NcclGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr); | |||
| switch (nccl_kernel_type_) { | |||
| case NCCL_ALL_REDUCE: { | |||
| auto all_reduce_funcptr = | |||
| reinterpret_cast<AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce")); | |||
| MS_EXCEPTION_IF_NULL(all_reduce_funcptr); | |||
| CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), | |||
| nccl_data_type_, nccl_reduce_type_, stream, group_name_), | |||
| "ncclAllReduce failed"); | |||
| break; | |||
| } | |||
| case NCCL_ALL_GATHER: { | |||
| auto all_gather_funcptr = | |||
| reinterpret_cast<AllGather>(dlsym(const_cast<void *>(collective_handle_), "AllGather")); | |||
| MS_EXCEPTION_IF_NULL(all_gather_funcptr); | |||
| CHECK_NCCL_RET_WITH_EXCEPT( | |||
| (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream, group_name_), | |||
| "ncclAllGather failed"); | |||
| break; | |||
| } | |||
| case NCCL_REDUCE_SCATTER: { | |||
| auto reduce_scatter_funcptr = | |||
| reinterpret_cast<ReduceScatter>(dlsym(const_cast<void *>(collective_handle_), "ReduceScatter")); | |||
| MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr); | |||
| CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), | |||
| nccl_data_type_, nccl_reduce_type_, stream, group_name_), | |||
| "ncclReduceScatter failed"); | |||
| break; | |||
| } | |||
| case NCCL_BROADCAST: { | |||
| auto broadcast_funcptr = | |||
| reinterpret_cast<Broadcast>(dlsym(const_cast<void *>(collective_handle_), "Broadcast")); | |||
| MS_EXCEPTION_IF_NULL(broadcast_funcptr); | |||
| for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) { | |||
| input_addr = GetDeviceAddress<T>(inputs, i); | |||
| output_addr = GetDeviceAddress<T>(outputs, i); | |||
| CHECK_NCCL_RET_WITH_EXCEPT((*broadcast_funcptr)(input_addr, output_addr, output_size_list_[i] / sizeof(T), | |||
| nccl_data_type_, root_, stream, group_name_), | |||
| "ncclBroadcast failed"); | |||
| } | |||
| break; | |||
| } | |||
| default: { | |||
| MS_LOG(EXCEPTION) << "Kernel type " << nccl_kernel_type_ << " is not supported."; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| nccl_data_type_ = kNcclDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||
| InferCommType(kernel_node); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); | |||
| size_t size = sizeof(T); | |||
| for (size_t j = 0; j < shape.size(); j++) { | |||
| size *= IntToSize(shape[j]); | |||
| } | |||
| size_t aligned_size = (nccl_kernel_type_ != NCCL_ALL_REDUCE) ? size : AlignMemorySize(size); | |||
| input_size_list_.push_back(aligned_size); | |||
| input_size_ += aligned_size; | |||
| } | |||
| for (size_t i = 0; i < output_num; ++i) { | |||
| auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i); | |||
| size_t size = sizeof(T); | |||
| for (size_t j = 0; j < shape.size(); j++) { | |||
| size *= IntToSize(shape[j]); | |||
| } | |||
| size_t aligned_size = (nccl_kernel_type_ != NCCL_ALL_REDUCE) ? size : AlignMemorySize(size); | |||
| output_size_list_.push_back(aligned_size); | |||
| output_size_ += aligned_size; | |||
| } | |||
| group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup); | |||
| MS_LOG(INFO) << AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_; | |||
| auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id"); | |||
| if (comm_stream_attr) { | |||
| comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr)); | |||
| MS_EXCEPTION_IF_NULL(comm_stream_); | |||
| } | |||
| collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); | |||
| MS_EXCEPTION_IF_NULL(collective_handle_); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { return; } | |||
| private: | |||
| void InferCommType(const CNodePtr &kernel_node) { | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| auto iter = kNcclTypeMap.find(kernel_name); | |||
| if (iter == kNcclTypeMap.end()) { | |||
| MS_LOG(EXCEPTION) << "Kernel " << kernel_name << " is not supported."; | |||
| } else { | |||
| nccl_kernel_type_ = iter->second; | |||
| } | |||
| auto reduce_op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kAttrOp); | |||
| if (reduce_op) { | |||
| std::string type = GetValue<std::string>(reduce_op); | |||
| if (type == "sum") { | |||
| nccl_reduce_type_ = ncclSum; | |||
| } else if (type == "max") { | |||
| nccl_reduce_type_ = ncclMax; | |||
| } else if (type == "min") { | |||
| nccl_reduce_type_ = ncclMin; | |||
| } else if (type == "prod") { | |||
| nccl_reduce_type_ = ncclProd; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Nccl reduce type " << type << " is not supported."; | |||
| } | |||
| } | |||
| auto root_rank = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kAttrRootRank); | |||
| if (root_rank) { | |||
| root_ = static_cast<int>(GetValue<int64_t>(root_rank)); | |||
| } | |||
| return; | |||
| } | |||
| ncclDataType_t nccl_dtype(const TypeId &type_id) { return kNcclDtypeMap[TypeIdLabel(type_id)]; } | |||
| size_t AlignMemorySize(size_t size) const { | |||
| if (size == 0) { | |||
| return COMMUNICATION_MEM_ALIGN_SIZE; | |||
| } | |||
| return ((size + COMMUNICATION_MEM_ALIGN_SIZE - 1) / COMMUNICATION_MEM_ALIGN_SIZE) * COMMUNICATION_MEM_ALIGN_SIZE; | |||
| } | |||
| NcclKernelType nccl_kernel_type_; | |||
| ncclRedOp_t nccl_reduce_type_; | |||
| ncclDataType_t nccl_data_type_; | |||
| std::string group_name_; | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| int root_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| const void *collective_handle_; | |||
| cudaStream_t comm_stream_; | |||
| static const size_t COMMUNICATION_MEM_ALIGN_SIZE = 16; | |||
| ncclDataType_t nccl_data_type_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nccl/nccl_recv_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(Receive, KernelAttr().AddAllSameAttr(true).AddOutputAttr(kNumberTypeFloat32), NcclRecvGpuKernel, | |||
| float); | |||
| MS_REG_GPU_KERNEL_ONE(Receive, KernelAttr().AddAllSameAttr(true).AddOutputAttr(kNumberTypeFloat16), NcclRecvGpuKernel, | |||
| half); | |||
| MS_REG_GPU_KERNEL_ONE(Receive, KernelAttr().AddAllSameAttr(true).AddOutputAttr(kNumberTypeInt32), NcclRecvGpuKernel, | |||
| int); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,88 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_RECV_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_RECV_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <functional> | |||
| #include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class NcclRecvGpuKernel : public NcclGpuKernel { | |||
| public: | |||
| NcclRecvGpuKernel() : src_rank_(-1), collective_handle_(nullptr) {} | |||
| ~NcclRecvGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &outputs, | |||
| void *stream_ptr) override { | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| auto nccl_recv_func = reinterpret_cast<Recv>(dlsym(const_cast<void *>(collective_handle_), "Recv")); | |||
| MS_EXCEPTION_IF_NULL(nccl_recv_func); | |||
| CHECK_NCCL_RET_WITH_EXCEPT((*nccl_recv_func)(output_addr, output_size_list_[0] / sizeof(T), nccl_data_type_, | |||
| src_rank_, reinterpret_cast<cudaStream_t>(stream_ptr), group_name_), | |||
| "ncclRecv failed"); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 0) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but NCCL receive needs 0 input."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but NCCL receive needs 1 output."; | |||
| return false; | |||
| } | |||
| src_rank_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "src_rank")); | |||
| group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup); | |||
| nccl_data_type_ = nccl_dtype(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0)); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| size_t output_size = | |||
| std::accumulate(output_shape.begin(), output_shape.end(), sizeof(T), std::multiplies<size_t>()); | |||
| output_size_list_.push_back(output_size); | |||
| MS_LOG(INFO) << "NcclRecv source rank is " << src_rank_ << ", group name is " << group_name_; | |||
| collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); | |||
| MS_EXCEPTION_IF_NULL(collective_handle_); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override {} | |||
| private: | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| int src_rank_; | |||
| const void *collective_handle_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_RECV_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nccl/nccl_send_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Send, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| NcclSendGpuKernel, float); | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Send, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| NcclSendGpuKernel, half); | |||
| MS_REG_GPU_KERNEL_ONE(Send, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| NcclSendGpuKernel, int); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,84 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SEND_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SEND_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <functional> | |||
| #include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class NcclSendGpuKernel : public NcclGpuKernel { | |||
| public: | |||
| NcclSendGpuKernel() : dest_rank_(-1), collective_handle_(nullptr) {} | |||
| ~NcclSendGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||
| auto nccl_send_func = reinterpret_cast<Send>(dlsym(const_cast<void *>(collective_handle_), "Send")); | |||
| MS_EXCEPTION_IF_NULL(nccl_send_func); | |||
| CHECK_NCCL_RET_WITH_EXCEPT((*nccl_send_func)(input_addr, input_size_list_[0] / sizeof(T), nccl_data_type_, | |||
| dest_rank_, reinterpret_cast<cudaStream_t>(stream_ptr), group_name_), | |||
| "ncclSend failed"); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but NCCL send needs 1 input."; | |||
| return false; | |||
| } | |||
| dest_rank_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "dest_rank")); | |||
| group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup); | |||
| nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)); | |||
| MS_LOG(INFO) << "NcclSend dest rank is " << dest_rank_ << ", group name is " << group_name_; | |||
| auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| size_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), sizeof(T), std::multiplies<size_t>()); | |||
| input_size_list_.push_back(input_size); | |||
| output_size_list_.push_back(0); | |||
| collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); | |||
| MS_EXCEPTION_IF_NULL(collective_handle_); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override {} | |||
| private: | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| int dest_rank_; | |||
| const void *collective_handle_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SEND_GPU_KERNEL_H_ | |||
| @@ -207,7 +207,7 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod | |||
| auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; | |||
| Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank)); | |||
| OperatorAttrs attrs = {attr_tag, attr_rank}; | |||
| auto send_op = CreatOpInstance(attrs, "_Send", "send"); | |||
| auto send_op = CreatOpInstance(attrs, "Send", "send"); | |||
| auto send_node = NewValueNode(send_op); | |||
| auto prim = GetValueNode<PrimitivePtr>(send_node); | |||
| auto shape_type_pair = GetShapeType(parameter); | |||
| @@ -233,7 +233,7 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode | |||
| Attr attr_shape = std::make_pair("shape", shape_type_pair.first); | |||
| Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second); | |||
| OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype}; | |||
| auto recv_op = CreatOpInstance(attrs, "_Receive", "recv"); | |||
| auto recv_op = CreatOpInstance(attrs, "Receive", "recv"); | |||
| std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op), virtual_param_}; | |||
| auto recv = graph->NewCNode(recv_input); | |||
| manager_->SetEdge(use_node, index, recv); | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_ | |||
| #include <nccl.h> | |||
| #include <vector> | |||
| #include <sstream> | |||
| #include "pybind11/pybind11.h" | |||
| @@ -31,6 +32,7 @@ struct NcclGroupInfo { | |||
| int rank; | |||
| ncclUniqueId unique_id; | |||
| ncclComm_t comm; | |||
| std::vector<int> group_ranks; | |||
| }; | |||
| #define CHECK_RET(expression, result, message) \ | |||
| { \ | |||
| @@ -53,3 +53,21 @@ ncclResult_t Broadcast(const void *input_addr, void *output_addr, size_t count, | |||
| cudaStream_t stream, const std::string &group) { | |||
| return NCCLWrapper::instance().Broadcast(input_addr, output_addr, count, data_type, root, stream, group); | |||
| } | |||
| ncclResult_t Send(const void *send_addr, size_t count, ncclDataType_t data_type, int peer_rank, cudaStream_t stream, | |||
| const std::string &group_name) { | |||
| return NCCLWrapper::instance().Send(send_addr, count, data_type, peer_rank, stream, group_name); | |||
| } | |||
| ncclResult_t Recv(void *recv_addr, size_t count, ncclDataType_t data_type, int peer_rank, cudaStream_t stream, | |||
| const std::string &group_name) { | |||
| return NCCLWrapper::instance().Recv(recv_addr, count, data_type, peer_rank, stream, group_name); | |||
| } | |||
| ncclResult_t GroupStart() { return NCCLWrapper::instance().GroupStart(); } | |||
| ncclResult_t GroupEnd() { return NCCLWrapper::instance().GroupEnd(); } | |||
| std::vector<int> GetGroupRanks(const std::string &group_name) { | |||
| return NCCLWrapper::instance().GetGroupRanks(group_name); | |||
| } | |||
| @@ -48,3 +48,10 @@ extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, voi | |||
| extern "C" EXPORT_WRAPPER ncclResult_t Broadcast(const void *input_addr, void *output_addr, size_t count, | |||
| ncclDataType_t data_type, int root, cudaStream_t stream, | |||
| const std::string &group); | |||
| extern "C" EXPORT_WRAPPER ncclResult_t Send(const void *send_addr, size_t count, ncclDataType_t data_type, | |||
| int peer_rank, cudaStream_t stream, const std::string &group_name); | |||
| extern "C" EXPORT_WRAPPER ncclResult_t Recv(void *recv_addr, size_t count, ncclDataType_t data_type, int peer_rank, | |||
| cudaStream_t stream, const std::string &group_name); | |||
| extern "C" EXPORT_WRAPPER ncclResult_t GroupStart(); | |||
| extern "C" EXPORT_WRAPPER ncclResult_t GroupEnd(); | |||
| extern "C" EXPORT_WRAPPER std::vector<int> GetGroupRanks(const std::string &group_name); | |||
| @@ -68,7 +68,7 @@ bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vecto | |||
| return false; | |||
| } | |||
| NcclGroupInfo nccl_group = {static_cast<int>(ranks.size()), group_rank[0], group_unique_id, nullptr}; | |||
| NcclGroupInfo nccl_group = {static_cast<int>(ranks.size()), group_rank[0], group_unique_id, nullptr, ranks}; | |||
| NCCLWrapper::instance().AddGroupInfo(group_name, &nccl_group); | |||
| return true; | |||
| } | |||
| @@ -122,7 +122,11 @@ void MPIWrapper::Init() { | |||
| CHECK_RET(MPI_Bcast(reinterpret_cast<void *>(&unique_id), sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD), | |||
| MPI_SUCCESS, "Failed to broadcast nccl unique id."); | |||
| NcclGroupInfo world_group = {rank_size_, rank_id_, unique_id, nullptr}; | |||
| std::vector<int> world_group_ranks = {}; | |||
| for (int global_rank = 0; global_rank < rank_size_; global_rank++) { | |||
| world_group_ranks.push_back(global_rank); | |||
| } | |||
| NcclGroupInfo world_group = {rank_size_, rank_id_, unique_id, nullptr, world_group_ranks}; | |||
| NCCLWrapper::instance().AddGroupInfo(NCCL_WORLD_GROUP, &world_group); | |||
| return; | |||
| } | |||
| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include "runtime/device/gpu/distribution/nccl_wrapper.h" | |||
| namespace mindspore { | |||
| @@ -74,6 +75,24 @@ ncclResult_t NCCLWrapper::Broadcast(const void *input_addr, void *output_addr, s | |||
| return ncclBroadcast(input_addr, output_addr, count, data_type, root, group_comm, stream); | |||
| } | |||
| ncclResult_t NCCLWrapper::Send(const void *send_addr, size_t count, ncclDataType_t data_type, int peer_rank, | |||
| cudaStream_t stream, const std::string &group_name) { | |||
| CHECK_RET(group_info_.count(group_name), 1, "Failed to find group info for Send by the group name " + group_name); | |||
| ncclComm_t group_comm = group_info_[group_name].comm; | |||
| return ncclSend(send_addr, count, data_type, peer_rank, group_comm, stream); | |||
| } | |||
| ncclResult_t NCCLWrapper::Recv(void *recv_addr, size_t count, ncclDataType_t data_type, int peer_rank, | |||
| cudaStream_t stream, const std::string &group_name) { | |||
| CHECK_RET(group_info_.count(group_name), 1, "Failed to find group info for Recv by the group name " + group_name); | |||
| ncclComm_t group_comm = group_info_[group_name].comm; | |||
| return ncclRecv(recv_addr, count, data_type, peer_rank, group_comm, stream); | |||
| } | |||
| ncclResult_t NCCLWrapper::GroupStart() { return ncclGroupStart(); } | |||
| ncclResult_t NCCLWrapper::GroupEnd() { return ncclGroupEnd(); } | |||
| void NCCLWrapper::AddGroupInfo(const std::string &group_name, NcclGroupInfo *group) { | |||
| if (comm_init_done_) { | |||
| CHECK_RET(ncclCommInitRank(&(group->comm), group->size, group->unique_id, group->rank), ncclSuccess, | |||
| @@ -92,6 +111,12 @@ void NCCLWrapper::DestroyGroup(const std::string &group_name) { | |||
| group_info_.erase(group_iter); | |||
| return; | |||
| } | |||
| std::vector<int> NCCLWrapper::GetGroupRanks(const std::string &group_name) { | |||
| CHECK_RET(group_info_.count(group_name), 1, | |||
| "Failed to find group info for GetGroupRanks by the group name " + group_name); | |||
| return group_info_[group_name].group_ranks; | |||
| } | |||
| } // namespace gpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -21,6 +21,7 @@ | |||
| #include <stdlib.h> | |||
| #include <nccl.h> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "runtime/device/gpu/distribution/collective_common.h" | |||
| @@ -34,16 +35,23 @@ class NCCLWrapper { | |||
| static NCCLWrapper &instance(); | |||
| ncclUniqueId nccl_unique_id() const; | |||
| void InitNCCLComm(); | |||
| ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, | |||
| ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); | |||
| ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, | |||
| cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); | |||
| ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, | |||
| ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); | |||
| ncclResult_t Broadcast(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, int root, | |||
| cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); | |||
| ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, | |||
| ncclRedOp_t op, cudaStream_t stream, const std::string &group_name); | |||
| ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, | |||
| cudaStream_t stream, const std::string &group_name); | |||
| ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, | |||
| ncclRedOp_t op, cudaStream_t stream, const std::string &group_name); | |||
| ncclResult_t Broadcast(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, int root, | |||
| cudaStream_t stream, const std::string &group_name); | |||
| ncclResult_t Send(const void *send_addr, size_t count, ncclDataType_t data_type, int peer_rank, cudaStream_t stream, | |||
| const std::string &group_name); | |||
| ncclResult_t Recv(void *recv_addr, size_t count, ncclDataType_t data_type, int peer_rank, cudaStream_t stream, | |||
| const std::string &group_name); | |||
| ncclResult_t GroupStart(); | |||
| ncclResult_t GroupEnd(); | |||
| void AddGroupInfo(const std::string &group_name, NcclGroupInfo *group); | |||
| void DestroyGroup(const std::string &group_name); | |||
| std::vector<int> GetGroupRanks(const std::string &group_name); | |||
| private: | |||
| NCCLWrapper() : comm_init_done_(false) {} | |||
| @@ -143,17 +143,17 @@ void InsertStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_ | |||
| size_t recv_node_offset = pair.recv_node_offset; | |||
| CNodePtr send_node = nullptr; | |||
| CNodePtr recv_node = nullptr; | |||
| // Step 1: generate Send and Recv CNodes. | |||
| // Step 1: Generate stream Send and Recv CNodes. | |||
| if (stream_switch_type == kAllReduceStreamSwitch) { | |||
| if (!GenSendRecvCNodesForAllReduce(kernel_graph, mock_send_node, mock_recv_node, &send_node, &recv_node)) { | |||
| MS_LOG(EXCEPTION) << "Generating CNodes for send and recv failed. Stream switch type: kAllReduceStreamSwitch"; | |||
| } | |||
| } | |||
| // Step 2: sort send and recv CNodes by offset. | |||
| // Step 2: Sort send and recv CNodes by offset. | |||
| ordered_stream_switch_nodes.insert({send_node_offset, send_node}); | |||
| ordered_stream_switch_nodes.insert({recv_node_offset, recv_node}); | |||
| } | |||
| // Step 3: insert stream switch CNodes into execution kernel list. | |||
| // Step 3: Insert stream switch CNodes into execution kernel list. | |||
| auto execution_kernels = kernel_graph->execution_order(); | |||
| for (auto node = ordered_stream_switch_nodes.rbegin(); node != ordered_stream_switch_nodes.rend(); node++) { | |||
| execution_kernels.insert(execution_kernels.begin() + node->offset, node->cnode); | |||
| @@ -185,7 +185,7 @@ inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD"); | |||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | |||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | |||
| inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("_Receive"); | |||
| inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive"); | |||
| inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||
| inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap"); | |||
| inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast"); | |||
| @@ -20,7 +20,7 @@ from .. import operations as P | |||
| from ...common.tensor import RowTensor | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||
| _GetTensorSlice, _MirrorOperator, ReduceOp, _Send, _Receive, | |||
| _GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive, | |||
| ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) | |||
| from .grad_base import bprop_getters | |||
| @@ -77,12 +77,12 @@ def get_bprop_all_reduce(self): | |||
| return bprop | |||
| @bprop_getters.register(_Send) | |||
| @bprop_getters.register(Send) | |||
| def get_bprop_send(self): | |||
| """Generate bprop for Send.""" | |||
| shape = self.get_attr_dict()["shape"] | |||
| dtype = self.get_attr_dict()["dtype"] | |||
| send_grad = _Receive(self.sr_tag, self.rank, shape, dtype, self.group) | |||
| send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group) | |||
| def bprop(x, out, dout): | |||
| dx = send_grad() | |||
| @@ -90,10 +90,10 @@ def get_bprop_send(self): | |||
| return bprop | |||
| @bprop_getters.register(_Receive) | |||
| @bprop_getters.register(Receive) | |||
| def get_bprop_receive(self): | |||
| """Generate bprop for Receive.""" | |||
| receive_grad = _Send(self.tag, self.rank, self.group) | |||
| receive_grad = Send(self.tag, self.rank, self.group) | |||
| depend = P.Depend() | |||
| cast = P.Cast() | |||
| @@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| Unique, GatherD, Identity, RepeatElements) | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice, _Send, _Receive, | |||
| _VirtualDiv, _GetTensorSlice, Send, Receive, | |||
| _HostAllGather, _HostReduceScatter) | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | |||
| TensorSummary, HistogramSummary, Print, Assert) | |||
| @@ -116,7 +116,7 @@ class AllReduce(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class _Send(PrimitiveWithInfer): | |||
| class Send(PrimitiveWithInfer): | |||
| """ | |||
| Send tensors from src_rank to the specified dest_rank. | |||
| @@ -145,7 +145,7 @@ class _Send(PrimitiveWithInfer): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.depend = P.Depend() | |||
| >>> self.send = P._Send(st_tag=0, dest_rank=8, group="hccl_world_group") | |||
| >>> self.send = P.Send(st_tag=0, dest_rank=8, group="hccl_world_group") | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> out = self.depend(x, self.send(x)) | |||
| @@ -170,7 +170,7 @@ class _Send(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class _Receive(PrimitiveWithInfer): | |||
| class Receive(PrimitiveWithInfer): | |||
| """ | |||
| receive tensors from src_rank. | |||
| @@ -201,7 +201,7 @@ class _Receive(PrimitiveWithInfer): | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.recv = P._Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32, | |||
| >>> self.recv = P.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32, | |||
| >>> group="hccl_world_group") | |||
| >>> | |||
| >>> def construct(self, x): | |||
| @@ -53,3 +53,10 @@ def test_nccl_reduce_scatter_op(): | |||
| def test_nccl_broadcast_op(): | |||
| return_code = os.system("mpirun -n 8 pytest -s test_nccl_broadcast_op.py") | |||
| assert return_code == 0 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_single | |||
| def test_nccl_send_recv_op(): | |||
| return_code = os.system("mpirun -n 8 pytest -s test_nccl_send_recv_op.py") | |||
| assert return_code == 0 | |||
| @@ -48,7 +48,7 @@ def test_AllGather(): | |||
| for i in range(size - 1): | |||
| tmp = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * (i + 2) | |||
| expect = np.concatenate((expect, tmp)) | |||
| diff = output.asnumpy() - expect | |||
| diff = np.absolute(output.asnumpy() - expect) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| assert np.all(diff < error) | |||
| assert output.shape == expect.shape | |||
| @@ -0,0 +1,69 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common import dtype as mstype | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| init() | |||
| rank = get_rank() | |||
| size = get_group_size() | |||
| if size % 2 != 0: | |||
| raise RuntimeError("Group size should be divided by 2 exactly.") | |||
| x = np.ones([3, 3, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) | |||
| class SendNet(nn.Cell): | |||
| def __init__(self): | |||
| super(SendNet, self).__init__() | |||
| self.x = Parameter(initializer(Tensor(x), x.shape), name='x') | |||
| self.depend = P.Depend() | |||
| self.send = P.Send(sr_tag=0, dest_rank=rank+size//2, group=NCCL_WORLD_COMM_GROUP) | |||
| def construct(self): | |||
| out = self.depend(self.x, self.send(self.x)) | |||
| return out | |||
| class RecvNet(nn.Cell): | |||
| def __init__(self): | |||
| super(RecvNet, self).__init__() | |||
| self.recv = P.Receive(sr_tag=0, src_rank=rank-size//2, shape=[3, 3, 3, 3], dtype=mstype.float32, | |||
| group=NCCL_WORLD_COMM_GROUP) | |||
| def construct(self): | |||
| out = self.recv() | |||
| return out | |||
| def test_send_recv(): | |||
| if rank < size / 2: | |||
| send_net = SendNet() | |||
| output = send_net() | |||
| else: | |||
| expect_output = np.ones([3, 3, 3, 3]).astype(np.float32) * 0.01 * (rank-size//2 + 1) | |||
| recv_net = RecvNet() | |||
| output = recv_net() | |||
| diff = abs(output.asnumpy() - expect_output) | |||
| error = np.ones(shape=output.shape) * 1.0e-5 | |||
| assert np.all(diff < error) | |||
| assert expect_output.shape == output.shape | |||