| @@ -18,6 +18,8 @@ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <chrono> | |||
| #include <random> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh" | |||
| @@ -27,7 +29,8 @@ namespace kernel { | |||
| template <typename T, typename S> | |||
| class RandomChoiceWithMaskGpuKernel : public GpuKernel { | |||
| public: | |||
| RandomChoiceWithMaskGpuKernel() : input_shape_size_(0), seedc_(0), input_size_(1), count_(0), ceil_power2_(0) {} | |||
| RandomChoiceWithMaskGpuKernel() | |||
| : input_shape_size_(0), seed_(0), seed2_(0), input_size_(1), count_(0), ceil_power2_(0) {} | |||
| ~RandomChoiceWithMaskGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -39,6 +42,14 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| S *output_index = GetDeviceAddress<S>(outputs, 0); | |||
| T *output_mask = GetDeviceAddress<T>(outputs, 1); | |||
| int seedc = 0; | |||
| if (seed2_ != 0) { | |||
| seedc = seed2_; | |||
| } else if (seed_ != 0) { | |||
| seedc = seed_; | |||
| } else { | |||
| seedc = generator_(); | |||
| } | |||
| if (count_ > kSmallK || input_shape_size_ > 1) { | |||
| S *index_buff = GetDeviceAddress<S>(workspaces, 0); | |||
| S *mask_buff = GetDeviceAddress<S>(workspaces, 1); | |||
| @@ -48,17 +59,18 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { | |||
| void *States = GetDeviceAddress<void *>(workspaces, 5); | |||
| curandState *devStates = reinterpret_cast<curandState *>(States); | |||
| CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], | |||
| input_shape_5D_[2], input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input, | |||
| input_shape_5D_[2], input_shape_5D_[3], input_shape_5D_[4], seedc, count_, input, | |||
| output_index, output_mask, index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, | |||
| devStates, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CalRandomChoiceWithMaskSmall<float, S, T>(input_size_, seedc_, count_, input, output_index, output_mask, | |||
| CalRandomChoiceWithMaskSmall<float, S, T>(input_size_, seedc, count_, input, output_index, output_mask, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| uint32_t time_interval = std::chrono::system_clock::now().time_since_epoch().count(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomChoiceWithMask needs 1 input."; | |||
| @@ -84,15 +96,10 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { | |||
| while (input_shape_5D_.size() != MAX_DIMENSION) { | |||
| input_shape_5D_.insert(input_shape_5D_.begin(), 1); | |||
| } | |||
| // init seedc_ | |||
| int seed = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed")); | |||
| int seed2 = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed2")); | |||
| if (seed2 != 0) | |||
| seedc_ = seed2; | |||
| else if (seed != 0) | |||
| seedc_ = seed; | |||
| else | |||
| seedc_ = time(NULL); | |||
| // init seedc | |||
| seed_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed")); | |||
| seed2_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed2")); | |||
| generator_.seed(time_interval); | |||
| // init memory | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| @@ -125,10 +132,12 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { | |||
| private: | |||
| const int kSmallK = 2048; | |||
| int input_shape_size_; | |||
| int seedc_; | |||
| int seed_; | |||
| int seed2_; | |||
| int input_size_; | |||
| int count_; | |||
| int ceil_power2_; | |||
| std::mt19937 generator_; | |||
| std::vector<int> input_shape_5D_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||