Browse Source

update random normal seed logic

tags/v1.6.0
Zichun Ye 4 years ago
parent
commit
3acc7726d2
1 changed files with 11 additions and 16 deletions
  1. +11
    -16
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h

+ 11
- 16
mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h View File

@@ -81,23 +81,18 @@ class RandomOpGpuKernel : public GpuKernel {
switch (random_op_type_) {
case RANDOM_OP_NORMAL: {
float *mask_f = GetDeviceAddress<float>(outputs, 0);
if (!states_init_) {
int RNG_seed = 0;
std::random_device rd;
if (seed2_ != 0) {
RNG_seed = seed2_;
} else if (seed_ != 0) {
RNG_seed = seed_;
} else {
RNG_seed = static_cast<int>(rd());
}
CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_PHILOX4_32_10),
"Failed to create generator");
CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, RNG_seed),
"Failed to SetPseudoRandomGeneratorSeed");
MS_EXCEPTION_IF_NULL(mask_generator_);
states_init_ = true;
std::random_device rd;
int RNG_seed = static_cast<int>(rd());
if (seed2_ != 0) {
RNG_seed = seed2_;
} else if (seed_ != 0) {
RNG_seed = seed_;
}
CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_PHILOX4_32_10),
"Failed to create generator");
CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, RNG_seed),
"Failed to SetPseudoRandomGeneratorSeed");
MS_EXCEPTION_IF_NULL(mask_generator_);
CHECK_CURAND_RET_WITH_EXCEPT(curandSetStream(mask_generator_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Failed to set stream for generator");
// curandGen only support float or double for mask.


Loading…
Cancel
Save