| @@ -18,6 +18,8 @@ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ | #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include <chrono> | |||||
| #include <random> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | #include "backend/kernel_compiler/gpu/gpu_kernel.h" | ||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | ||||
| #include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh" | #include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh" | ||||
| @@ -27,7 +29,8 @@ namespace kernel { | |||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| class RandomChoiceWithMaskGpuKernel : public GpuKernel { | class RandomChoiceWithMaskGpuKernel : public GpuKernel { | ||||
| public: | public: | ||||
| RandomChoiceWithMaskGpuKernel() : input_shape_size_(0), seedc_(0), input_size_(1), count_(0), ceil_power2_(0) {} | |||||
| RandomChoiceWithMaskGpuKernel() | |||||
| : input_shape_size_(0), seed_(0), seed2_(0), input_size_(1), count_(0), ceil_power2_(0) {} | |||||
| ~RandomChoiceWithMaskGpuKernel() override = default; | ~RandomChoiceWithMaskGpuKernel() override = default; | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| @@ -39,6 +42,14 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { | |||||
| T *input = GetDeviceAddress<T>(inputs, 0); | T *input = GetDeviceAddress<T>(inputs, 0); | ||||
| S *output_index = GetDeviceAddress<S>(outputs, 0); | S *output_index = GetDeviceAddress<S>(outputs, 0); | ||||
| T *output_mask = GetDeviceAddress<T>(outputs, 1); | T *output_mask = GetDeviceAddress<T>(outputs, 1); | ||||
| int seedc = 0; | |||||
| if (seed2_ != 0) { | |||||
| seedc = seed2_; | |||||
| } else if (seed_ != 0) { | |||||
| seedc = seed_; | |||||
| } else { | |||||
| seedc = generator_(); | |||||
| } | |||||
| if (count_ > kSmallK || input_shape_size_ > 1) { | if (count_ > kSmallK || input_shape_size_ > 1) { | ||||
| S *index_buff = GetDeviceAddress<S>(workspaces, 0); | S *index_buff = GetDeviceAddress<S>(workspaces, 0); | ||||
| S *mask_buff = GetDeviceAddress<S>(workspaces, 1); | S *mask_buff = GetDeviceAddress<S>(workspaces, 1); | ||||
| @@ -48,17 +59,18 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { | |||||
| void *States = GetDeviceAddress<void *>(workspaces, 5); | void *States = GetDeviceAddress<void *>(workspaces, 5); | ||||
| curandState *devStates = reinterpret_cast<curandState *>(States); | curandState *devStates = reinterpret_cast<curandState *>(States); | ||||
| CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], | CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], | ||||
| input_shape_5D_[2], input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input, | |||||
| input_shape_5D_[2], input_shape_5D_[3], input_shape_5D_[4], seedc, count_, input, | |||||
| output_index, output_mask, index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, | output_index, output_mask, index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, | ||||
| devStates, reinterpret_cast<cudaStream_t>(stream_ptr)); | devStates, reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| } else { | } else { | ||||
| CalRandomChoiceWithMaskSmall<float, S, T>(input_size_, seedc_, count_, input, output_index, output_mask, | |||||
| CalRandomChoiceWithMaskSmall<float, S, T>(input_size_, seedc, count_, input, output_index, output_mask, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| uint32_t time_interval = std::chrono::system_clock::now().time_since_epoch().count(); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 1) { | if (input_num != 1) { | ||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomChoiceWithMask needs 1 input."; | MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomChoiceWithMask needs 1 input."; | ||||
| @@ -84,15 +96,10 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { | |||||
| while (input_shape_5D_.size() != MAX_DIMENSION) { | while (input_shape_5D_.size() != MAX_DIMENSION) { | ||||
| input_shape_5D_.insert(input_shape_5D_.begin(), 1); | input_shape_5D_.insert(input_shape_5D_.begin(), 1); | ||||
| } | } | ||||
| // init seedc_ | |||||
| int seed = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed")); | |||||
| int seed2 = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed2")); | |||||
| if (seed2 != 0) | |||||
| seedc_ = seed2; | |||||
| else if (seed != 0) | |||||
| seedc_ = seed; | |||||
| else | |||||
| seedc_ = time(NULL); | |||||
| // init seedc | |||||
| seed_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed")); | |||||
| seed2_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed2")); | |||||
| generator_.seed(time_interval); | |||||
| // init memory | // init memory | ||||
| for (size_t i = 0; i < input_shape.size(); i++) { | for (size_t i = 0; i < input_shape.size(); i++) { | ||||
| input_size_ *= input_shape[i]; | input_size_ *= input_shape[i]; | ||||
| @@ -125,10 +132,12 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { | |||||
| private: | private: | ||||
| const int kSmallK = 2048; | const int kSmallK = 2048; | ||||
| int input_shape_size_; | int input_shape_size_; | ||||
| int seedc_; | |||||
| int seed_; | |||||
| int seed2_; | |||||
| int input_size_; | int input_size_; | ||||
| int count_; | int count_; | ||||
| int ceil_power2_; | int ceil_power2_; | ||||
| std::mt19937 generator_; | |||||
| std::vector<int> input_shape_5D_; | std::vector<int> input_shape_5D_; | ||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||