Merge pull request !3786 from baihuawei/multinomialtags/v0.7.0-beta
| @@ -19,6 +19,9 @@ from mindspore.ops import _utils as utils | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | |||||
| import mindspore.nn as nn | |||||
| def cast_to_tensor(t, dtype=mstype.float32): | def cast_to_tensor(t, dtype=mstype.float32): | ||||
| """ | """ | ||||
| @@ -196,3 +199,55 @@ def check_prob(p): | |||||
| comp = np.greater(p.asnumpy(), np.ones(p.shape)) | comp = np.greater(p.asnumpy(), np.ones(p.shape)) | ||||
| if comp.any(): | if comp.any(): | ||||
| raise ValueError('Probabilities should be less than or equal to one') | raise ValueError('Probabilities should be less than or equal to one') | ||||
| def logits_to_probs(logits, is_binary=False): | |||||
| """ | |||||
| converts logits into probabilities. | |||||
| Args: | |||||
| logits (Tensor) | |||||
| is_binary (bool) | |||||
| """ | |||||
| if is_binary: | |||||
| return nn.sigmoid()(logits) | |||||
| return nn.softmax(axis=-1)(logits) | |||||
| def clamp_probs(probs): | |||||
| """ | |||||
| clamp probs boundary | |||||
| Args: | |||||
| probs (Tensor) | |||||
| """ | |||||
| eps = P.Eps()(probs) | |||||
| return C.clip_by_value(probs, eps, 1-eps) | |||||
| def probs_to_logits(probs, is_binary=False): | |||||
| """ | |||||
| converts probabilities into logits. | |||||
| Args: | |||||
| probs (Tensor) | |||||
| is_binary (bool) | |||||
| """ | |||||
| ps_clamped = clamp_probs(probs) | |||||
| if is_binary: | |||||
| return P.Log()(ps_clamped) - P.Log()(1-ps_clamped) | |||||
| return P.Log()(ps_clamped) | |||||
| def check_tensor_type(name, inputs, valid_type): | |||||
| """ | |||||
| Check if inputs is proper. | |||||
| Args: | |||||
| inputs: Tensor to be checked. | |||||
| name: inputs name | |||||
| Raises: | |||||
| ValueError: if inputs is not a proper Tensor. | |||||
| """ | |||||
| if not isinstance(inputs, Tensor): | |||||
| raise TypeError(f"{name} should be a Tensor") | |||||
| inputs = P.DType()(inputs) | |||||
| if inputs not in valid_type: | |||||
| raise TypeError(f"{name} dtype is invalid") | |||||
| @@ -27,7 +27,7 @@ from .clip_ops import clip_by_value | |||||
| from .multitype_ops.add_impl import hyper_add | from .multitype_ops.add_impl import hyper_add | ||||
| from .multitype_ops.ones_like_impl import ones_like | from .multitype_ops.ones_like_impl import ones_like | ||||
| from .multitype_ops.zeros_like_impl import zeros_like | from .multitype_ops.zeros_like_impl import zeros_like | ||||
| from .random_ops import set_seed, normal | |||||
| from .random_ops import set_seed, normal, multinomial | |||||
| __all__ = [ | __all__ = [ | ||||
| @@ -50,4 +50,5 @@ __all__ = [ | |||||
| 'zip_operation', | 'zip_operation', | ||||
| 'set_seed', | 'set_seed', | ||||
| 'normal', | 'normal', | ||||
| 'multinomial', | |||||
| 'clip_by_value',] | 'clip_by_value',] | ||||
| @@ -20,6 +20,9 @@ from .. import functional as F | |||||
| from ..primitive import constexpr | from ..primitive import constexpr | ||||
| from .multitype_ops import _constexpr_utils as const_utils | from .multitype_ops import _constexpr_utils as const_utils | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.tensor import Tensor | |||||
| from ..._checkparam import Validator as validator | |||||
| from ..._checkparam import Rel | |||||
| # set graph-level RNG seed | # set graph-level RNG seed | ||||
| _GRAPH_SEED = 0 | _GRAPH_SEED = 0 | ||||
| @@ -68,3 +71,51 @@ def normal(shape, mean, stddev, seed=0): | |||||
| rnd = stdnormal(shape) | rnd = stdnormal(shape) | ||||
| value = rnd * stddev + mean | value = rnd * stddev + mean | ||||
| return value | return value | ||||
| def multinomial(inputs, num_sample=None, replacement=True, seed=0): | |||||
| r""" | |||||
| Returns a tensor sampled from the multinomial probability distribution located in the corresponding | |||||
| row of tensor input. | |||||
| Note: | |||||
| The rows of input do not need to sum to one (in which case we use the values as weights), | |||||
| but must be non-negative, finite and have a non-zero sum. | |||||
| Args: | |||||
| seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. | |||||
| Default: 0. | |||||
| Inputs: | |||||
| - **input** (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims. | |||||
| - **num_samples** (int) - number of samples to draw, default None. | |||||
| - **replacement** (bool, optional) - whether to draw with replacement or not, default True. | |||||
| Outputs: | |||||
| Tensor. have the same rows with input, each row has num_samples sampled indices. | |||||
| Examples: | |||||
| >>> input = Tensor([0, 9, 4, 0], mstype.float32) | |||||
| >>> output = C.multinomial(input, 2, True) | |||||
| """ | |||||
| shape = P.Shape() | |||||
| reshape = P.Reshape() | |||||
| validator.check_value_type('replacement', replacement, (bool,), None) | |||||
| validator.check_value_type('num_sample', num_sample, (int,), None) | |||||
| validator.check_integer("num_sample", num_sample, 0, Rel.GT, None) | |||||
| if inputs.dim() != 1 and inputs.dim() != 2: | |||||
| raise ValueError("inputs dim must be 1d or 2d") | |||||
| if not replacement: | |||||
| if shape(inputs)[-1] < num_sample: | |||||
| raise ValueError("num_sample must be less than shape(input)[-1] without replacement") | |||||
| n_dist = 1 | |||||
| if len(shape(inputs)) > 1: | |||||
| n_dist = shape(inputs)[-2] | |||||
| a = Tensor(0.0, mstype.float32) | |||||
| b = Tensor(1.0, mstype.float32) | |||||
| uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b) | |||||
| if n_dist != 1: | |||||
| uniform = reshape(uniform, (n_dist, num_sample)) | |||||
| vals = P.RealDiv()(P.Log()(uniform), inputs + 1e-6) | |||||
| _, indices = P.TopK()(vals, num_sample) | |||||
| return indices | |||||
| return P.Multinomial(seed=seed)(inputs, num_sample) | |||||
| @@ -57,7 +57,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) | Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) | ||||
| from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | ||||
| RandomCategorical, Laplace) | |||||
| RandomCategorical, Laplace, Multinomial) | |||||
| from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, ApplyMomentum, BatchNorm, | from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, ApplyMomentum, BatchNorm, | ||||
| BiasAdd, Conv2D, | BiasAdd, Conv2D, | ||||
| DepthwiseConv2dNative, | DepthwiseConv2dNative, | ||||
| @@ -184,6 +184,7 @@ __all__ = [ | |||||
| 'Tanh', | 'Tanh', | ||||
| 'RandomChoiceWithMask', | 'RandomChoiceWithMask', | ||||
| 'StandardNormal', | 'StandardNormal', | ||||
| 'Multinomial', | |||||
| 'Gamma', | 'Gamma', | ||||
| 'Poisson', | 'Poisson', | ||||
| 'UniformInt', | 'UniformInt', | ||||
| @@ -409,6 +409,7 @@ class RandomCategorical(PrimitiveWithInfer): | |||||
| >>> net = Net(8) | >>> net = Net(8) | ||||
| >>> output = net(Tensor(x)) | >>> output = net(Tensor(x)) | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, dtype=mstype.int64): | def __init__(self, dtype=mstype.int64): | ||||
| """Init RandomCategorical""" | """Init RandomCategorical""" | ||||
| @@ -436,3 +437,54 @@ class RandomCategorical(PrimitiveWithInfer): | |||||
| return {'shape': (x_shape), | return {'shape': (x_shape), | ||||
| 'dtype': (self.dtype), | 'dtype': (self.dtype), | ||||
| 'value': None} | 'value': None} | ||||
| class Multinomial(PrimitiveWithInfer): | |||||
| r""" | |||||
| Returns a tensor sampled from the multinomial probability distribution located in the corresponding | |||||
| row of tensor input. | |||||
| Note: | |||||
| The rows of input do not need to sum to one (in which case we use the values as weights), | |||||
| but must be non-negative, finite and have a non-zero sum. | |||||
| Args: | |||||
| seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers. | |||||
| Default: 0. | |||||
| Inputs: | |||||
| - **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dims. | |||||
| - **num_samples** (int) - number of samples to draw. | |||||
| Outputs: | |||||
| Tensor. have the same rows with input, each row has num_samples sampled indices. | |||||
| Examples: | |||||
| >>> input = Tensor([0., 9., 4., 0.], mstype.float32) | |||||
| >>> multinomial = P.Multinomial(seed=10) | |||||
| >>> output = multinomial(input, 2) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, seed=0): | |||||
| """init""" | |||||
| validator.check_value_type("seed", seed, [int], self.name) | |||||
| self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) | |||||
| def __infer__(self, inputs, num_samples): | |||||
| input_shape = inputs["shape"] | |||||
| if len(input_shape) != 1 and len(input_shape) != 2: | |||||
| raise ValueError("input dim must be 1 or 2") | |||||
| validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name) | |||||
| num_samples_value = num_samples["value"] | |||||
| if num_samples_value is None: | |||||
| raise ValueError(f"For {self.name}, shape nust be const") | |||||
| validator.check_value_type("num_samples", num_samples_value, [int], self.name) | |||||
| validator.check_integer("num_samples", num_samples_value, 0, Rel.GT, None) | |||||
| y_shape = (num_samples_value,) | |||||
| if len(input_shape) == 2: | |||||
| y_shape = (input_shape[0], num_samples_value) | |||||
| out = { | |||||
| "shape": y_shape, | |||||
| "dtype": mstype.int32, | |||||
| "value": None} | |||||
| return out | |||||