Browse Source

!13201 fix topk device memory issue

From: @TFbunny
Reviewed-by: @tom__chen,@tom__chen,@robingrosman
Signed-off-by: @robingrosman
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
51b2b89078
3 changed files with 10 additions and 7 deletions
  1. +7
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h
  2. +2
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu
  3. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh

+ 7
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h View File

@@ -42,8 +42,13 @@ class TopKGpuKernel : public GpuKernel {
T *output_addr = GetDeviceAddress<T>(outputs, 0);
S *indices = GetDeviceAddress<S>(outputs, 1);
const T init_k = std::numeric_limits<T>::lowest();

FastTopK(outer_size_, inner_size_, input_addr, k, output_addr, indices, init_k,
S k_cut = 0;
CHECK_CUDA_RET_WITH_EXCEPT(
kernel_node_,
cudaMemcpyAsync(&k_cut, k, sizeof(S), cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync k_cut failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed - TopK");
FastTopK(outer_size_, inner_size_, input_addr, k_cut, output_addr, indices, init_k,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}


+ 2
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu View File

@@ -204,11 +204,9 @@ __global__ void TopKBlock(int outer_size, int inner_size, const T *input, T *out
}

template <typename T, typename S>
void FastTopK(const int outer_size, const int inner_size, const T *input, const S *k, T *output, S *output_index,
void FastTopK(const int outer_size, const int inner_size, const T *input, S k_cut, T *output, S *output_index,
const T init_K, cudaStream_t stream) {
int block_num_limit = outer_size < 128 ? outer_size : 128;
S k_cut = 0;
cudaMemcpy(&k_cut, k, sizeof(S), cudaMemcpyDeviceToHost);
if (k_cut > inner_size) k_cut = inner_size;

if (k_cut <= 32) {
@@ -223,5 +221,5 @@ void FastTopK(const int outer_size, const int inner_size, const T *input, const
}
}

template void FastTopK(const int outer_size, const int inner_size, const float *input, const int *k, float *output,
template void FastTopK(const int outer_size, const int inner_size, const float *input, int k_cut, float *output,
int *output_index, const float init_K, cudaStream_t stream);

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh View File

@@ -21,7 +21,7 @@
#include "runtime/device/gpu/cuda_common.h"

template <typename T, typename S>
void FastTopK(const int outer, const int inner, const T *input_addr, const S *k, T *output, S *indices, const T initK,
void FastTopK(const int outer, const int inner, const T *input_addr, S k_cut, T *output, S *indices, const T initK,
cudaStream_t stream);

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TOPK_IMPL_CUH_

Loading…
Cancel
Save