From 5dd4933328d543e16993d2e191b24eeafb6f5eea Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Thu, 13 Aug 2020 23:19:48 -0400 Subject: [PATCH] Refactor uniform ops in GPU context --- .../gpu/cuda_impl/random_op_impl.cu | 65 +++++++++++++++---- .../gpu/cuda_impl/random_op_impl.cuh | 9 ++- .../gpu/math/random_op_gpu_kernel.cc | 10 +-- .../gpu/math/random_op_gpu_kernel.h | 42 ++++++------ 4 files changed, 83 insertions(+), 43 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu index 19a1273cb3..4c88916515 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu @@ -19,19 +19,26 @@ template __global__ void NormalKernel(int seed, curandState *globalState, T *output, size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { curand_init(seed, i, 0, &globalState[i]); - output[i] = curand_normal(&globalState[i]); + output[i] = (T)curand_normal(&globalState[i]); } return; } template -__global__ void UniformKernel(int seed, curandState *globalState, T *input1, size_t input_size_1, - T *input2, size_t input_size_2, T *output, size_t count) { +__global__ void UniformIntKernel(int seed, curandState *globalState, T *input1, size_t input_size_1, + T *input2, size_t input_size_2, T *output, size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - input1[i] = (input_size_1 == 1 ? input1[0] : input1[i]); - input2[i] = (input_size_2 == 1 ? input2[0] : input2[i]); curand_init(seed, i, 0, &globalState[i]); - output[i] = curand_uniform(&globalState[i]) * (input2[i] - input1[i]) + input1[i]; + output[i] = (T)(curand_uniform(&globalState[i])) * (input2[0] - input1[0]) + input1[0]; + } + return; +} + +template +__global__ void UniformRealKernel(int seed, curandState *globalState, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + curand_init(seed, i, 0, &globalState[i]); + output[i] = (T)curand_uniform(&globalState[i]); } return; } @@ -51,16 +58,46 @@ void StandardNormal(int seed, int seed2, curandState *globalState, T *output, si } template -void UniformReal(int seed, curandState *globalState, T *input1, size_t input_size_1, - T *input2, size_t input_size_2, T *output, size_t count, cudaStream_t cuda_stream) { - seed = (seed == 0 ? time(NULL):seed); - UniformKernel<<>> - (seed, globalState, input1, input_size_1, input2, input_size_2, output, count); +void UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t input_size_1, + T *input2, size_t input_size_2, T *output, size_t count, cudaStream_t cuda_stream) { + int RNG_seed = 0; + if (seed2 != 0) { + RNG_seed = seed2; + } else if (seed != 0) { + RNG_seed = seed; + } else { + RNG_seed = time(NULL); + } + UniformIntKernel<<>> + (RNG_seed, globalState, input1, input_size_1, input2, input_size_2, output, count); + return; +} + +template +void UniformReal(int seed, int seed2, curandState *globalState, T *output, size_t count, cudaStream_t cuda_stream) { + int RNG_seed = 0; + if (seed2 != 0) { + RNG_seed = seed2; + } else if (seed != 0) { + RNG_seed = seed; + } else { + RNG_seed = time(NULL); + } + UniformRealKernel<<>>(RNG_seed, globalState, output, count); return; } template void StandardNormal(int seed, int seed2, curandState *globalState, float *output, size_t count, cudaStream_t cuda_stream); -template void UniformReal(int seed, curandState *globalState, float *input1, size_t input_size_1, - float *input2, size_t input_size_2, float *output, size_t count, - cudaStream_t cuda_stream); +template void StandardNormal(int seed, int seed2, curandState *globalState, + int *output, size_t count, cudaStream_t cuda_stream); +template void UniformInt(int seed, int seed2, curandState *globalState, float *input1, size_t input_size_1, + float *input2, size_t input_size_2, float *output, size_t count, + cudaStream_t cuda_stream); +template void UniformInt(int seed, int seed2, curandState *globalState, int *input1, size_t input_size_1, + int *input2, size_t input_size_2, int *output, size_t count, + cudaStream_t cuda_stream); +template void UniformReal(int seed, int seed2, curandState *globalState, + float *output, size_t count, cudaStream_t cuda_stream); +template void UniformReal(int seed, int seed2, curandState *globalState, + int *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh index f5699cee0a..9c51304bb6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh @@ -24,7 +24,10 @@ template void StandardNormal(int seed, int seed2, curandState *globalState, T *output, size_t count, cudaStream_t cuda_stream); template -void UniformReal(int seed, curandState *globalState, - T *input1, size_t input_size_1, T *input2, size_t input_size_2, - T *output, size_t count, cudaStream_t cuda_stream); +void UniformInt(int seed, int seed2, curandState *globalState, + T *input1, size_t input_size_1, T *input2, size_t input_size_2, + T *output, size_t count, cudaStream_t cuda_stream); +template +void UniformReal(int seed, int seed2, curandState *globalState, + T *output, size_t count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc index 8dfd4eef08..e674bef5ad 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc @@ -20,12 +20,14 @@ namespace mindspore { namespace kernel { MS_REG_GPU_KERNEL_ONE(StandardNormal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), RandomOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(UniformReal, +MS_REG_GPU_KERNEL_ONE(UniformInt, KernelAttr() .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + RandomOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE(UniformReal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), RandomOpGpuKernel, float) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h index 98a421c922..43954a762c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h @@ -28,16 +28,17 @@ namespace mindspore { namespace kernel { -enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_UNIFORM_REAL, RANDOM_OP_INVALID_TYPE = 255 }; +enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_UNIFORM_INT, RANDOM_OP_UNIFORM_REAL, RANDOM_OP_INVALID_TYPE = 255 }; + +const std::map kRandomOpTypeMap = { + {"StandardNormal", RANDOM_OP_NORMAL}, {"UniformInt", RANDOM_OP_UNIFORM_INT}, {"UniformReal", RANDOM_OP_UNIFORM_REAL}}; -const std::map kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}, - {"UniformReal", RANDOM_OP_UNIFORM_REAL}}; template class RandomOpGpuKernel : public GpuKernel { public: RandomOpGpuKernel() : random_op_type_(RANDOM_OP_INVALID_TYPE), - input_size_0_(sizeof(int)), + input_size_0_(sizeof(0)), input_size_1_(sizeof(T)), input_size_2_(sizeof(T)), output_size_(sizeof(T)), @@ -62,11 +63,16 @@ class RandomOpGpuKernel : public GpuKernel { reinterpret_cast(stream_ptr)); break; } - case RANDOM_OP_UNIFORM_REAL: { + case RANDOM_OP_UNIFORM_INT: { T *input_addr_1 = GetDeviceAddress(inputs, 1); T *input_addr_2 = GetDeviceAddress(inputs, 2); - UniformReal(seed_, devStates, input_addr_1, inputs[1]->size / sizeof(T), input_addr_2, - inputs[2]->size / sizeof(T), output_addr, outputs[0]->size / sizeof(T), + UniformInt(seed_, seed2_, devStates, input_addr_1, inputs[1]->size / sizeof(T), input_addr_2, + inputs[2]->size / sizeof(T), output_addr, outputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } + case RANDOM_OP_UNIFORM_REAL: { + UniformReal(seed_, seed2_, devStates, output_addr, outputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); break; } @@ -86,11 +92,11 @@ class RandomOpGpuKernel : public GpuKernel { random_op_type_ = iter->second; } size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (random_op_type_ == RANDOM_OP_NORMAL && input_num != 1) { + if ((random_op_type_ == RANDOM_OP_NORMAL || random_op_type_ == RANDOM_OP_UNIFORM_REAL) && input_num != 1) { MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 1 input."; return false; } - if (random_op_type_ == RANDOM_OP_UNIFORM_REAL && input_num != 3) { + if (random_op_type_ == RANDOM_OP_UNIFORM_INT && input_num != 3) { MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 3 inputs."; return false; } @@ -104,15 +110,9 @@ class RandomOpGpuKernel : public GpuKernel { input_size_0_ += input_shape_0[i]; } input_size_0_ *= sizeof(int); - if (random_op_type_ == RANDOM_OP_UNIFORM_REAL) { - auto input_shape_1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - for (size_t i = 0; i < input_shape_1.size(); i++) { - input_size_1_ *= input_shape_1[i]; - } - auto input_shape_2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - for (size_t i = 0; i < input_shape_2.size(); i++) { - input_size_2_ *= input_shape_2[i]; - } + if (random_op_type_ == RANDOM_OP_UNIFORM_INT) { + input_size_1_ *= 1; + input_size_2_ *= 1; } auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); for (size_t i = 0; i < output_shape.size(); i++) { @@ -120,9 +120,7 @@ class RandomOpGpuKernel : public GpuKernel { workspace_size_ *= output_shape[i]; } seed_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); - if (random_op_type_ == RANDOM_OP_NORMAL) { - seed2_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); - } + seed2_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); InitSizeLists(); return true; } @@ -130,7 +128,7 @@ class RandomOpGpuKernel : public GpuKernel { protected: void InitSizeLists() override { input_size_list_.push_back(input_size_0_); - if (random_op_type_ == RANDOM_OP_UNIFORM_REAL) { + if (random_op_type_ == RANDOM_OP_UNIFORM_INT) { input_size_list_.push_back(input_size_1_); input_size_list_.push_back(input_size_2_); }