Merge pull request !527 from limingqi107/mastertags/v0.2.0-alpha
| @@ -111,7 +111,8 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| MS_EXCEPTION_IF_NULL(mem_manager_); | MS_EXCEPTION_IF_NULL(mem_manager_); | ||||
| mem_manager_->ResetDynamicMemory(); | mem_manager_->ResetDynamicMemory(); | ||||
| AssignStaticMemory(graph); | |||||
| AssignStaticMemoryInput(graph); | |||||
| AssignStaticMemoryValueNode(graph); | |||||
| bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); | bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); | ||||
| if (is_enable_dynamic_mem) { | if (is_enable_dynamic_mem) { | ||||
| // Use the dynamic memory pool. | // Use the dynamic memory pool. | ||||
| @@ -181,7 +182,7 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph | |||||
| bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { | bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto graph_id = graph->graph_id(); | auto graph_id = graph->graph_id(); | ||||
| // The inputs and outputs memory of communication kernel are special, so separate processing. | |||||
| // The inputs and outputs memory of communication kernel need be continuous, so separate processing. | |||||
| AllocCommunicationOpDynamicRes(graph); | AllocCommunicationOpDynamicRes(graph); | ||||
| auto &kernels = graph->execution_order(); | auto &kernels = graph->execution_order(); | ||||
| @@ -229,15 +230,12 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod | |||||
| for (size_t i = 0; i < output_sizes.size(); ++i) { | for (size_t i = 0; i < output_sizes.size(); ++i) { | ||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); | auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); | ||||
| MS_EXCEPTION_IF_NULL(device_address); | MS_EXCEPTION_IF_NULL(device_address); | ||||
| auto device_ptr = device_address->ptr_; | |||||
| if (device_ptr == nullptr) { | |||||
| device_ptr = mem_manager_->MallocMemFromMemPool(output_sizes[i]); | |||||
| MS_EXCEPTION_IF_NULL(device_ptr); | |||||
| device_address->ptr_ = device_ptr; | |||||
| if (device_address->ptr_ == nullptr) { | |||||
| mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); | |||||
| } | } | ||||
| kernel::AddressPtr output = std::make_shared<kernel::Address>(); | kernel::AddressPtr output = std::make_shared<kernel::Address>(); | ||||
| MS_EXCEPTION_IF_NULL(output); | MS_EXCEPTION_IF_NULL(output); | ||||
| output->addr = device_ptr; | |||||
| output->addr = device_address->ptr_; | |||||
| output->size = output_sizes[i]; | output->size = output_sizes[i]; | ||||
| kernel_outputs->push_back(output); | kernel_outputs->push_back(output); | ||||
| } | } | ||||
| @@ -267,7 +265,6 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph | |||||
| if (kernel_name == kAllReduceOpName) { | if (kernel_name == kAllReduceOpName) { | ||||
| AllocCommunicationOpInputDynamicRes(kernel); | AllocCommunicationOpInputDynamicRes(kernel); | ||||
| AllocCommunicationOpOutputDynamicRes(kernel); | AllocCommunicationOpOutputDynamicRes(kernel); | ||||
| return; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -275,48 +272,30 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph | |||||
| void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel) { | void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel) { | ||||
| MS_EXCEPTION_IF_NULL(kernel); | MS_EXCEPTION_IF_NULL(kernel); | ||||
| MS_EXCEPTION_IF_NULL(mem_manager_); | MS_EXCEPTION_IF_NULL(mem_manager_); | ||||
| // The reference count of communication kernel input is not 0. | |||||
| if (communication_op_input_ref_count_ != 0) { | |||||
| MS_LOG(ERROR) << "The reference count of communication kernel input is not 0."; | |||||
| return; | |||||
| } | |||||
| size_t total = 0; | |||||
| std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size; | |||||
| size_t total_size = 0; | |||||
| std::vector<size_t> size_list; | |||||
| DeviceAddressPtrList addr_list; | |||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | ||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); | auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); | ||||
| MS_EXCEPTION_IF_NULL(device_address); | MS_EXCEPTION_IF_NULL(device_address); | ||||
| // The inputs of communication kernel are not released. | // The inputs of communication kernel are not released. | ||||
| if ((i == 0) && (device_address->ptr_ != nullptr)) { | |||||
| MS_LOG(ERROR) << "The inputs of communication kernel are not released."; | |||||
| return; | |||||
| if (device_address->ptr_ != nullptr) { | |||||
| MS_LOG(INFO) << "The inputs of communication kernel are not released."; | |||||
| mem_manager_->FreeMemFromMemPool(device_address); | |||||
| } | } | ||||
| auto output_size = device_address->size_; | |||||
| total += output_size; | |||||
| addr_size.emplace_back(device_address.get(), output_size); | |||||
| } | |||||
| auto device_mem_ptr = mem_manager_->MallocMemFromMemPool(total); | |||||
| MS_EXCEPTION_IF_NULL(device_mem_ptr); | |||||
| for (const auto &iter : addr_size) { | |||||
| MS_EXCEPTION_IF_NULL(iter.first); | |||||
| iter.first->set_ptr(device_mem_ptr); | |||||
| communication_op_input_ref_count_++; | |||||
| device_mem_ptr = AddressOffset(device_mem_ptr, iter.second); | |||||
| total_size += device_address->size_; | |||||
| size_list.emplace_back(device_address->size_); | |||||
| addr_list.emplace_back(device_address); | |||||
| } | } | ||||
| mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); | |||||
| } | } | ||||
| void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel) { | void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel) { | ||||
| MS_EXCEPTION_IF_NULL(kernel); | MS_EXCEPTION_IF_NULL(kernel); | ||||
| MS_EXCEPTION_IF_NULL(mem_manager_); | MS_EXCEPTION_IF_NULL(mem_manager_); | ||||
| // The reference count of communication kernel output is not 0. | |||||
| if (communication_op_output_ref_count_ != 0) { | |||||
| MS_LOG(ERROR) << "The reference count of communication kernel output is not 0."; | |||||
| return; | |||||
| } | |||||
| size_t total = 0; | |||||
| std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size; | |||||
| size_t total_size = 0; | |||||
| std::vector<size_t> size_list; | |||||
| DeviceAddressPtrList addr_list; | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | ||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| auto output_sizes = kernel_mod->GetOutputSizeList(); | auto output_sizes = kernel_mod->GetOutputSizeList(); | ||||
| @@ -324,22 +303,15 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); | auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); | ||||
| MS_EXCEPTION_IF_NULL(device_address); | MS_EXCEPTION_IF_NULL(device_address); | ||||
| // The outputs of communication kernel are not released. | // The outputs of communication kernel are not released. | ||||
| if ((i == 0) && (device_address->ptr_ != nullptr)) { | |||||
| MS_LOG(ERROR) << "The outputs of communication kernel are not released."; | |||||
| return; | |||||
| if (device_address->ptr_ != nullptr) { | |||||
| MS_LOG(INFO) << "The outputs of communication kernel are not released."; | |||||
| mem_manager_->FreeMemFromMemPool(device_address); | |||||
| } | } | ||||
| total += output_sizes[i]; | |||||
| addr_size.emplace_back(device_address.get(), output_sizes[i]); | |||||
| } | |||||
| auto device_mem_ptr = mem_manager_->MallocMemFromMemPool(total); | |||||
| MS_EXCEPTION_IF_NULL(device_mem_ptr); | |||||
| for (const auto &iter : addr_size) { | |||||
| MS_EXCEPTION_IF_NULL(iter.first); | |||||
| iter.first->set_ptr(device_mem_ptr); | |||||
| communication_op_output_ref_count_++; | |||||
| device_mem_ptr = AddressOffset(device_mem_ptr, iter.second); | |||||
| total_size += output_sizes[i]; | |||||
| size_list.emplace_back(output_sizes[i]); | |||||
| addr_list.emplace_back(device_address); | |||||
| } | } | ||||
| mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); | |||||
| } | } | ||||
| void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, | void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, | ||||
| @@ -362,14 +334,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, | |||||
| } | } | ||||
| kernel_ref_count_ptr->ref_count_dynamic_use_--; | kernel_ref_count_ptr->ref_count_dynamic_use_--; | ||||
| if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { | if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { | ||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); | |||||
| mem_manager_->FreeMemFromMemPool(device_address); | |||||
| // Reset the reference count. | // Reset the reference count. | ||||
| kernel_ref_count_ptr->ref_count_dynamic_use_ = kernel_ref_count_ptr->ref_count_; | kernel_ref_count_ptr->ref_count_dynamic_use_ = kernel_ref_count_ptr->ref_count_; | ||||
| bool is_communication_op = false; | |||||
| FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op); | |||||
| if (!is_communication_op) { | |||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); | |||||
| mem_manager_->FreeMemFromMemPool(device_address); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| // Free the output of kernel, if output has no reference. | // Free the output of kernel, if output has no reference. | ||||
| @@ -393,40 +361,6 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr &kernel, size_t input_idx, | |||||
| bool *is_communication_op) { | |||||
| MS_EXCEPTION_IF_NULL(kernel); | |||||
| MS_EXCEPTION_IF_NULL(mem_manager_); | |||||
| // The inputs memory of communication kernel is one piece memory, need release together. | |||||
| if (AnfAlgo::GetCNodeName(kernel) == kAllReduceOpName) { | |||||
| communication_op_input_ref_count_--; | |||||
| if (communication_op_input_ref_count_ == 0) { | |||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0); | |||||
| mem_manager_->FreeMemFromMemPool(device_address); | |||||
| } | |||||
| *is_communication_op = true; | |||||
| return; | |||||
| } | |||||
| auto cnode = kernel->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (input_idx + 1 >= cnode->inputs().size()) { | |||||
| MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << cnode->inputs().size() - 1 | |||||
| << "."; | |||||
| } | |||||
| auto input_node = cnode->input(input_idx + 1); | |||||
| auto kernel_input = AnfAlgo::VisitKernel(input_node, 0); | |||||
| // The outputs memory of communication kernel is one piece memory, need release together. | |||||
| if (AnfAlgo::GetCNodeName(kernel_input.first) == kAllReduceOpName) { | |||||
| communication_op_output_ref_count_--; | |||||
| if (communication_op_output_ref_count_ == 0) { | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0); | |||||
| mem_manager_->FreeMemFromMemPool(device_address); | |||||
| } | |||||
| *is_communication_op = true; | |||||
| } | |||||
| } | |||||
| } // namespace gpu | } // namespace gpu | ||||
| } // namespace device | } // namespace device | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -60,9 +60,6 @@ class GPUKernelRuntime : public KernelRuntime { | |||||
| void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); | void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); | ||||
| void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, | void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, | ||||
| uint32_t graph_id); | uint32_t graph_id); | ||||
| void FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr &kernel, size_t input_idx, bool *is_communication_op); | |||||
| size_t communication_op_input_ref_count_{0}; | |||||
| size_t communication_op_output_ref_count_{0}; | |||||
| std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_; | std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_; | ||||
| }; | }; | ||||
| MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); | MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); | ||||
| @@ -29,6 +29,10 @@ void GPUMemoryManager::FreeMemFromMemPool(void *device_ptr) { | |||||
| GPUMemoryAllocator::GetInstance().FreeTensorMem(device_ptr); | GPUMemoryAllocator::GetInstance().FreeTensorMem(device_ptr); | ||||
| } | } | ||||
| std::vector<void *> GPUMemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list) { | |||||
| return GPUMemoryAllocator::GetInstance().AllocContinuousTensorMem(total_size, size_list); | |||||
| } | |||||
| void GPUMemoryManager::MallocDeviceMemory() { | void GPUMemoryManager::MallocDeviceMemory() { | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| @@ -16,6 +16,7 @@ | |||||
| #ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ | #ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ | ||||
| #define MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ | #define MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ | ||||
| #include <vector> | |||||
| #include "device/memory_manager.h" | #include "device/memory_manager.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| @@ -30,6 +31,7 @@ class GPUMemoryManager : public MemoryManager { | |||||
| void *MallocMemFromMemPool(size_t size) override; | void *MallocMemFromMemPool(size_t size) override; | ||||
| void FreeMemFromMemPool(void *device_ptr) override; | void FreeMemFromMemPool(void *device_ptr) override; | ||||
| std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list); | |||||
| protected: | protected: | ||||
| uint8_t *MallocStaticMem(size_t size, bool communication_mem) override; | uint8_t *MallocStaticMem(size_t size, bool communication_mem) override; | ||||
| @@ -67,6 +67,7 @@ class KernelRuntime { | |||||
| TypeId type_id) = 0; | TypeId type_id) = 0; | ||||
| virtual bool SyncStream() = 0; | virtual bool SyncStream() = 0; | ||||
| void AssignStaticMemory(session::KernelGraph *graph); | void AssignStaticMemory(session::KernelGraph *graph); | ||||
| void AssignStaticMemoryValueNode(session::KernelGraph *graph); | |||||
| void AssignDynamicMemory(session::KernelGraph *graph); | void AssignDynamicMemory(session::KernelGraph *graph); | ||||
| void ReuseAssignDynamicMemory(session::KernelGraph *graph); | void ReuseAssignDynamicMemory(session::KernelGraph *graph); | ||||
| void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); | void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); | ||||
| @@ -81,7 +82,6 @@ class KernelRuntime { | |||||
| private: | private: | ||||
| void AssignStaticMemoryOutput(const session::KernelGraph *graph); | void AssignStaticMemoryOutput(const session::KernelGraph *graph); | ||||
| void AssignStaticMemoryValueNode(session::KernelGraph *graph); | |||||
| void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, | void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, | ||||
| AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); | AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); | ||||
| bool LaunchKernelMod(const session::KernelGraph &graph); | bool LaunchKernelMod(const session::KernelGraph &graph); | ||||
| @@ -167,5 +167,28 @@ void MemoryManager::FreeMemFromMemPool(void *device_ptr) { | |||||
| MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null."; | MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null."; | ||||
| } | } | ||||
| } | } | ||||
| void MemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, | |||||
| std::vector<size_t> size_list) { | |||||
| auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list); | |||||
| if (addr_list.size() != device_ptr_list.size()) { | |||||
| MS_LOG(EXCEPTION) << "The size of device list is not equal to the size of address list."; | |||||
| } | |||||
| for (size_t i = 0; i < addr_list.size(); i++) { | |||||
| MS_EXCEPTION_IF_NULL(device_ptr_list[i]); | |||||
| MS_EXCEPTION_IF_NULL(addr_list[i]); | |||||
| addr_list[i]->ptr_ = device_ptr_list[i]; | |||||
| addr_list[i]->from_mem_pool_ = true; | |||||
| } | |||||
| } | |||||
| std::vector<void *> MemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list) { | |||||
| if (total_size == 0) { | |||||
| MS_LOG(ERROR) << "MallocContinuousMemFromMemPool total_size is 0."; | |||||
| } | |||||
| std::vector<void *> device_ptr_list; | |||||
| device_ptr_list.emplace_back(nullptr); | |||||
| return device_ptr_list; | |||||
| } | |||||
| } // namespace device | } // namespace device | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,6 +17,7 @@ | |||||
| #ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ | #ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ | ||||
| #define MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ | #define MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | |||||
| #include "pre_activate/mem_reuse/mem_reuse.h" | #include "pre_activate/mem_reuse/mem_reuse.h" | ||||
| #include "pre_activate/mem_reuse/mem_reuse_allocator.h" | #include "pre_activate/mem_reuse/mem_reuse_allocator.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -49,6 +50,9 @@ class MemoryManager { | |||||
| virtual void *MallocMemFromMemPool(size_t size); | virtual void *MallocMemFromMemPool(size_t size); | ||||
| virtual void FreeMemFromMemPool(const DeviceAddressPtr address); | virtual void FreeMemFromMemPool(const DeviceAddressPtr address); | ||||
| virtual void FreeMemFromMemPool(void *device_ptr); | virtual void FreeMemFromMemPool(void *device_ptr); | ||||
| virtual void MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, | |||||
| std::vector<size_t> size_list); | |||||
| virtual std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list); | |||||
| size_t GetCommonAlignSize(size_t input_size) const; | size_t GetCommonAlignSize(size_t input_size) const; | ||||
| size_t GetCommunicationAlignSize(size_t input_size) const; | size_t GetCommunicationAlignSize(size_t input_size) const; | ||||
| @@ -44,7 +44,7 @@ class TransposeGpuFwdKernel : public GpuKernel { | |||||
| "cudaMemcpyAsync input_shape failed"); | "cudaMemcpyAsync input_shape failed"); | ||||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, | CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, | ||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | reinterpret_cast<cudaStream_t>(stream_ptr)), | ||||
| "cudaMemcphalfyAsync input_axis failed"); | |||||
| "cudaMemcpyAsync input_axis failed"); | |||||
| int size = SizeToInt(input_size_ / sizeof(T)); | int size = SizeToInt(input_size_ / sizeof(T)); | ||||
| CalTranspose(size, input, input_shape, input_axis, SizeToInt(shape_size_), output, | CalTranspose(size, input, input_shape, input_axis, SizeToInt(shape_size_), output, | ||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| @@ -60,6 +60,14 @@ __global__ void SquareKernel(T *input, T *output, size_t count) { | |||||
| return; | return; | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| __global__ void ZeroslikeKernel(T *output, size_t count) { | |||||
| T zero = 0.0; | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||||
| output[i] = zero; | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream) { | void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream) { | ||||
| ExponentialKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | ExponentialKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | ||||
| return; | return; | ||||
| @@ -84,13 +92,21 @@ void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream) { | |||||
| SquareKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | SquareKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count); | ||||
| return; | return; | ||||
| } | } | ||||
| template <typename T> | |||||
| void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream) { | |||||
| ZeroslikeKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(output, count); | |||||
| return; | |||||
| } | |||||
| template void Exponential<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | template void Exponential<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Logarithm<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | template void Logarithm<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Negative<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | template void Negative<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Reciprocal<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | template void Reciprocal<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Square<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | template void Square<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Zeroslike<float>(float *output, size_t count, cudaStream_t cuda_stream); | |||||
| template void Exponential<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | template void Exponential<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Logarithm<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | template void Logarithm<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Negative<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | template void Negative<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Reciprocal<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | template void Reciprocal<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Square<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | template void Square<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream); | ||||
| template void Zeroslike<half>(half *output, size_t count, cudaStream_t cuda_stream); | |||||
| @@ -28,4 +28,7 @@ template <typename T> | |||||
| void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream); | void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream); | ||||
| template <typename T> | template <typename T> | ||||
| void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream); | void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream); | ||||
| template <typename T> | |||||
| void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ | ||||
| @@ -81,6 +81,7 @@ class UnaryOpGpuKernel : public GpuKernel { | |||||
| break; | break; | ||||
| } | } | ||||
| case UNARY_OP_ZEROSLIKE: { | case UNARY_OP_ZEROSLIKE: { | ||||
| Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | return true; | ||||
| } | } | ||||
| default: { | default: { | ||||
| @@ -36,6 +36,37 @@ DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size) { | |||||
| return device_addr; | return device_addr; | ||||
| } | } | ||||
| std::vector<DeviceMemPtr> DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size, | |||||
| std::vector<size_t> size_list) { | |||||
| // Pre-alloc the one whole piece memory. | |||||
| auto device_addr = AllocTensorMem(total_size); | |||||
| MS_EXCEPTION_IF_NULL(device_addr); | |||||
| // Remove the pre-alloc memory. | |||||
| auto mem_block = FindMemBlock(device_addr); | |||||
| MS_EXCEPTION_IF_NULL(mem_block); | |||||
| auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); | |||||
| if (iter == mem_block->block_all_mem_buf_map_.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "]."; | |||||
| } | |||||
| auto mem_buf = iter->second; | |||||
| MS_EXCEPTION_IF_NULL(mem_buf); | |||||
| auto rest_size = mem_buf->size_ - total_size; | |||||
| (void)mem_block->block_all_mem_buf_map_.erase(iter); | |||||
| // Split the pre-alloc memory into continuous memory by the size list. | |||||
| DynamicMemBufPtr continuous_mem_buf; | |||||
| std::vector<DeviceMemPtr> device_addr_list; | |||||
| auto buf_addr = device_addr; | |||||
| for (size_t i = 0; i < size_list.size(); i++) { | |||||
| continuous_mem_buf = std::make_shared<DynamicMemBuf>(buf_addr, kMemBufUsed, size_list[i]); | |||||
| (void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf); | |||||
| device_addr_list.emplace_back(buf_addr); | |||||
| buf_addr = AddressOffset(buf_addr, size_list[i]); | |||||
| } | |||||
| // Update the size of the last memory buf. | |||||
| continuous_mem_buf->size_ += rest_size; | |||||
| return device_addr_list; | |||||
| } | |||||
| size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const { | size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const { | ||||
| if (size == 0) { | if (size == 0) { | ||||
| return DYNAMIC_MEM_ALIGN_SIZE; | return DYNAMIC_MEM_ALIGN_SIZE; | ||||
| @@ -79,6 +79,8 @@ class DynamicMemPoolBestFit { | |||||
| virtual ~DynamicMemPoolBestFit(); | virtual ~DynamicMemPoolBestFit(); | ||||
| // The main program entry of memory alloc. | // The main program entry of memory alloc. | ||||
| DeviceMemPtr AllocTensorMem(size_t size); | DeviceMemPtr AllocTensorMem(size_t size); | ||||
| // The main program entry of continuous memory alloc. | |||||
| std::vector<DeviceMemPtr> AllocContinuousTensorMem(size_t total_size, std::vector<size_t> size_list); | |||||
| // The main program entry of memory free. | // The main program entry of memory free. | ||||
| void FreeTensorMem(const DeviceMemPtr device_addr); | void FreeTensorMem(const DeviceMemPtr device_addr); | ||||
| // Release the real device memory. | // Release the real device memory. | ||||
| @@ -162,10 +162,6 @@ void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr | |||||
| if (iter == kernel_def_ptr->inputs_.end()) { | if (iter == kernel_def_ptr->inputs_.end()) { | ||||
| kernel_def_ptr->inputs_[key].push_back(ref_ptr); | kernel_def_ptr->inputs_[key].push_back(ref_ptr); | ||||
| } else { | } else { | ||||
| if (std::any_of(iter->second.begin(), iter->second.end(), | |||||
| [ref_ptr](const KernelRefCountPtr &it) { return (it.get() == ref_ptr.get()); })) { | |||||
| break; | |||||
| } | |||||
| iter->second.push_back(ref_ptr); | iter->second.push_back(ref_ptr); | ||||
| } | } | ||||
| } | } | ||||
| @@ -185,10 +181,6 @@ void MemReuseUtil::SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_pt | |||||
| if (iter == kernel_def_ptr->outputs_.end()) { | if (iter == kernel_def_ptr->outputs_.end()) { | ||||
| kernel_def_ptr->outputs_[key].push_back(kernel_ref); | kernel_def_ptr->outputs_[key].push_back(kernel_ref); | ||||
| } else { | } else { | ||||
| if (std::any_of(iter->second.begin(), iter->second.end(), | |||||
| [kernel_ref](const KernelRefCountPtr &it) { return (it == kernel_ref); })) { | |||||
| break; | |||||
| } | |||||
| iter->second.push_back(kernel_ref); | iter->second.push_back(kernel_ref); | ||||
| } | } | ||||
| } | } | ||||
| @@ -20,7 +20,7 @@ import mindspore.context as context | |||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size | from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU', enable_dynamic_memory=False) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||||
| init('nccl') | init('nccl') | ||||
| rank = get_rank() | rank = get_rank() | ||||