From 920820382fad60a0312ec41ce1a96720a5f43fde Mon Sep 17 00:00:00 2001 From: TFbunny Date: Fri, 21 Aug 2020 21:28:56 -0400 Subject: [PATCH] fix GPU-ArgMaxWithValue --- .../gpu/cuda_impl/argmaxwithvalue_impl.cu | 51 +++++++++---------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu index 46a8a75af9..66a73aca50 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu @@ -18,39 +18,34 @@ #include "runtime/device/gpu/cuda_common.h" #include "include/cuda_fp16.h" template -__global__ void ArgmaxWithValue(const T* input, const int bound, int outerSize, int innerSize, S* index, - T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (outerSize); pos += blockDim.x * gridDim.x) { - int inputOutterOffset = pos * innerSize * bound; - int outputOutterOffset = pos * innerSize; - for (int j = 0; j < innerSize; j++) { - auto outputInnerOffset = outputOutterOffset + j; - S idx = 0; - T maxData = input[j + inputOutterOffset]; - for (S c = 0; c < bound; c++) { - int offset = j + c * innerSize; - auto inputData = input[inputOutterOffset + offset]; - idx = inputData > maxData ? c : idx; - maxData = inputData > maxData ? inputData : maxData; - } - output[outputInnerOffset] = maxData; - index[outputInnerOffset] = idx; - } +__global__ void ArgmaxWithValue(const T *input, const int bound, int outerSize, int innerSize, S *index, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outerSize * innerSize; pos += gridDim.x * blockDim.x) { + int x = pos / innerSize % outerSize; + int y = pos % innerSize; + S idx = 0; + int InputOffset = x * bound * innerSize + 0 * innerSize + y; + T maxData = input[InputOffset]; + for (int i = 0; i < bound; i++) { + InputOffset = x * bound * innerSize + i * innerSize + y; + auto inputData = input[InputOffset]; + idx = inputData > maxData ? i : idx; + maxData = inputData > maxData ? inputData : maxData; + } + output[pos] = maxData; + index[pos] = idx; } return; } template -void CalArgmaxWithValue(const T* input, const int bound_, const int outerSize_, const int innerSize_, - S* index, T* output, cudaStream_t cuda_stream) { - ArgmaxWithValue<<>>(input, bound_, outerSize_, innerSize_, - index, output); +void CalArgmaxWithValue(const T *input, const int bound_, const int outerSize_, const int innerSize_, S *index, + T *output, cudaStream_t cuda_stream) { + ArgmaxWithValue<<>>(input, bound_, outerSize_, innerSize_, index, + output); return; } -template void CalArgmaxWithValue(const float* input, const int bound_, const int outerSize_, - const int innerSize_, int* index, float* output, - cudaStream_t cuda_stream); -template void CalArgmaxWithValue(const half* input, const int bound_, const int outerSize_, - const int innerSize_, int* index, half* output, - cudaStream_t cuda_stream); +template void CalArgmaxWithValue(const float *input, const int bound_, const int outerSize_, + const int innerSize_, int *index, float *output, cudaStream_t cuda_stream); +template void CalArgmaxWithValue(const half *input, const int bound_, const int outerSize_, + const int innerSize_, int *index, half *output, cudaStream_t cuda_stream);