From 3ea3d9e5a433e47196b9dc087bd00a50bdcf1163 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Mon, 13 Apr 2020 21:04:33 +0800 Subject: [PATCH] 1.GPU supports multiple streams. 2.GPU commnication stream and compute stream overlap. --- .../ccsrc/device/gpu/gpu_device_manager.cc | 20 +- .../ccsrc/device/gpu/gpu_device_manager.h | 16 +- .../ccsrc/device/gpu/gpu_stream_assign.cc | 181 ++++++++++++++++++ .../ccsrc/device/gpu/gpu_stream_assign.h | 73 +++++++ .../ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h | 25 +-- mindspore/ccsrc/session/gpu_session.cc | 8 + mindspore/ccsrc/session/gpu_session.h | 2 + mindspore/ccsrc/utils/utils.h | 2 + 8 files changed, 304 insertions(+), 23 deletions(-) create mode 100644 mindspore/ccsrc/device/gpu/gpu_stream_assign.cc create mode 100644 mindspore/ccsrc/device/gpu/gpu_stream_assign.h diff --git a/mindspore/ccsrc/device/gpu/gpu_device_manager.cc b/mindspore/ccsrc/device/gpu/gpu_device_manager.cc index 59c8fde5a2..b25ba2906b 100644 --- a/mindspore/ccsrc/device/gpu/gpu_device_manager.cc +++ b/mindspore/ccsrc/device/gpu/gpu_device_manager.cc @@ -25,7 +25,7 @@ namespace device { namespace gpu { void GPUDeviceManager::InitDevice() { CHECK_OP_RET_WITH_EXCEPT(CudaDriver::set_current_device(SizeToInt(cur_dev_id_)), "Failed to set current device id"); - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(&stream_), "Failed to create CUDA stream."); + CHECK_OP_RET_WITH_EXCEPT(CreateStream(&default_stream_), "Failed to create CUDA stream."); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreate(&cudnn_handle_), "Failed to create cuDNN handle"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetStream(cudnn_handle_, reinterpret_cast(default_stream())), "Failed to set stream for cuDNN handle."); @@ -36,19 +36,27 @@ void GPUDeviceManager::InitDevice() { } void GPUDeviceManager::ReleaseDevice() { - if (stream_ != nullptr) { - CHECK_OP_RET_WITH_ERROR(CudaDriver::DestroyStream(stream_), "Failed to destroy cuda stream."); + for (DeviceStream stream : gpu_streams_) { + if (stream != nullptr) { + CHECK_OP_RET_WITH_ERROR(CudaDriver::DestroyStream(stream), "Failed to destroy CUDA stream."); + } } if (cudnn_handle_ != nullptr) { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroy(cudnn_handle_), "Failed to destroy cudnn handle"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroy(cudnn_handle_), "Failed to destroy cuDNN handle"); } if (cublas_handle_ != nullptr) { - CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cublas handle."); + CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cuBLAS handle."); } CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); } -const DeviceStream& GPUDeviceManager::default_stream() const { return stream_; } +bool GPUDeviceManager::CreateStream(DeviceStream* stream) { + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); + gpu_streams_.emplace_back(*stream); + return true; +} + +const DeviceStream& GPUDeviceManager::default_stream() const { return default_stream_; } int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } diff --git a/mindspore/ccsrc/device/gpu/gpu_device_manager.h b/mindspore/ccsrc/device/gpu/gpu_device_manager.h index 6bfaf85673..3b3d2aecb5 100644 --- a/mindspore/ccsrc/device/gpu/gpu_device_manager.h +++ b/mindspore/ccsrc/device/gpu/gpu_device_manager.h @@ -19,6 +19,7 @@ #include #include +#include #include #include "device/gpu/cuda_driver.h" #include "device/gpu/gpu_memory_allocator.h" @@ -36,13 +37,15 @@ class GPUDeviceManager { uint32_t cur_device_id() const; bool is_device_id_init() const; + bool CreateStream(DeviceStream* stream); + bool SyncStream(const DeviceStream& stream) const; const DeviceStream& default_stream() const; + const cudnnHandle_t& GetCudnnHandle() const; const cublasHandle_t& GetCublasHandle() const; bool CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const; bool CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const; - bool SyncStream(const DeviceStream& stream) const; static GPUDeviceManager& GetInstance() { static GPUDeviceManager instance; @@ -55,13 +58,16 @@ class GPUDeviceManager { GPUDeviceManager(const GPUDeviceManager&) = delete; GPUDeviceManager& operator=(const GPUDeviceManager&) = delete; - // default cuda stream used for all the kernels. - DeviceStream stream_{nullptr}; + // default CUDA stream used for all the kernels. + DeviceStream default_stream_{nullptr}; + + // all gpu CUDA streams including default_stream_. + std::vector gpu_streams_; - // handle used for cudnn kernels. + // handle used for cuDNN kernels. cudnnHandle_t cudnn_handle_{nullptr}; - // handle used for cublas kernels. + // handle used for cuBLAS kernels. cublasHandle_t cublas_handle_{nullptr}; bool dev_id_init_; diff --git a/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc b/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc new file mode 100644 index 0000000000..39d5ca3fe6 --- /dev/null +++ b/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc @@ -0,0 +1,181 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "device/gpu/gpu_common.h" +#include "device/gpu/kernel_info_setter.h" +#include "device/gpu/gpu_device_manager.h" +#include "device/gpu/gpu_stream_assign.h" + +namespace mindspore { +namespace device { +namespace gpu { +void AssignGpuStream(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector allreduce_cnodes; + auto execution_kernels = kernel_graph->execution_order(); + for (auto kernel : execution_kernels) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel); + if (kernel_name == kAllReduceOpName) { + allreduce_cnodes.emplace_back(kernel); + } + } + if (allreduce_cnodes.size() > 1) { + DeviceStream comm_stream = nullptr; + GPUDeviceManager::GetInstance().CreateStream(&comm_stream); + std::transform(allreduce_cnodes.begin(), allreduce_cnodes.end(), allreduce_cnodes.begin(), [&](CNodePtr node) { + AnfAlgo::SetNodeAttr("stream_id", MakeValue(reinterpret_cast(comm_stream)), node); + return node; + }); + + std::vector send_recv_pairs; + FindAllReduceStreamSwitchPos(kernel_graph, &send_recv_pairs); + InsertStreamSwitchNode(kernel_graph, send_recv_pairs); + } +} + +void FindAllReduceStreamSwitchPos(const std::shared_ptr &kernel_graph, + std::vector *send_recv_pairs) { + auto execution_kernels = kernel_graph->execution_order(); + std::vector::iterator iter, iter_begin; + iter = iter_begin = execution_kernels.begin(); + std::vector::iterator iter_end = execution_kernels.end(); + for (; iter != execution_kernels.end(); ++iter) { + std::string kernel_name = AnfAlgo::GetCNodeName(*iter); + if (kernel_name == kAllReduceOpName) { + // Find AllReduce node's last input node. + std::vector::iterator mock_send_node_iter = + FindSendNodePos(iter_begin, iter + 1, *iter, kAllReduceStreamSwitch); + if (mock_send_node_iter == iter + 1) { + MS_LOG(WARNING) << "Can't find send node place before AllReduce node."; + continue; + } + SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter, + IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)}; + send_recv_pairs->push_back(pair1); + // Find node which uses AllReduce as input[0]. + std::vector::iterator mock_recv_node_iter = + FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch); + if (mock_recv_node_iter == iter_end) { + MS_LOG(WARNING) << "Can't find send node place before AllReduce node."; + continue; + } + SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1), + IntToSize(mock_recv_node_iter - iter_begin)}; + send_recv_pairs->push_back(pair2); + } + } +} + +std::vector::iterator FindSendNodePos(std::vector::iterator begin, + std::vector::iterator end, const CNodePtr mock_recv_node, + StreamSwitchType stream_switch_type) { + MS_EXCEPTION_IF_NULL(mock_recv_node); + if (stream_switch_type == kAllReduceStreamSwitch) { + for (auto iter = begin; iter != end; iter++) { + if (*(iter + 1) == mock_recv_node) { + return iter; + } + } + } + return end; +} + +std::vector::iterator FindRecvNodePos(std::vector::iterator begin, + std::vector::iterator end, const CNodePtr mock_send_node, + StreamSwitchType stream_switch_type) { + MS_EXCEPTION_IF_NULL(mock_send_node); + for (auto iter = begin; iter != end; iter++) { + auto node = *iter; + if (stream_switch_type == kAllReduceStreamSwitch) { + for (auto input : node->inputs()) { + if (mock_send_node == AnfAlgo::VisitKernel(input, 0).first) { + return iter; + } + } + } + } + return end; +} + +void InsertStreamSwitchNode(const std::shared_ptr &kernel_graph, + const std::vector &send_recv_pairs) { + std::set ordered_stream_switch_nodes; + for (SendRecvPair pair : send_recv_pairs) { + StreamSwitchType stream_switch_type = pair.stream_switch_type; + CNodePtr mock_send_node = pair.mock_send_node; + CNodePtr mock_recv_node = pair.mock_recv_node; + size_t send_node_offset = pair.send_node_offset; + size_t recv_node_offset = pair.recv_node_offset; + CNodePtr send_node = nullptr; + CNodePtr recv_node = nullptr; + // Step 1: generate Send and Recv CNodes. + if (stream_switch_type == kAllReduceStreamSwitch) { + if (!GenSendRecvCNodesForAllReduce(kernel_graph, mock_send_node, mock_recv_node, &send_node, &recv_node)) { + MS_LOG(EXCEPTION) << "Generating CNodes for send and recv failed. Stream switch type: kAllReduceStreamSwitch"; + } + } + // Step 2: sort send and recv CNodes by offset. + ordered_stream_switch_nodes.insert({send_node_offset, send_node}); + ordered_stream_switch_nodes.insert({recv_node_offset, recv_node}); + } + // Step 3: insert stream switch CNodes into execution kernel list. + auto execution_kernels = kernel_graph->execution_order(); + for (auto node = ordered_stream_switch_nodes.begin(); node != ordered_stream_switch_nodes.end(); node++) { + execution_kernels.insert(execution_kernels.begin() + node->offset, node->cnode); + } + kernel_graph->set_execution_order(execution_kernels); +} + +bool GenSendRecvCNodesForAllReduce(const std::shared_ptr &kernel_graph, + const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node, + CNodePtr *recv_node) { + *send_node = CreateStreamSwitchNode(kernel_graph, kSendOpName); + MS_EXCEPTION_IF_NULL(*send_node); + *recv_node = CreateStreamSwitchNode(kernel_graph, kRecvOpName); + MS_EXCEPTION_IF_NULL(*recv_node); + + cudaEvent_t event = nullptr; + CHECK_CUDA_RET_WITH_EXCEPT(cudaEventCreate(&event, cudaEventDisableTiming), "Creating cuda event failed."); + AnfAlgo::SetNodeAttr("record_event", MakeValue(reinterpret_cast(event)), *send_node); + AnfAlgo::SetNodeAttr("wait_event", MakeValue(reinterpret_cast(event)), *recv_node); + + uintptr_t send_stream = AnfAlgo::GetNodeAttr(mock_send_node, "stream_id"); + AnfAlgo::SetNodeAttr("record_event_stream", MakeValue(send_stream), *send_node); + uintptr_t recv_stream = AnfAlgo::GetNodeAttr(mock_recv_node, "stream_id"); + AnfAlgo::SetNodeAttr("wait_event_stream", MakeValue(recv_stream), *recv_node); + return true; +} + +CNodePtr CreateStreamSwitchNode(const std::shared_ptr &kernel_graph, const std::string &name) { + auto op = std::make_shared(name); + auto apply = std::make_shared(op); + std::vector input_list = {apply}; + CNodePtr node = kernel_graph->NewCNode(input_list); + MS_EXCEPTION_IF_NULL(node); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), node.get()); + auto abstract_none = std::make_shared(); + node->set_abstract(abstract_none); + SetKernelInfo(node); + return node; +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_stream_assign.h b/mindspore/ccsrc/device/gpu/gpu_stream_assign.h new file mode 100644 index 0000000000..c7d2fe40e2 --- /dev/null +++ b/mindspore/ccsrc/device/gpu/gpu_stream_assign.h @@ -0,0 +1,73 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ + +#include +#include +#include +#include "session/kernel_graph.h" +#include "session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace device { +namespace gpu { +enum StreamSwitchType { kAllReduceStreamSwitch, kStreamSwitchInvalidType = 255 }; +struct SendRecvPair { + StreamSwitchType stream_switch_type; + CNodePtr mock_send_node; + CNodePtr mock_recv_node; + size_t send_node_offset; + size_t recv_node_offset; +}; +struct StreamSwitchNode { + size_t offset; + CNodePtr cnode; + bool operator<(const StreamSwitchNode &n) const { + if (offset < n.offset) { + return true; + } else if (offset == n.offset) { + return AnfAlgo::GetCNodeName(cnode) == kSendOpName ? true : false; + } else { + return false; + } + } +}; +void AssignGpuStream(const std::shared_ptr &kernel_graph); +void FindAllReduceStreamSwitchPos(const std::shared_ptr &kernel_graph, + std::vector *send_recv_pairs); +// Find Send node position according to "mock" recv node. +// "mock" recv node is a gpu kernel node after a real Recv node, e.g. AllReduce node. +std::vector::iterator FindSendNodePos(std::vector::iterator begin, + std::vector::iterator end, const CNodePtr mock_recv_node, + StreamSwitchType stream_switch_type); +// Find Recv node position according to "mock" send node. +// "mock" send node is a gpu kernel node before a real send node, e.g. AllReduce node. +std::vector::iterator FindRecvNodePos(std::vector::iterator begin, + std::vector::iterator end, const CNodePtr mock_send_node, + StreamSwitchType stream_switch_type); +void InsertStreamSwitchNode(const std::shared_ptr &kernel_graph, + const std::vector &send_recv_pairs); +bool GenSendRecvCNodesForAllReduce(const std::shared_ptr &kernel_graph, + const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node, + CNodePtr *recv_node); +CNodePtr CreateStreamSwitchNode(const std::shared_ptr &kernel_graph, const std::string &name); +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h index 54e4eb9213..cea56b9878 100644 --- a/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h @@ -52,7 +52,8 @@ class NcclGpuKernel : public GpuKernel { nccl_reduce_type_(ncclSum), input_size_(0), output_size_(0), - collective_handle_(nullptr) {} + collective_handle_(nullptr), + comm_stream_(nullptr) {} ~NcclGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -63,34 +64,33 @@ class NcclGpuKernel : public GpuKernel { T *input_addr = GetDeviceAddress(inputs, 0); T *output_addr = GetDeviceAddress(outputs, 0); + cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); switch (nccl_kernel_type_) { case NCCL_ALL_REDUCE: { auto all_reduce_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "AllReduce")); MS_EXCEPTION_IF_NULL(all_reduce_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT( - (*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), nccl_data_type_, nccl_reduce_type_, - reinterpret_cast(stream_ptr)), - "ncclAllReduce failed"); + CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), + nccl_data_type_, nccl_reduce_type_, stream), + "ncclAllReduce failed"); break; } case NCCL_ALL_GATHER: { auto all_gather_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "AllGather")); MS_EXCEPTION_IF_NULL(all_gather_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT((*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), - nccl_data_type_, reinterpret_cast(stream_ptr)), - "ncclAllGather failed"); + CHECK_NCCL_RET_WITH_EXCEPT( + (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream), + "ncclAllGather failed"); break; } case NCCL_REDUCE_SCATTER: { auto reduce_scatter_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "ReduceScatter")); MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT( - (*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), nccl_data_type_, - nccl_reduce_type_, reinterpret_cast(stream_ptr)), - "ncclReduceScatter failed"); + CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), + nccl_data_type_, nccl_reduce_type_, stream), + "ncclReduceScatter failed"); break; } default: { @@ -167,6 +167,7 @@ class NcclGpuKernel : public GpuKernel { std::vector output_size_list_; std::vector workspace_size_list_; const void *collective_handle_; + cudaStream_t comm_stream_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/session/gpu_session.cc b/mindspore/ccsrc/session/gpu_session.cc index bbcf2228cc..c0b2323e04 100644 --- a/mindspore/ccsrc/session/gpu_session.cc +++ b/mindspore/ccsrc/session/gpu_session.cc @@ -17,6 +17,7 @@ #include "device/gpu/kernel_info_setter.h" #include "device/gpu/gpu_kernel_build.h" #include "device/gpu/gpu_kernel_runtime.h" +#include "device/gpu/gpu_stream_assign.h" #include "pre_activate/common/optimizer.h" #include "pre_activate/common/pass_manager.h" #include "pre_activate/common/ir_fusion/allreduce_fusion.h" @@ -55,6 +56,11 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { kernel_graph->SetExecOrderByDefault(); } +void GPUSession::AssignStream(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + device::gpu::AssignGpuStream(kernel_graph); +} + void GPUSession::BuildKernel(const std::shared_ptr &kernel_graph) const { device::gpu::GpuBuild(kernel_graph); } @@ -94,6 +100,8 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList StartKernelRT(); // AllReduce Optimize Optimize(graph); + // Assign CUDA streams + AssignStream(graph); // Build kernel if node is cnode BuildKernel(graph); // Set graph execution order before memory alloc, ensure that memory alloc is according to the reorder graph diff --git a/mindspore/ccsrc/session/gpu_session.h b/mindspore/ccsrc/session/gpu_session.h index e443c1e701..d81a6c58f9 100644 --- a/mindspore/ccsrc/session/gpu_session.h +++ b/mindspore/ccsrc/session/gpu_session.h @@ -49,6 +49,8 @@ class GPUSession : public SessionBasic { void Optimize(const std::shared_ptr &kernel_graph); + void AssignStream(const std::shared_ptr &kernel_graph); + void BuildKernel(const std::shared_ptr &kernel_graph) const; void AllocateMemory(KernelGraph *kernel_graph) const; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 39b4b7a160..646fb36871 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -112,6 +112,8 @@ constexpr auto kFusedMulAddNOpName = "FusedMulAddN"; constexpr auto kFusedMulApplyMomentumOpName = "FusedMulApplyMomentum"; constexpr auto kBiasAddOpName = "BiasAdd"; constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad"; +constexpr auto kSendOpName = "Send"; +constexpr auto kRecvOpName = "Recv"; // attr key name constexpr auto kAttrInputNames = "input_names";