|
|
|
@@ -18,6 +18,8 @@ |
|
|
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ |
|
|
|
|
|
|
|
#include <vector> |
|
|
|
#include <chrono> |
|
|
|
#include <random> |
|
|
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h" |
|
|
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" |
|
|
|
#include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh" |
|
|
|
@@ -27,7 +29,8 @@ namespace kernel { |
|
|
|
template <typename T, typename S> |
|
|
|
class RandomChoiceWithMaskGpuKernel : public GpuKernel { |
|
|
|
public: |
|
|
|
RandomChoiceWithMaskGpuKernel() : input_shape_size_(0), seedc_(0), input_size_(1), count_(0), ceil_power2_(0) {} |
|
|
|
RandomChoiceWithMaskGpuKernel() |
|
|
|
: input_shape_size_(0), seed_(0), seed2_(0), input_size_(1), count_(0), ceil_power2_(0) {} |
|
|
|
~RandomChoiceWithMaskGpuKernel() override = default; |
|
|
|
|
|
|
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } |
|
|
|
@@ -39,6 +42,14 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { |
|
|
|
T *input = GetDeviceAddress<T>(inputs, 0); |
|
|
|
S *output_index = GetDeviceAddress<S>(outputs, 0); |
|
|
|
T *output_mask = GetDeviceAddress<T>(outputs, 1); |
|
|
|
int seedc = 0; |
|
|
|
if (seed2_ != 0) { |
|
|
|
seedc = seed2_; |
|
|
|
} else if (seed_ != 0) { |
|
|
|
seedc = seed_; |
|
|
|
} else { |
|
|
|
seedc = generator_(); |
|
|
|
} |
|
|
|
if (count_ > kSmallK || input_shape_size_ > 1) { |
|
|
|
S *index_buff = GetDeviceAddress<S>(workspaces, 0); |
|
|
|
S *mask_buff = GetDeviceAddress<S>(workspaces, 1); |
|
|
|
@@ -48,17 +59,18 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { |
|
|
|
void *States = GetDeviceAddress<void *>(workspaces, 5); |
|
|
|
curandState *devStates = reinterpret_cast<curandState *>(States); |
|
|
|
CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], |
|
|
|
input_shape_5D_[2], input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input, |
|
|
|
input_shape_5D_[2], input_shape_5D_[3], input_shape_5D_[4], seedc, count_, input, |
|
|
|
output_index, output_mask, index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, |
|
|
|
devStates, reinterpret_cast<cudaStream_t>(stream_ptr)); |
|
|
|
} else { |
|
|
|
CalRandomChoiceWithMaskSmall<float, S, T>(input_size_, seedc_, count_, input, output_index, output_mask, |
|
|
|
CalRandomChoiceWithMaskSmall<float, S, T>(input_size_, seedc, count_, input, output_index, output_mask, |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool Init(const CNodePtr &kernel_node) override { |
|
|
|
uint32_t time_interval = std::chrono::system_clock::now().time_since_epoch().count(); |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); |
|
|
|
if (input_num != 1) { |
|
|
|
MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomChoiceWithMask needs 1 input."; |
|
|
|
@@ -84,15 +96,10 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { |
|
|
|
while (input_shape_5D_.size() != MAX_DIMENSION) { |
|
|
|
input_shape_5D_.insert(input_shape_5D_.begin(), 1); |
|
|
|
} |
|
|
|
// init seedc_ |
|
|
|
int seed = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed")); |
|
|
|
int seed2 = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed2")); |
|
|
|
if (seed2 != 0) |
|
|
|
seedc_ = seed2; |
|
|
|
else if (seed != 0) |
|
|
|
seedc_ = seed; |
|
|
|
else |
|
|
|
seedc_ = time(NULL); |
|
|
|
// init seedc |
|
|
|
seed_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed")); |
|
|
|
seed2_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed2")); |
|
|
|
generator_.seed(time_interval); |
|
|
|
// init memory |
|
|
|
for (size_t i = 0; i < input_shape.size(); i++) { |
|
|
|
input_size_ *= input_shape[i]; |
|
|
|
@@ -125,10 +132,12 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { |
|
|
|
private: |
|
|
|
const int kSmallK = 2048; |
|
|
|
int input_shape_size_; |
|
|
|
int seedc_; |
|
|
|
int seed_; |
|
|
|
int seed2_; |
|
|
|
int input_size_; |
|
|
|
int count_; |
|
|
|
int ceil_power2_; |
|
|
|
std::mt19937 generator_; |
|
|
|
std::vector<int> input_shape_5D_; |
|
|
|
std::vector<size_t> input_size_list_; |
|
|
|
std::vector<size_t> output_size_list_; |
|
|
|
|