|
|
|
@@ -70,8 +70,9 @@ class RandomOpGpuKernel : public GpuKernel { |
|
|
|
} |
|
|
|
|
|
|
|
curandState *devStates = nullptr; |
|
|
|
// Operator CudnnUniformReal does not need workspace memory. |
|
|
|
if (random_op_type_ != RANDOM_OP_CUDNN_UNIFORM_REAL) { |
|
|
|
// Operator StandardNormal and CudnnUniformReal use curand |
|
|
|
// so they do not need workspace memory. |
|
|
|
if (random_op_type_ >= RANDOM_OP_UNIFORM_INT && random_op_type_ <= RANDOM_OP_UNIFORM_REAL) { |
|
|
|
void *workspace_addr = GetDeviceAddress<void *>(workspace, 0); |
|
|
|
devStates = reinterpret_cast<curandState *>(workspace_addr); |
|
|
|
} |
|
|
|
@@ -79,8 +80,30 @@ class RandomOpGpuKernel : public GpuKernel { |
|
|
|
|
|
|
|
switch (random_op_type_) { |
|
|
|
case RANDOM_OP_NORMAL: { |
|
|
|
StandardNormal(seed_, seed2_, devStates, output_addr, outputs[0]->size / sizeof(T), |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)); |
|
|
|
float *mask_f = GetDeviceAddress<float>(outputs, 0); |
|
|
|
if (!states_init_) { |
|
|
|
int RNG_seed = 0; |
|
|
|
std::random_device rd; |
|
|
|
if (seed2_ != 0) { |
|
|
|
RNG_seed = seed2_; |
|
|
|
} else if (seed_ != 0) { |
|
|
|
RNG_seed = seed_; |
|
|
|
} else { |
|
|
|
RNG_seed = static_cast<int>(rd()); |
|
|
|
} |
|
|
|
CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_PHILOX4_32_10), |
|
|
|
"Failed to create generator"); |
|
|
|
CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, RNG_seed), |
|
|
|
"Failed to SetPseudoRandomGeneratorSeed"); |
|
|
|
MS_EXCEPTION_IF_NULL(mask_generator_); |
|
|
|
states_init_ = true; |
|
|
|
} |
|
|
|
CHECK_CURAND_RET_WITH_EXCEPT(curandSetStream(mask_generator_, reinterpret_cast<cudaStream_t>(stream_ptr)), |
|
|
|
"Failed to set stream for generator"); |
|
|
|
// curandGen only support float or double for mask. |
|
|
|
CHECK_CURAND_RET_WITH_EXCEPT( |
|
|
|
curandGenerateNormal(mask_generator_, mask_f, outputs[0]->size / sizeof(float), 0.0, 1.0), |
|
|
|
"Failed to generate uniform"); |
|
|
|
break; |
|
|
|
} |
|
|
|
case RANDOM_OP_UNIFORM_INT: { |
|
|
|
@@ -103,7 +126,7 @@ class RandomOpGpuKernel : public GpuKernel { |
|
|
|
case RANDOM_OP_CUDNN_UNIFORM_REAL: { |
|
|
|
float *mask_f = GetDeviceAddress<float>(outputs, 0); |
|
|
|
if (!states_init_) { |
|
|
|
CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT), |
|
|
|
CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_PHILOX4_32_10), |
|
|
|
"Failed to create generator"); |
|
|
|
CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, seed_), |
|
|
|
"Failed to SetPseudoRandomGeneratorSeed"); |
|
|
|
|