|
|
|
@@ -28,17 +28,22 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_INVALID_TYPE = 255 }; |
|
|
|
enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_UNIFORM_REAL, RANDOM_OP_INVALID_TYPE = 255 }; |
|
|
|
|
|
|
|
const std::map<std::string, RandomOptype> kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}}; |
|
|
|
const std::map<std::string, RandomOptype> kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}, |
|
|
|
{"UniformReal", RANDOM_OP_UNIFORM_REAL}}; |
|
|
|
template <typename T> |
|
|
|
class RandomOpGpuKernel : public GpuKernel { |
|
|
|
public: |
|
|
|
RandomOpGpuKernel() |
|
|
|
: random_op_type_(RANDOM_OP_INVALID_TYPE), |
|
|
|
input_size_0_(0), |
|
|
|
input_size_0_(sizeof(int)), |
|
|
|
input_size_1_(sizeof(T)), |
|
|
|
input_size_2_(sizeof(T)), |
|
|
|
output_size_(sizeof(T)), |
|
|
|
workspace_size_(sizeof(curandState)) {} |
|
|
|
workspace_size_(sizeof(curandState)), |
|
|
|
seed_(0), |
|
|
|
seed2_(0) {} |
|
|
|
~RandomOpGpuKernel() override = default; |
|
|
|
|
|
|
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } |
|
|
|
@@ -57,12 +62,21 @@ class RandomOpGpuKernel : public GpuKernel { |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)); |
|
|
|
break; |
|
|
|
} |
|
|
|
case RANDOM_OP_UNIFORM_REAL: { |
|
|
|
T *input_addr_1 = GetDeviceAddress<T>(inputs, 1); |
|
|
|
T *input_addr_2 = GetDeviceAddress<T>(inputs, 2); |
|
|
|
UniformReal(seed_, devStates, input_addr_1, inputs[1]->size / sizeof(T), input_addr_2, |
|
|
|
inputs[2]->size / sizeof(T), output_addr, outputs[0]->size / sizeof(T), |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)); |
|
|
|
break; |
|
|
|
} |
|
|
|
default: { |
|
|
|
MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported."; |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool Init(const CNodePtr &kernel_node) override { |
|
|
|
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); |
|
|
|
auto iter = kRandomOpTypeMap.find(kernel_name); |
|
|
|
@@ -72,10 +86,14 @@ class RandomOpGpuKernel : public GpuKernel { |
|
|
|
random_op_type_ = iter->second; |
|
|
|
} |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); |
|
|
|
if (input_num != 1) { |
|
|
|
if (random_op_type_ == RANDOM_OP_NORMAL && input_num != 1) { |
|
|
|
MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 1 input."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (random_op_type_ == RANDOM_OP_UNIFORM_REAL && input_num != 3) { |
|
|
|
MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 3 inputs."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); |
|
|
|
if (output_num != 1) { |
|
|
|
MS_LOG(ERROR) << "Output number is " << output_num << ", but random op needs 1 output."; |
|
|
|
@@ -86,13 +104,25 @@ class RandomOpGpuKernel : public GpuKernel { |
|
|
|
input_size_0_ += input_shape_0[i]; |
|
|
|
} |
|
|
|
input_size_0_ *= sizeof(int); |
|
|
|
if (random_op_type_ == RANDOM_OP_UNIFORM_REAL) { |
|
|
|
auto input_shape_1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); |
|
|
|
for (size_t i = 0; i < input_shape_1.size(); i++) { |
|
|
|
input_size_1_ *= input_shape_1[i]; |
|
|
|
} |
|
|
|
auto input_shape_2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); |
|
|
|
for (size_t i = 0; i < input_shape_2.size(); i++) { |
|
|
|
input_size_2_ *= input_shape_2[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); |
|
|
|
for (size_t i = 0; i < output_shape.size(); i++) { |
|
|
|
output_size_ *= output_shape[i]; |
|
|
|
workspace_size_ *= output_shape[i]; |
|
|
|
} |
|
|
|
seed_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); |
|
|
|
seed2_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); |
|
|
|
if (random_op_type_ == RANDOM_OP_NORMAL) { |
|
|
|
seed2_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); |
|
|
|
} |
|
|
|
InitSizeLists(); |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -100,6 +130,10 @@ class RandomOpGpuKernel : public GpuKernel { |
|
|
|
protected: |
|
|
|
void InitSizeLists() override { |
|
|
|
input_size_list_.push_back(input_size_0_); |
|
|
|
if (random_op_type_ == RANDOM_OP_UNIFORM_REAL) { |
|
|
|
input_size_list_.push_back(input_size_1_); |
|
|
|
input_size_list_.push_back(input_size_2_); |
|
|
|
} |
|
|
|
output_size_list_.push_back(output_size_); |
|
|
|
workspace_size_list_.push_back(workspace_size_); |
|
|
|
} |
|
|
|
@@ -107,6 +141,8 @@ class RandomOpGpuKernel : public GpuKernel { |
|
|
|
private: |
|
|
|
RandomOptype random_op_type_; |
|
|
|
size_t input_size_0_; |
|
|
|
size_t input_size_1_; |
|
|
|
size_t input_size_2_; |
|
|
|
size_t output_size_; |
|
|
|
size_t workspace_size_; |
|
|
|
int seed_; |
|
|
|
|