| @@ -102,13 +102,15 @@ __global__ void MultinomialKernel(int seed, T *input, int num_sample, curandStat | |||
| } | |||
| template <typename T> | |||
| void Multinomial(int seed, T *input, int num_sample, curandState *globalState, int *output, size_t distributions, | |||
| size_t categories, cudaStream_t cuda_stream) { | |||
| void Multinomial(int seed, int seed2, T *input, int num_sample, curandState *globalState, int *output, | |||
| size_t distributions, size_t categories, cudaStream_t cuda_stream) { | |||
| int RNG_seed = 0; | |||
| if (seed != 0) { | |||
| std::random_device rd; | |||
| if (seed2 != 0) { | |||
| RNG_seed = seed2; | |||
| } else if (seed != 0) { | |||
| RNG_seed = seed; | |||
| } else { | |||
| std::random_device rd; | |||
| RNG_seed = static_cast<int>(rd()); | |||
| } | |||
| int count = distributions * num_sample; | |||
| @@ -117,8 +119,8 @@ void Multinomial(int seed, T *input, int num_sample, curandState *globalState, i | |||
| return; | |||
| } | |||
| template void Multinomial<float>(int seed, float *input, int num_sample, curandState *globalState, int *output, | |||
| size_t distributions, size_t categories, cudaStream_t cuda_stream); | |||
| template void Multinomial<float>(int seed, int seed2, float *input, int num_sample, curandState *globalState, | |||
| int *output, size_t distributions, size_t categories, cudaStream_t cuda_stream); | |||
| template void CheckNonNeg<float>(const size_t size, const float *input, float *output, cudaStream_t cuda_stream); | |||
| template void CheckZero<float>(const size_t distributions, const size_t categories, const float *input, float *output, | |||
| cudaStream_t cuda_stream); | |||
| @@ -20,8 +20,8 @@ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void Multinomial(int seed, T *input, int num_sample, curandState *globalState, int *output, size_t distributions, | |||
| size_t categories, cudaStream_t cuda_stream); | |||
| void Multinomial(int seed, int seed2, T *input, int num_sample, curandState *globalState, int *output, | |||
| size_t distributions, size_t categories, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t stream); | |||
| template <typename T> | |||
| @@ -32,7 +32,13 @@ namespace kernel { | |||
| template <typename T> | |||
| class MultinomialGpuKernel : public GpuKernel { | |||
| public: | |||
| MultinomialGpuKernel() : input_size_0_(0), output_size_(0), distributions_(0), workspace_size_(sizeof(curandState)) {} | |||
| MultinomialGpuKernel() | |||
| : input_size_0_(0), | |||
| output_size_(0), | |||
| distributions_(0), | |||
| workspace_size_(sizeof(curandState)), | |||
| seed_(0), | |||
| seed2_(0) {} | |||
| ~MultinomialGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -52,7 +58,7 @@ class MultinomialGpuKernel : public GpuKernel { | |||
| IntToSize(categories), 1, false, false, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| NormInput(cum_sum_input, IntToSize(distributions_), IntToSize(categories), | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_), | |||
| Multinomial(seed_, seed2_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_), | |||
| IntToSize(categories), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| @@ -87,6 +93,7 @@ class MultinomialGpuKernel : public GpuKernel { | |||
| } | |||
| workspace_size_ = output_size_; | |||
| seed_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); | |||
| seed2_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| @@ -106,6 +113,7 @@ class MultinomialGpuKernel : public GpuKernel { | |||
| size_t distributions_; | |||
| size_t workspace_size_; | |||
| int seed_; | |||
| int seed2_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| @@ -33,7 +33,7 @@ class Categorical(Distribution): | |||
| Args: | |||
| probs (Tensor, list, numpy.ndarray, Parameter): Event probabilities. | |||
| seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None. | |||
| dtype (mindspore.dtype): The type of the distribution. Default: mstype.int32. | |||
| dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32. | |||
| name (str): The name of the distribution. Default: Categorical. | |||
| Note: | |||
| @@ -109,6 +109,11 @@ def is_same_type(inst, type_): | |||
| """ | |||
| return inst == type_ | |||
| @constexpr | |||
| def check_valid_dim(dim, name): | |||
| if dim not in (1, 2): | |||
| raise ValueError( | |||
| f"For {name}, inputs dim must be 1d or 2d") | |||
| @constexpr | |||
| def check_valid_type(data_type, value_type, name): | |||
| @@ -205,7 +205,7 @@ def poisson(shape, mean, seed=None): | |||
| value = random_poisson(shape, mean) | |||
| return value | |||
| def multinomial(inputs, num_sample, replacement=True, seed=0): | |||
| def multinomial(inputs, num_sample, replacement=True, seed=None): | |||
| r""" | |||
| Returns a tensor sampled from the multinomial probability distribution located in the corresponding | |||
| row of the input tensor. | |||
| @@ -232,18 +232,18 @@ def multinomial(inputs, num_sample, replacement=True, seed=0): | |||
| """ | |||
| shape = P.Shape() | |||
| reshape = P.Reshape() | |||
| if inputs.dim() != 1 and inputs.dim() != 2: | |||
| const_utils.raise_value_error("inputs dim must be 1d or 2d") | |||
| const_utils.check_valid_dim(len(shape(inputs)), "multinomial") | |||
| seed1, seed2 = _get_seed(seed, "multinomial") | |||
| if not replacement: | |||
| if shape(inputs)[-1] < num_sample: | |||
| const_utils.raise_value_error("num_sample must be less than shape(input)[-1] without replacement") | |||
| n_dist = 1 | |||
| if len(shape(inputs)) > 1: | |||
| n_dist = shape(inputs)[-2] | |||
| random_uniform = P.UniformReal(seed=seed)((n_dist * shape(inputs)[-1],)) | |||
| random_uniform = P.UniformReal(seed1, seed2)((n_dist * shape(inputs)[-1],)) | |||
| if n_dist != 1: | |||
| random_uniform = reshape(random_uniform, (n_dist, shape(inputs)[-1])) | |||
| vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6) | |||
| _, indices = P.TopK()(vals, num_sample) | |||
| return indices | |||
| return P.Multinomial(seed=seed)(inputs, num_sample) | |||
| return P.Multinomial(seed1, seed2)(inputs, num_sample) | |||
| @@ -433,8 +433,8 @@ class Multinomial(PrimitiveWithInfer): | |||
| The rows of input do not need to sum to one (in which case we use the values as weights), | |||
| but must be non-negative, finite and have a non-zero sum. | |||
| Args: | |||
| seed (int): Seed data is used as entropy source for Random number engines to generate pseudo-random numbers. | |||
| Must be non-negative. Default: 0. | |||
| seed (int): Random seed, must be non-negative. Default: 0. | |||
| seed2 (int): Random seed2, must be non-negative. Default: 0. | |||
| Inputs: | |||
| - **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 | |||
| dimensions. | |||
| @@ -450,10 +450,10 @@ class Multinomial(PrimitiveWithInfer): | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, seed=0): | |||
| def __init__(self, seed=0, seed2=0): | |||
| """init""" | |||
| Validator.check_value_type("seed", seed, [int], self.name) | |||
| Validator.check_non_negative_int(seed, "seed", self.name) | |||
| Validator.check_non_negative_int(seed2, "seed2", self.name) | |||
| self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) | |||
| def __infer__(self, inputs, num_samples): | |||
| @@ -17,9 +17,20 @@ import numpy as np | |||
| import pytest | |||
| from mindspore.ops import composite as C | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| context.set_context(device_target='GPU') | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| class Net(nn.Cell): | |||
| def __init__(self, sample, replacement, seed=0): | |||
| super(Net, self).__init__() | |||
| self.sample = sample | |||
| self.replacement = replacement | |||
| self.seed = seed | |||
| def construct(self, x): | |||
| return C.multinomial(x, self.sample, self.replacement, self.seed) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @@ -27,9 +38,12 @@ context.set_context(device_target='GPU') | |||
| def test_multinomial(): | |||
| x0 = Tensor(np.array([0.9, 0.2]).astype(np.float32)) | |||
| x1 = Tensor(np.array([[0.9, 0.2], [0.9, 0.2]]).astype(np.float32)) | |||
| out0 = C.multinomial(x0, 1, True) | |||
| out1 = C.multinomial(x0, 2, True) | |||
| out2 = C.multinomial(x1, 6, True) | |||
| net0 = Net(1, True, 20) | |||
| net1 = Net(2, True, 20) | |||
| net2 = Net(6, True, 20) | |||
| out0 = net0(x0) | |||
| out1 = net1(x0) | |||
| out2 = net2(x1) | |||
| assert out0.asnumpy().shape == (1,) | |||
| assert out1.asnumpy().shape == (2,) | |||
| assert out2.asnumpy().shape == (2, 6) | |||