From 4d92c5b39efcbc7e61353e10add86c587ae82667 Mon Sep 17 00:00:00 2001 From: baihuawei Date: Thu, 6 Aug 2020 16:49:27 +0800 Subject: [PATCH] add multinomial --- .../probability/distribution/_utils/utils.py | 55 +++++++++++++++++++ mindspore/ops/composite/__init__.py | 3 +- mindspore/ops/composite/random_ops.py | 51 +++++++++++++++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/random_ops.py | 52 ++++++++++++++++++ 5 files changed, 162 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 3f7a92a31d..aeccfc2b8f 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -19,6 +19,9 @@ from mindspore.ops import _utils as utils from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import composite as C +import mindspore.nn as nn def cast_to_tensor(t, dtype=mstype.float32): """ @@ -196,3 +199,55 @@ def check_prob(p): comp = np.greater(p.asnumpy(), np.ones(p.shape)) if comp.any(): raise ValueError('Probabilities should be less than or equal to one') + + +def logits_to_probs(logits, is_binary=False): + """ + converts logits into probabilities. + Args: + logits (Tensor) + is_binary (bool) + """ + if is_binary: + return nn.sigmoid()(logits) + return nn.softmax(axis=-1)(logits) + + +def clamp_probs(probs): + """ + clamp probs boundary + Args: + probs (Tensor) + """ + eps = P.Eps()(probs) + return C.clip_by_value(probs, eps, 1-eps) + + +def probs_to_logits(probs, is_binary=False): + """ + converts probabilities into logits. + Args: + probs (Tensor) + is_binary (bool) + """ + ps_clamped = clamp_probs(probs) + if is_binary: + return P.Log()(ps_clamped) - P.Log()(1-ps_clamped) + return P.Log()(ps_clamped) + +def check_tensor_type(name, inputs, valid_type): + """ + Check if inputs is proper. + + Args: + inputs: Tensor to be checked. + name: inputs name + + Raises: + ValueError: if inputs is not a proper Tensor. + """ + if not isinstance(inputs, Tensor): + raise TypeError(f"{name} should be a Tensor") + inputs = P.DType()(inputs) + if inputs not in valid_type: + raise TypeError(f"{name} dtype is invalid") diff --git a/mindspore/ops/composite/__init__.py b/mindspore/ops/composite/__init__.py index 530bf9e1b7..f60378279e 100644 --- a/mindspore/ops/composite/__init__.py +++ b/mindspore/ops/composite/__init__.py @@ -27,7 +27,7 @@ from .clip_ops import clip_by_value from .multitype_ops.add_impl import hyper_add from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.zeros_like_impl import zeros_like -from .random_ops import set_seed, normal +from .random_ops import set_seed, normal, multinomial __all__ = [ @@ -50,4 +50,5 @@ __all__ = [ 'zip_operation', 'set_seed', 'normal', + 'multinomial', 'clip_by_value',] diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index 53fa58c4d3..b0c2a55d6f 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -20,6 +20,9 @@ from .. import functional as F from ..primitive import constexpr from .multitype_ops import _constexpr_utils as const_utils from ...common import dtype as mstype +from ...common.tensor import Tensor +from ..._checkparam import Validator as validator +from ..._checkparam import Rel # set graph-level RNG seed _GRAPH_SEED = 0 @@ -68,3 +71,51 @@ def normal(shape, mean, stddev, seed=0): rnd = stdnormal(shape) value = rnd * stddev + mean return value + + +def multinomial(inputs, num_sample=None, replacement=True, seed=0): + r""" + Returns a tensor sampled from the multinomial probability distribution located in the corresponding + row of tensor input. + + Note: + The rows of input do not need to sum to one (in which case we use the values as weights), + but must be non-negative, finite and have a non-zero sum. + Args: + seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. + Default: 0. + + Inputs: + - **input** (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims. + - **num_samples** (int) - number of samples to draw, default None. + - **replacement** (bool, optional) - whether to draw with replacement or not, default True. + + Outputs: + Tensor. have the same rows with input, each row has num_samples sampled indices. + + Examples: + >>> input = Tensor([0, 9, 4, 0], mstype.float32) + >>> output = C.multinomial(input, 2, True) + """ + shape = P.Shape() + reshape = P.Reshape() + validator.check_value_type('replacement', replacement, (bool,), None) + validator.check_value_type('num_sample', num_sample, (int,), None) + validator.check_integer("num_sample", num_sample, 0, Rel.GT, None) + if inputs.dim() != 1 and inputs.dim() != 2: + raise ValueError("inputs dim must be 1d or 2d") + if not replacement: + if shape(inputs)[-1] < num_sample: + raise ValueError("num_sample must be less than shape(input)[-1] without replacement") + n_dist = 1 + if len(shape(inputs)) > 1: + n_dist = shape(inputs)[-2] + a = Tensor(0.0, mstype.float32) + b = Tensor(1.0, mstype.float32) + uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b) + if n_dist != 1: + uniform = reshape(uniform, (n_dist, num_sample)) + vals = P.RealDiv()(P.Log()(uniform), inputs + 1e-6) + _, indices = P.TopK()(vals, num_sample) + return indices + return P.Multinomial(seed=seed)(inputs, num_sample) diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 59290c3234..d72b8bc9de 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -56,7 +56,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, - RandomCategorical, Laplace) + RandomCategorical, Laplace, Multinomial) from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, DepthwiseConv2dNative, @@ -176,6 +176,7 @@ __all__ = [ 'Tanh', 'RandomChoiceWithMask', 'StandardNormal', + 'Multinomial', 'Gamma', 'Poisson', 'UniformInt', diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 065c4eaf27..3b4bf9d6a9 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -408,6 +408,7 @@ class RandomCategorical(PrimitiveWithInfer): >>> net = Net(8) >>> output = net(Tensor(x)) """ + @prim_attr_register def __init__(self, dtype=mstype.int64): """Init RandomCategorical""" @@ -435,3 +436,54 @@ class RandomCategorical(PrimitiveWithInfer): return {'shape': (x_shape), 'dtype': (self.dtype), 'value': None} + + +class Multinomial(PrimitiveWithInfer): + r""" + Returns a tensor sampled from the multinomial probability distribution located in the corresponding + row of tensor input. + + Note: + The rows of input do not need to sum to one (in which case we use the values as weights), + but must be non-negative, finite and have a non-zero sum. + Args: + seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. + Default: 0. + + Inputs: + - **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dims. + - **num_samples** (int) - number of samples to draw. + + Outputs: + Tensor. have the same rows with input, each row has num_samples sampled indices. + + Examples: + >>> input = Tensor([0., 9., 4., 0.], mstype.float32) + >>> multinomial = P.Multinomial(seed=10) + >>> output = multinomial(input, 2) + """ + + @prim_attr_register + def __init__(self, seed=0): + """init""" + validator.check_value_type("seed", seed, [int], self.name) + self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) + + def __infer__(self, inputs, num_samples): + input_shape = inputs["shape"] + if len(input_shape) != 1 and len(input_shape) != 2: + raise ValueError("input dim must be 1 or 2") + validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name) + num_samples_value = num_samples["value"] + if num_samples_value is None: + raise ValueError(f"For {self.name}, shape nust be const") + validator.check_value_type("num_samples", num_samples_value, [int], self.name) + validator.check_integer("num_samples", num_samples_value, 0, Rel.GT, None) + y_shape = (num_samples_value,) + if len(input_shape) == 2: + y_shape = (input_shape[0], num_samples_value) + out = { + "shape": y_shape, + "dtype": mstype.int32, + "value": None} + return out