Browse Source

Fix an error that uniform_a should be less than uniform_b

tags/v1.1.0
peixu_ren 5 years ago
parent
commit
05f44ab834
3 changed files with 21 additions and 8 deletions
  1. +13
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu
  2. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh
  3. +7
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h

+ 13
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu View File

@@ -24,13 +24,20 @@ __global__ void NormalKernel(int seed, curandState *globalState, T *output, size
return; return;
} }


__device__ bool dev_error_res = false;

template <typename T> template <typename T>
__global__ void UniformIntKernel(int seed, curandState *globalState, T *input1, size_t input_size_1, __global__ void UniformIntKernel(int seed, curandState *globalState, T *input1, size_t input_size_1,
T *input2, size_t input_size_2, T *output, size_t count) { T *input2, size_t input_size_2, T *output, size_t count) {
if (!(input1[0] < input2[0])) {
dev_error_res = false;
return;
}
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
curand_init(seed, i, 0, &globalState[i]); curand_init(seed, i, 0, &globalState[i]);
output[i] = (T)(curand_uniform(&globalState[i]) * (input2[0] - input1[0])) + input1[0]; output[i] = (T)(curand_uniform(&globalState[i]) * (input2[0] - input1[0])) + input1[0];
} }
dev_error_res = true;
return; return;
} }


@@ -59,7 +66,7 @@ void StandardNormal(int seed, int seed2, curandState *globalState, T *output, si
} }


template <typename T> template <typename T>
void UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t input_size_1,
bool UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t input_size_1,
T *input2, size_t input_size_2, T *output, size_t count, cudaStream_t cuda_stream) { T *input2, size_t input_size_2, T *output, size_t count, cudaStream_t cuda_stream) {
int RNG_seed = 0; int RNG_seed = 0;
std::random_device rd; std::random_device rd;
@@ -70,9 +77,11 @@ void UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t
} else { } else {
RNG_seed = static_cast<int>(rd()); RNG_seed = static_cast<int>(rd());
} }
bool host_error_res = false;
UniformIntKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>> UniformIntKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>
(RNG_seed, globalState, input1, input_size_1, input2, input_size_2, output, count); (RNG_seed, globalState, input1, input_size_1, input2, input_size_2, output, count);
return;
cudaMemcpyFromSymbol(&host_error_res, dev_error_res, sizeof(bool));
return host_error_res;
} }


template <typename T> template <typename T>
@@ -94,10 +103,10 @@ template void StandardNormal<float>(int seed, int seed2, curandState *globalStat
float *output, size_t count, cudaStream_t cuda_stream); float *output, size_t count, cudaStream_t cuda_stream);
template void StandardNormal<int>(int seed, int seed2, curandState *globalState, template void StandardNormal<int>(int seed, int seed2, curandState *globalState,
int *output, size_t count, cudaStream_t cuda_stream); int *output, size_t count, cudaStream_t cuda_stream);
template void UniformInt<float>(int seed, int seed2, curandState *globalState, float *input1, size_t input_size_1,
template bool UniformInt<float>(int seed, int seed2, curandState *globalState, float *input1, size_t input_size_1,
float *input2, size_t input_size_2, float *output, size_t count, float *input2, size_t input_size_2, float *output, size_t count,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void UniformInt<int>(int seed, int seed2, curandState *globalState, int *input1, size_t input_size_1,
template bool UniformInt<int>(int seed, int seed2, curandState *globalState, int *input1, size_t input_size_1,
int *input2, size_t input_size_2, int *output, size_t count, int *input2, size_t input_size_2, int *output, size_t count,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void UniformReal<float>(int seed, int seed2, curandState *globalState, template void UniformReal<float>(int seed, int seed2, curandState *globalState,


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh View File

@@ -25,7 +25,7 @@ template <typename T>
void StandardNormal(int seed, int seed2, curandState *globalState, void StandardNormal(int seed, int seed2, curandState *globalState,
T *output, size_t count, cudaStream_t cuda_stream); T *output, size_t count, cudaStream_t cuda_stream);
template <typename T> template <typename T>
void UniformInt(int seed, int seed2, curandState *globalState,
bool UniformInt(int seed, int seed2, curandState *globalState,
T *input1, size_t input_size_1, T *input2, size_t input_size_2, T *input1, size_t input_size_1, T *input2, size_t input_size_2,
T *output, size_t count, cudaStream_t cuda_stream); T *output, size_t count, cudaStream_t cuda_stream);
template <typename T> template <typename T>


+ 7
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h View File

@@ -75,9 +75,13 @@ class RandomOpGpuKernel : public GpuKernel {
case RANDOM_OP_UNIFORM_INT: { case RANDOM_OP_UNIFORM_INT: {
T *input_addr_1 = GetDeviceAddress<T>(inputs, 1); T *input_addr_1 = GetDeviceAddress<T>(inputs, 1);
T *input_addr_2 = GetDeviceAddress<T>(inputs, 2); T *input_addr_2 = GetDeviceAddress<T>(inputs, 2);
UniformInt(seed_, seed2_, devStates, input_addr_1, inputs[1]->size / sizeof(T), input_addr_2,
inputs[2]->size / sizeof(T), output_addr, outputs[0]->size / sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr));
bool ret = UniformInt(seed_, seed2_, devStates, input_addr_1, inputs[1]->size / sizeof(T), input_addr_2,
inputs[2]->size / sizeof(T), output_addr, outputs[0]->size / sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr));
if (!ret) {
MS_LOG(ERROR) << "For UniformInt op, `minval` should be strictly less than `maxval`";
return false;
}
break; break;
} }
case RANDOM_OP_UNIFORM_REAL: { case RANDOM_OP_UNIFORM_REAL: {


Loading…
Cancel
Save