| @@ -102,13 +102,15 @@ __global__ void MultinomialKernel(int seed, T *input, int num_sample, curandStat | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void Multinomial(int seed, T *input, int num_sample, curandState *globalState, int *output, size_t distributions, | |||||
| size_t categories, cudaStream_t cuda_stream) { | |||||
| void Multinomial(int seed, int seed2, T *input, int num_sample, curandState *globalState, int *output, | |||||
| size_t distributions, size_t categories, cudaStream_t cuda_stream) { | |||||
| int RNG_seed = 0; | int RNG_seed = 0; | ||||
| if (seed != 0) { | |||||
| std::random_device rd; | |||||
| if (seed2 != 0) { | |||||
| RNG_seed = seed2; | |||||
| } else if (seed != 0) { | |||||
| RNG_seed = seed; | RNG_seed = seed; | ||||
| } else { | } else { | ||||
| std::random_device rd; | |||||
| RNG_seed = static_cast<int>(rd()); | RNG_seed = static_cast<int>(rd()); | ||||
| } | } | ||||
| int count = distributions * num_sample; | int count = distributions * num_sample; | ||||
| @@ -117,8 +119,8 @@ void Multinomial(int seed, T *input, int num_sample, curandState *globalState, i | |||||
| return; | return; | ||||
| } | } | ||||
| template void Multinomial<float>(int seed, float *input, int num_sample, curandState *globalState, int *output, | |||||
| size_t distributions, size_t categories, cudaStream_t cuda_stream); | |||||
| template void Multinomial<float>(int seed, int seed2, float *input, int num_sample, curandState *globalState, | |||||
| int *output, size_t distributions, size_t categories, cudaStream_t cuda_stream); | |||||
| template void CheckNonNeg<float>(const size_t size, const float *input, float *output, cudaStream_t cuda_stream); | template void CheckNonNeg<float>(const size_t size, const float *input, float *output, cudaStream_t cuda_stream); | ||||
| template void CheckZero<float>(const size_t distributions, const size_t categories, const float *input, float *output, | template void CheckZero<float>(const size_t distributions, const size_t categories, const float *input, float *output, | ||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| @@ -20,8 +20,8 @@ | |||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| template <typename T> | template <typename T> | ||||
| void Multinomial(int seed, T *input, int num_sample, curandState *globalState, int *output, size_t distributions, | |||||
| size_t categories, cudaStream_t cuda_stream); | |||||
| void Multinomial(int seed, int seed2, T *input, int num_sample, curandState *globalState, int *output, | |||||
| size_t distributions, size_t categories, cudaStream_t cuda_stream); | |||||
| template <typename T> | template <typename T> | ||||
| void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t stream); | void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t stream); | ||||
| template <typename T> | template <typename T> | ||||
| @@ -32,7 +32,13 @@ namespace kernel { | |||||
| template <typename T> | template <typename T> | ||||
| class MultinomialGpuKernel : public GpuKernel { | class MultinomialGpuKernel : public GpuKernel { | ||||
| public: | public: | ||||
| MultinomialGpuKernel() : input_size_0_(0), output_size_(0), distributions_(0), workspace_size_(sizeof(curandState)) {} | |||||
| MultinomialGpuKernel() | |||||
| : input_size_0_(0), | |||||
| output_size_(0), | |||||
| distributions_(0), | |||||
| workspace_size_(sizeof(curandState)), | |||||
| seed_(0), | |||||
| seed2_(0) {} | |||||
| ~MultinomialGpuKernel() override = default; | ~MultinomialGpuKernel() override = default; | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| @@ -52,7 +58,7 @@ class MultinomialGpuKernel : public GpuKernel { | |||||
| IntToSize(categories), 1, false, false, reinterpret_cast<cudaStream_t>(stream_ptr)); | IntToSize(categories), 1, false, false, reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| NormInput(cum_sum_input, IntToSize(distributions_), IntToSize(categories), | NormInput(cum_sum_input, IntToSize(distributions_), IntToSize(categories), | ||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_), | |||||
| Multinomial(seed_, seed2_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_), | |||||
| IntToSize(categories), reinterpret_cast<cudaStream_t>(stream_ptr)); | IntToSize(categories), reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -87,6 +93,7 @@ class MultinomialGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| workspace_size_ = output_size_; | workspace_size_ = output_size_; | ||||
| seed_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); | seed_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); | ||||
| seed2_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -106,6 +113,7 @@ class MultinomialGpuKernel : public GpuKernel { | |||||
| size_t distributions_; | size_t distributions_; | ||||
| size_t workspace_size_; | size_t workspace_size_; | ||||
| int seed_; | int seed_; | ||||
| int seed2_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| @@ -33,7 +33,7 @@ class Categorical(Distribution): | |||||
| Args: | Args: | ||||
| probs (Tensor, list, numpy.ndarray, Parameter): Event probabilities. | probs (Tensor, list, numpy.ndarray, Parameter): Event probabilities. | ||||
| seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None. | seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None. | ||||
| dtype (mindspore.dtype): The type of the distribution. Default: mstype.int32. | |||||
| dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32. | |||||
| name (str): The name of the distribution. Default: Categorical. | name (str): The name of the distribution. Default: Categorical. | ||||
| Note: | Note: | ||||
| @@ -109,6 +109,11 @@ def is_same_type(inst, type_): | |||||
| """ | """ | ||||
| return inst == type_ | return inst == type_ | ||||
| @constexpr | |||||
| def check_valid_dim(dim, name): | |||||
| if dim not in (1, 2): | |||||
| raise ValueError( | |||||
| f"For {name}, inputs dim must be 1d or 2d") | |||||
| @constexpr | @constexpr | ||||
| def check_valid_type(data_type, value_type, name): | def check_valid_type(data_type, value_type, name): | ||||
| @@ -205,7 +205,7 @@ def poisson(shape, mean, seed=None): | |||||
| value = random_poisson(shape, mean) | value = random_poisson(shape, mean) | ||||
| return value | return value | ||||
| def multinomial(inputs, num_sample, replacement=True, seed=0): | |||||
| def multinomial(inputs, num_sample, replacement=True, seed=None): | |||||
| r""" | r""" | ||||
| Returns a tensor sampled from the multinomial probability distribution located in the corresponding | Returns a tensor sampled from the multinomial probability distribution located in the corresponding | ||||
| row of the input tensor. | row of the input tensor. | ||||
| @@ -232,18 +232,18 @@ def multinomial(inputs, num_sample, replacement=True, seed=0): | |||||
| """ | """ | ||||
| shape = P.Shape() | shape = P.Shape() | ||||
| reshape = P.Reshape() | reshape = P.Reshape() | ||||
| if inputs.dim() != 1 and inputs.dim() != 2: | |||||
| const_utils.raise_value_error("inputs dim must be 1d or 2d") | |||||
| const_utils.check_valid_dim(len(shape(inputs)), "multinomial") | |||||
| seed1, seed2 = _get_seed(seed, "multinomial") | |||||
| if not replacement: | if not replacement: | ||||
| if shape(inputs)[-1] < num_sample: | if shape(inputs)[-1] < num_sample: | ||||
| const_utils.raise_value_error("num_sample must be less than shape(input)[-1] without replacement") | const_utils.raise_value_error("num_sample must be less than shape(input)[-1] without replacement") | ||||
| n_dist = 1 | n_dist = 1 | ||||
| if len(shape(inputs)) > 1: | if len(shape(inputs)) > 1: | ||||
| n_dist = shape(inputs)[-2] | n_dist = shape(inputs)[-2] | ||||
| random_uniform = P.UniformReal(seed=seed)((n_dist * shape(inputs)[-1],)) | |||||
| random_uniform = P.UniformReal(seed1, seed2)((n_dist * shape(inputs)[-1],)) | |||||
| if n_dist != 1: | if n_dist != 1: | ||||
| random_uniform = reshape(random_uniform, (n_dist, shape(inputs)[-1])) | random_uniform = reshape(random_uniform, (n_dist, shape(inputs)[-1])) | ||||
| vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6) | vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6) | ||||
| _, indices = P.TopK()(vals, num_sample) | _, indices = P.TopK()(vals, num_sample) | ||||
| return indices | return indices | ||||
| return P.Multinomial(seed=seed)(inputs, num_sample) | |||||
| return P.Multinomial(seed1, seed2)(inputs, num_sample) | |||||
| @@ -433,8 +433,8 @@ class Multinomial(PrimitiveWithInfer): | |||||
| The rows of input do not need to sum to one (in which case we use the values as weights), | The rows of input do not need to sum to one (in which case we use the values as weights), | ||||
| but must be non-negative, finite and have a non-zero sum. | but must be non-negative, finite and have a non-zero sum. | ||||
| Args: | Args: | ||||
| seed (int): Seed data is used as entropy source for Random number engines to generate pseudo-random numbers. | |||||
| Must be non-negative. Default: 0. | |||||
| seed (int): Random seed, must be non-negative. Default: 0. | |||||
| seed2 (int): Random seed2, must be non-negative. Default: 0. | |||||
| Inputs: | Inputs: | ||||
| - **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 | - **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 | ||||
| dimensions. | dimensions. | ||||
| @@ -450,10 +450,10 @@ class Multinomial(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, seed=0): | |||||
| def __init__(self, seed=0, seed2=0): | |||||
| """init""" | """init""" | ||||
| Validator.check_value_type("seed", seed, [int], self.name) | |||||
| Validator.check_non_negative_int(seed, "seed", self.name) | Validator.check_non_negative_int(seed, "seed", self.name) | ||||
| Validator.check_non_negative_int(seed2, "seed2", self.name) | |||||
| self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) | self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) | ||||
| def __infer__(self, inputs, num_samples): | def __infer__(self, inputs, num_samples): | ||||
| @@ -17,9 +17,20 @@ import numpy as np | |||||
| import pytest | import pytest | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| context.set_context(device_target='GPU') | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, sample, replacement, seed=0): | |||||
| super(Net, self).__init__() | |||||
| self.sample = sample | |||||
| self.replacement = replacement | |||||
| self.seed = seed | |||||
| def construct(self, x): | |||||
| return C.multinomial(x, self.sample, self.replacement, self.seed) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @@ -27,9 +38,12 @@ context.set_context(device_target='GPU') | |||||
| def test_multinomial(): | def test_multinomial(): | ||||
| x0 = Tensor(np.array([0.9, 0.2]).astype(np.float32)) | x0 = Tensor(np.array([0.9, 0.2]).astype(np.float32)) | ||||
| x1 = Tensor(np.array([[0.9, 0.2], [0.9, 0.2]]).astype(np.float32)) | x1 = Tensor(np.array([[0.9, 0.2], [0.9, 0.2]]).astype(np.float32)) | ||||
| out0 = C.multinomial(x0, 1, True) | |||||
| out1 = C.multinomial(x0, 2, True) | |||||
| out2 = C.multinomial(x1, 6, True) | |||||
| net0 = Net(1, True, 20) | |||||
| net1 = Net(2, True, 20) | |||||
| net2 = Net(6, True, 20) | |||||
| out0 = net0(x0) | |||||
| out1 = net1(x0) | |||||
| out2 = net2(x1) | |||||
| assert out0.asnumpy().shape == (1,) | assert out0.asnumpy().shape == (1,) | ||||
| assert out1.asnumpy().shape == (2,) | assert out1.asnumpy().shape == (2,) | ||||
| assert out2.asnumpy().shape == (2, 6) | assert out2.asnumpy().shape == (2, 6) | ||||