Browse Source

!2749 GPU update argmaxwithvalue

Merge pull request !2749 from VectorSL/argmaxwithvalue
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
ac3c35c329
4 changed files with 32 additions and 27 deletions
  1. +1
    -1
      mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h
  2. +22
    -24
      mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu
  3. +2
    -2
      mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh
  4. +7
    -0
      mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu

+ 1
- 1
mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h View File

@@ -38,7 +38,7 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
T *input = GetDeviceAddress<T>(inputs, 0); T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 1); T *output = GetDeviceAddress<T>(outputs, 1);
S *index = GetDeviceAddress<S>(outputs, 0); S *index = GetDeviceAddress<S>(outputs, 0);
CalArgmaxWithValue(input_size_ / sizeof(T), input, bound_, outerSize_, innerSize_, index, output,
CalArgmaxWithValue(input, bound_, outerSize_, innerSize_, index, output,
reinterpret_cast<cudaStream_t>(stream_ptr)); reinterpret_cast<cudaStream_t>(stream_ptr));
return true; return true;
} }


+ 22
- 24
mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu View File

@@ -18,41 +18,39 @@
#include "device/gpu/cuda_common.h" #include "device/gpu/cuda_common.h"
#include "include/cuda_fp16.h" #include "include/cuda_fp16.h"
template <typename T, typename S> template <typename T, typename S>
__global__ void ArgmaxWithValue(size_t size, const T* input, const int bound, int outerSize, int innerSize,
S* index, T* output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
for (int i = 0; i < outerSize; i++) {
int inputOutterOffset = i * innerSize * bound;
int outputOutterOffset = i * innerSize;
for (int j = 0; j < innerSize; j++) {
auto outputInnerOffset = outputOutterOffset + j;
S idx = 0;
T maxData = input[j + inputOutterOffset];
for (S c = 0; c < bound; c++) {
int offset = j + c * innerSize;
auto inputData = input[inputOutterOffset + offset];
idx = inputData > maxData ? c : idx;
maxData = inputData > maxData ? inputData : maxData;
}
output[outputInnerOffset] = maxData;
index[outputInnerOffset] = idx;
__global__ void ArgmaxWithValue(const T* input, const int bound, int outerSize, int innerSize, S* index,
T* output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (outerSize); pos += blockDim.x * gridDim.x) {
int inputOutterOffset = pos * innerSize * bound;
int outputOutterOffset = pos * innerSize;
for (int j = 0; j < innerSize; j++) {
auto outputInnerOffset = outputOutterOffset + j;
S idx = 0;
T maxData = input[j + inputOutterOffset];
for (S c = 0; c < bound; c++) {
int offset = j + c * innerSize;
auto inputData = input[inputOutterOffset + offset];
idx = inputData > maxData ? c : idx;
maxData = inputData > maxData ? inputData : maxData;
}
output[outputInnerOffset] = maxData;
index[outputInnerOffset] = idx;
} }
}
} }
return; return;
} }


template <typename T, typename S> template <typename T, typename S>
void CalArgmaxWithValue(size_t size, const T* input, const int bound_, const int outerSize_, const int innerSize_,
void CalArgmaxWithValue(const T* input, const int bound_, const int outerSize_, const int innerSize_,
S* index, T* output, cudaStream_t cuda_stream) { S* index, T* output, cudaStream_t cuda_stream) {
ArgmaxWithValue<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, bound_, outerSize_, innerSize_,
index, output);
ArgmaxWithValue<<<GET_BLOCKS(outerSize_), GET_THREADS, 0, cuda_stream>>>(input, bound_, outerSize_, innerSize_,
index, output);
return; return;
} }


template void CalArgmaxWithValue<float, int>(size_t size, const float* input, const int bound_, const int outerSize_,
template void CalArgmaxWithValue<float, int>(const float* input, const int bound_, const int outerSize_,
const int innerSize_, int* index, float* output, const int innerSize_, int* index, float* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void CalArgmaxWithValue<half, int>(size_t size, const half* input, const int bound_, const int outerSize_,
template void CalArgmaxWithValue<half, int>(const half* input, const int bound_, const int outerSize_,
const int innerSize_, int* index, half* output, const int innerSize_, int* index, half* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);

+ 2
- 2
mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh View File

@@ -17,6 +17,6 @@
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
template <typename T, typename S> template <typename T, typename S>
void CalArgmaxWithValue(size_t size, const T* input, const int bound_, const int outerSize_, const int innerSize_,
S* index, T* output, cudaStream_t cuda_stream);
void CalArgmaxWithValue(const T *input, const int bound_, const int outerSize_, const int innerSize_, S *index,
T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_

+ 7
- 0
mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu View File

@@ -36,6 +36,13 @@ __global__ void LogarithmKernel(T *input, T *output, size_t count) {
} }
return; return;
} }
template <>
__global__ void LogarithmKernel(half *input, half *output, size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = hlog(input[i]);
}
return;
}
template <typename T> template <typename T>
__global__ void NegativeKernel(T *input, T *output, size_t count) { __global__ void NegativeKernel(T *input, T *output, size_t count) {
T neg_one = -1; T neg_one = -1;


Loading…
Cancel
Save