|
|
|
@@ -13,108 +13,150 @@ |
|
|
|
# limitations under the License.
|
|
|
|
# ============================================================================
|
|
|
|
"""Categorical Distribution"""
|
|
|
|
import numpy as np
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
import mindspore.nn as nn
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from .distribution import Distribution
|
|
|
|
from ._utils.utils import logits_to_probs, probs_to_logits, check_type, cast_to_tensor, \
|
|
|
|
raise_probs_logits_error
|
|
|
|
from ._utils.utils import check_prob, check_sum_equal_one, check_type, check_rank,\
|
|
|
|
check_distribution_name, raise_not_implemented_util
|
|
|
|
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
|
|
|
|
|
|
|
|
|
|
|
|
class Categorical(Distribution):
|
|
|
|
"""
|
|
|
|
Create a categorical distribution parameterized by either probabilities or logits (but not both).
|
|
|
|
Create a categorical distribution parameterized by event probabilities.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
probs (Tensor, list, numpy.ndarray, Parameter): Event probabilities.
|
|
|
|
logits (Tensor, list, numpy.ndarray, Parameter, float): Event log-odds.
|
|
|
|
seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None.
|
|
|
|
dtype (mindspore.dtype): The type of the distribution. Default: mstype.int32.
|
|
|
|
name (str): The name of the distribution. Default: Categorical.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
`probs` must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1.
|
|
|
|
`probs` must have rank at least 1, values are proper probabilities and sum to 1.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> # To initialize a Categorical distribution of prob is [0.5, 0.5]
|
|
|
|
>>> # To initialize a Categorical distribution of probs [0.5, 0.5]
|
|
|
|
>>> import mindspore.nn.probability.distribution as msd
|
|
|
|
>>> b = msd.Categorical(probs = [0.5, 0.5], dtype=mstype.int32)
|
|
|
|
>>>
|
|
|
|
>>> # To use Categorical in a network
|
|
|
|
>>> # To use a Categorical distribution in a network
|
|
|
|
>>> class net(Cell):
|
|
|
|
>>> def __init__(self, probs):
|
|
|
|
>>> super(net, self).__init__():
|
|
|
|
>>> self.ca = msd.Categorical(probs=probs, dtype=mstype.int32)
|
|
|
|
>>> self.ca = msd.Categorical(probs=[0.2, 0.8], dtype=mstype.int32)
|
|
|
|
>>> self.ca1 = msd.Categorical(probs=[0.2, 0.8], dtype=mstype.int32)
|
|
|
|
>>>
|
|
|
|
>>> # All the following calls in construct are valid
|
|
|
|
>>> def construct(self, value):
|
|
|
|
>>>
|
|
|
|
>>> # Similar calls can be made to logits
|
|
|
|
>>> ans = self.ca.probs
|
|
|
|
>>> # value must be Tensor(mstype.float32, bool, mstype.int32)
|
|
|
|
>>> ans = self.ca.log_prob(value)
|
|
|
|
>>> # Private interfaces of probability functions corresponding to public interfaces, including
|
|
|
|
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows.
|
|
|
|
>>> # Args:
|
|
|
|
>>> # value (Tensor): the value to be evaluated.
|
|
|
|
>>> # probs (Tensor): event probabilities. Default: self.probs.
|
|
|
|
>>>
|
|
|
|
>>> # Examples of `prob`.
|
|
|
|
>>> # Similar calls can be made to other probability functions
|
|
|
|
>>> # by replacing `prob` by the name of the function.
|
|
|
|
>>> ans = self.ca.prob(value)
|
|
|
|
>>> # Evaluate `prob` with respect to distribution b.
|
|
|
|
>>> ans = self.ca.prob(value, probs_b)
|
|
|
|
>>> # `probs` must be passed in during function calls.
|
|
|
|
>>> ans = self.ca1.prob(value, probs_a)
|
|
|
|
>>>
|
|
|
|
>>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
|
|
|
|
>>> # Args:
|
|
|
|
>>> # probs (Tensor): event probabilities. Default: self.probs.
|
|
|
|
>>>
|
|
|
|
>>> # Usage of enumerate_support
|
|
|
|
>>> ans = self.ca.enumerate_support()
|
|
|
|
>>> # Examples of `mean`. `sd`, `var`, and `entropy` are similar.
|
|
|
|
>>> ans = self.ca.mean() # return 0.8
|
|
|
|
>>> ans = self.ca.mean(probs_b)
|
|
|
|
>>> # `probs` must be passed in during function calls.
|
|
|
|
>>> ans = self.ca1.mean(probs_a)
|
|
|
|
>>>
|
|
|
|
>>> # Usage of entropy
|
|
|
|
>>> ans = self.ca.entropy()
|
|
|
|
>>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows:
|
|
|
|
>>> # Args:
|
|
|
|
>>> # dist (str): the name of the distribution. Only 'Categorical' is supported.
|
|
|
|
>>> # probs_b (Tensor): event probabilities of distribution b.
|
|
|
|
>>> # probs (Tensor): event probabilities of distribution a. Default: self.probs.
|
|
|
|
>>>
|
|
|
|
>>> # Sample
|
|
|
|
>>> # Examples of kl_loss. `cross_entropy` is similar.
|
|
|
|
>>> ans = self.ca.kl_loss('Categorical', probs_b)
|
|
|
|
>>> ans = self.ca.kl_loss('Categorical', probs_b, probs_a)
|
|
|
|
>>> # An additional `probs` must be passed in.
|
|
|
|
>>> ans = self.ca1.kl_loss('Categorical', probs_b, probs_a)
|
|
|
|
>>>
|
|
|
|
>>> # Examples of `sample`.
|
|
|
|
>>> # Args:
|
|
|
|
>>> # shape (tuple): the shape of the sample. Default: ().
|
|
|
|
>>> # probs (Tensor): event probabilities. Default: self.probs.
|
|
|
|
>>> ans = self.ca.sample()
|
|
|
|
>>> ans = self.ca.sample((2,3))
|
|
|
|
>>> ans = self.ca.sample((2,))
|
|
|
|
>>> ans = self.b1.sample((2,3), probs_b)
|
|
|
|
>>> ans = self.b2.sample((2,3), probs_a)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
probs=None,
|
|
|
|
logits=None,
|
|
|
|
seed=None,
|
|
|
|
dtype=mstype.int32,
|
|
|
|
name="Categorical"):
|
|
|
|
param = dict(locals())
|
|
|
|
param['param_dict'] = {'probs': probs, 'logits': logits}
|
|
|
|
param['param_dict'] = {'probs': probs}
|
|
|
|
valid_dtype = mstype.int_type
|
|
|
|
check_type(dtype, valid_dtype, "Categorical")
|
|
|
|
super(Categorical, self).__init__(seed, dtype, name, param)
|
|
|
|
if (probs is None) == (logits is None):
|
|
|
|
raise_probs_logits_error()
|
|
|
|
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
|
|
|
self.reduce_sum1 = P.ReduceSum(keep_dims=False)
|
|
|
|
self.log = P.Log()
|
|
|
|
self.exp = P.Exp()
|
|
|
|
self.shape = P.Shape()
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
self.div = P.RealDiv()
|
|
|
|
self.size = P.Size()
|
|
|
|
self.mutinomial = P.Multinomial(seed=self.seed)
|
|
|
|
|
|
|
|
self._probs = self._add_parameter(probs, 'probs')
|
|
|
|
if self.probs is not None:
|
|
|
|
check_rank(self.probs)
|
|
|
|
check_prob(self.probs)
|
|
|
|
check_sum_equal_one(self.probs)
|
|
|
|
|
|
|
|
# update is_scalar_batch and broadcast_shape
|
|
|
|
# drop one dimension
|
|
|
|
if self.probs.shape[:-1] == ():
|
|
|
|
self._is_scalar_batch = True
|
|
|
|
self._broadcast_shape = self._broadcast_shape[:-1]
|
|
|
|
|
|
|
|
self.argmax = P.Argmax()
|
|
|
|
self.broadcast = broadcast_to
|
|
|
|
self.cast = P.Cast()
|
|
|
|
self.expandim = P.ExpandDims()
|
|
|
|
self.gather = P.GatherNd()
|
|
|
|
self.clip_by_value = C.clip_by_value
|
|
|
|
self.concat = P.Concat(-1)
|
|
|
|
self.cumsum = P.CumSum()
|
|
|
|
self.dtypeop = P.DType()
|
|
|
|
self.exp = exp_generic
|
|
|
|
self.expand_dim = P.ExpandDims()
|
|
|
|
self.fill = P.Fill()
|
|
|
|
self.floor = P.Floor()
|
|
|
|
self.gather = P.GatherNd()
|
|
|
|
self.less = P.Less()
|
|
|
|
self.log = log_generic
|
|
|
|
self.log_softmax = P.LogSoftmax()
|
|
|
|
self.logicor = P.LogicalOr()
|
|
|
|
self.multinomial = P.Multinomial(seed=self.seed)
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
|
|
|
self.select = P.Select()
|
|
|
|
self.shape = P.Shape()
|
|
|
|
self.softmax = P.Softmax()
|
|
|
|
self.squeeze = P.Squeeze()
|
|
|
|
self.square = P.Square()
|
|
|
|
self.transpose = P.Transpose()
|
|
|
|
if probs is not None:
|
|
|
|
self._probs = cast_to_tensor(probs, mstype.float32)
|
|
|
|
input_sum = self.reduce_sum(self._probs, -1)
|
|
|
|
self._probs = self.div(self._probs, input_sum)
|
|
|
|
self._logits = probs_to_logits(self._probs)
|
|
|
|
self._param = self._probs
|
|
|
|
else:
|
|
|
|
self._logits = cast_to_tensor(logits, mstype.float32)
|
|
|
|
input_sum = self.reduce_sum(self.exp(self._logits), -1)
|
|
|
|
self._logits = self._logits - self.log(input_sum)
|
|
|
|
self._probs = logits_to_probs(self._logits)
|
|
|
|
self._param = self._logits
|
|
|
|
self._num_events = self.shape(self._param)[-1]
|
|
|
|
self._param2d = self.reshape(self._param, (-1, self._num_events))
|
|
|
|
self._batch_shape = self.shape(self._param)[:-1]
|
|
|
|
self._batch_shape_n = (1,) * len(self._batch_shape)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def logits(self):
|
|
|
|
"""
|
|
|
|
Return the logits.
|
|
|
|
"""
|
|
|
|
return self._logits
|
|
|
|
self.index_type = mstype.int32
|
|
|
|
|
|
|
|
|
|
|
|
def extend_repr(self):
|
|
|
|
if self.is_scalar_batch:
|
|
|
|
str_info = f'probs = {self.probs}'
|
|
|
|
else:
|
|
|
|
str_info = f'batch_shape = {self._broadcast_shape}'
|
|
|
|
return str_info
|
|
|
|
|
|
|
|
@property
|
|
|
|
def probs(self):
|
|
|
|
@@ -123,68 +165,214 @@ class Categorical(Distribution): |
|
|
|
"""
|
|
|
|
return self._probs
|
|
|
|
|
|
|
|
def _sample(self, sample_shape=()):
|
|
|
|
def _mean(self, probs=None):
|
|
|
|
r"""
|
|
|
|
.. math::
|
|
|
|
E[X] = \sum_{i=0}^{num_classes-1} i*p_i
|
|
|
|
"""
|
|
|
|
Sampling.
|
|
|
|
probs = self._check_param_type(probs)
|
|
|
|
num_classes = self.shape(probs)[-1]
|
|
|
|
index = nn.Range(0., num_classes, 1.)()
|
|
|
|
return self.reduce_sum(index * probs, -1)
|
|
|
|
|
|
|
|
def _mode(self, probs=None):
|
|
|
|
probs = self._check_param_type(probs)
|
|
|
|
mode = self.cast(self.argmax(probs), self.dtype)
|
|
|
|
return self.squeeze(mode)
|
|
|
|
|
|
|
|
def _var(self, probs=None):
|
|
|
|
r"""
|
|
|
|
.. math::
|
|
|
|
VAR(X) = E[X^{2}] - (E[X])^{2}
|
|
|
|
"""
|
|
|
|
probs = self._check_param_type(probs)
|
|
|
|
num_classes = self.shape(probs)[-1]
|
|
|
|
index = nn.Range(0., num_classes, 1.)()
|
|
|
|
return self.reduce_sum(self.square(index) * probs, -1) -\
|
|
|
|
self.square(self.reduce_sum(index * probs, -1))
|
|
|
|
|
|
|
|
def _entropy(self, probs=None):
|
|
|
|
r"""
|
|
|
|
Evaluate entropy.
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
H(X) = -\sum(logits * probs)
|
|
|
|
"""
|
|
|
|
probs = self._check_param_type(probs)
|
|
|
|
logits = self.log(probs)
|
|
|
|
return self.squeeze(-self.reduce_sum(logits * probs, -1))
|
|
|
|
|
|
|
|
def _kl_loss(self, dist, probs_b, probs=None):
|
|
|
|
"""
|
|
|
|
Evaluate KL divergence between Categorical distributions.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
sample_shape (tuple): The shape of the sample. Default: ().
|
|
|
|
dist (str): The type of the distributions. Should be "Categorical" in this case.
|
|
|
|
probs_b (Tensor): Event probabilities of distribution b.
|
|
|
|
probs (Tensor): Event probabilities of distribution a. Default: self.probs.
|
|
|
|
"""
|
|
|
|
check_distribution_name(dist, 'Categorical')
|
|
|
|
probs_b = self._check_value(probs_b, 'probs_b')
|
|
|
|
probs_b = self.cast(probs_b, self.parameter_type)
|
|
|
|
probs_a = self._check_param_type(probs)
|
|
|
|
logits_a = self.log(probs_a)
|
|
|
|
logits_b = self.log(probs_b)
|
|
|
|
return self.squeeze(-self.reduce_sum(
|
|
|
|
self.softmax(logits_a) * (self.log_softmax(logits_a) - (self.log_softmax(logits_b))), -1))
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor, shape is shape(probs)[:-1] + sample_shape
|
|
|
|
def _cross_entropy(self, dist, probs_b, probs=None):
|
|
|
|
"""
|
|
|
|
self.checktuple(sample_shape, 'shape')
|
|
|
|
num_sample = 1
|
|
|
|
for i in sample_shape:
|
|
|
|
num_sample *= i
|
|
|
|
probs_2d = self.reshape(self._probs, (-1, self._num_events))
|
|
|
|
samples = self.mutinomial(probs_2d, num_sample)
|
|
|
|
samples = self.transpose(samples, (1, 0))
|
|
|
|
extend_shape = sample_shape
|
|
|
|
if len(self.shape(self._probs)) > 1:
|
|
|
|
extend_shape = sample_shape + self.shape(self._probs)[:-1]
|
|
|
|
return self.cast(self.reshape(samples, extend_shape), self.dtype)
|
|
|
|
|
|
|
|
def _log_prob(self, value):
|
|
|
|
Evaluate cross entropy between Categorical distributions.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dist (str): The type of the distributions. Should be "Categorical" in this case.
|
|
|
|
probs_b (Tensor): Event probabilities of distribution b.
|
|
|
|
probs (Tensor): Event probabilities of distribution a. Default: self.probs.
|
|
|
|
"""
|
|
|
|
check_distribution_name(dist, 'Categorical')
|
|
|
|
return self._entropy(probs) + self._kl_loss(dist, probs_b, probs)
|
|
|
|
|
|
|
|
def _log_prob(self, value, probs=None):
|
|
|
|
r"""
|
|
|
|
Evaluate log probability.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
value (Tensor): The value to be evaluated.
|
|
|
|
probs (Tensor): Event probabilities. Default: self.probs.
|
|
|
|
"""
|
|
|
|
value = self._check_value(value, 'value')
|
|
|
|
value = self.expandim(self.cast(value, mstype.float32), -1)
|
|
|
|
broad_shape = self.shape(value + self._logits)
|
|
|
|
broad = P.BroadcastTo(broad_shape)
|
|
|
|
logits_pmf = self.reshape(broad(self._logits), (-1, broad_shape[-1]))
|
|
|
|
value = self.reshape(broad(value)[..., :1], (-1, 1))
|
|
|
|
index = nn.Range(0., self.shape(value)[0], 1)()
|
|
|
|
index = self.reshape(index, (-1, 1))
|
|
|
|
value = self.concat((index, value))
|
|
|
|
value = self.cast(value, mstype.int32)
|
|
|
|
return self.reshape(self.gather(logits_pmf, value), broad_shape[:-1])
|
|
|
|
|
|
|
|
def _entropy(self):
|
|
|
|
r"""
|
|
|
|
Evaluate entropy.
|
|
|
|
value = self.cast(value, self.parameter_type)
|
|
|
|
probs = self._check_param_type(probs)
|
|
|
|
logits = self.log(probs)
|
|
|
|
|
|
|
|
# handle the case when value is of shape () and probs is a scalar batch
|
|
|
|
drop_dim = False
|
|
|
|
if self.shape(value) == () and self.shape(probs)[:-1] == ():
|
|
|
|
drop_dim = True
|
|
|
|
# manually add one more dimension: () -> (1,)
|
|
|
|
# drop this dimension before return
|
|
|
|
value = self.expand_dim(value, -1)
|
|
|
|
|
|
|
|
value = self.expand_dim(value, -1)
|
|
|
|
|
|
|
|
broadcast_shape_tensor = logits * value
|
|
|
|
broadcast_shape = self.shape(broadcast_shape_tensor)
|
|
|
|
# broadcast_shape (N, C)
|
|
|
|
num_classes = broadcast_shape[-1]
|
|
|
|
label_shape = broadcast_shape[:-1]
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
H(X) = -\sum(logits * probs)
|
|
|
|
"""
|
|
|
|
p_log_p = self._logits * self._probs
|
|
|
|
return self.reduce_sum1(-p_log_p, -1)
|
|
|
|
# broadcasting logits and value
|
|
|
|
# logit_pmf shape (num of labels, C)
|
|
|
|
logits = self.broadcast(logits, broadcast_shape_tensor)
|
|
|
|
value = self.broadcast(value, broadcast_shape_tensor)[..., :1]
|
|
|
|
|
|
|
|
def enumerate_support(self, expand=True):
|
|
|
|
# flatten value to shape (number of labels, 1)
|
|
|
|
# clip value to be in range from 0 to num_classes -1 and cast into int32
|
|
|
|
value = self.reshape(value, (-1, 1))
|
|
|
|
out_of_bound = self.squeeze(self.logicor(\
|
|
|
|
self.less(value, 0.0), self.less(num_classes-1, value)))
|
|
|
|
value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
|
|
|
|
value_clipped = self.cast(value_clipped, self.index_type)
|
|
|
|
# create index from 0 ... NumOfLabels
|
|
|
|
index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
|
|
|
|
index = self.concat((index, value_clipped))
|
|
|
|
|
|
|
|
# index into logit_pmf, fill in out_of_bound places with -inf
|
|
|
|
# reshape into label shape N
|
|
|
|
logits_pmf = self.gather(self.reshape(logits, (-1, num_classes)), index)
|
|
|
|
neg_inf = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), -np.inf)
|
|
|
|
logits_pmf = self.select(out_of_bound, neg_inf, logits_pmf)
|
|
|
|
ans = self.reshape(logits_pmf, label_shape)
|
|
|
|
if drop_dim:
|
|
|
|
return self.squeeze(ans)
|
|
|
|
return ans
|
|
|
|
|
|
|
|
def _cdf(self, value, probs=None):
|
|
|
|
r"""
|
|
|
|
Enumerate categories.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
expand (Bool): Whether to expand.
|
|
|
|
"""
|
|
|
|
num_events = self._num_events
|
|
|
|
values = nn.Range(0., num_events, 1)()
|
|
|
|
values = self.reshape(values, (num_events,) + self._batch_shape_n)
|
|
|
|
if expand:
|
|
|
|
values = P.BroadcastTo((num_events,) + self._batch_shape)(values)
|
|
|
|
values = self.cast(values, mstype.int32)
|
|
|
|
return values
|
|
|
|
Cumulative distribution function (cdf) of Categorical distributions.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
value (Tensor): The value to be evaluated.
|
|
|
|
probs (Tensor): Event probabilities. Default: self.probs.
|
|
|
|
"""
|
|
|
|
value = self._check_value(value, 'value')
|
|
|
|
value = self.cast(value, self.parameter_type)
|
|
|
|
value = self.floor(value)
|
|
|
|
probs = self._check_param_type(probs)
|
|
|
|
|
|
|
|
# handle the case when value is of shape () and probs is a scalar batch
|
|
|
|
drop_dim = False
|
|
|
|
if self.shape(value) == () and self.shape(probs)[:-1] == ():
|
|
|
|
drop_dim = True
|
|
|
|
# manually add one more dimension: () -> (1,)
|
|
|
|
# drop this dimension before return
|
|
|
|
value = self.expand_dim(value, -1)
|
|
|
|
|
|
|
|
value = self.expand_dim(value, -1)
|
|
|
|
|
|
|
|
broadcast_shape_tensor = probs * value
|
|
|
|
broadcast_shape = self.shape(broadcast_shape_tensor)
|
|
|
|
# broadcast_shape (N, C)
|
|
|
|
num_classes = broadcast_shape[-1]
|
|
|
|
label_shape = broadcast_shape[:-1]
|
|
|
|
|
|
|
|
probs = self.broadcast(probs, broadcast_shape_tensor)
|
|
|
|
value = self.broadcast(value, broadcast_shape_tensor)[..., :1]
|
|
|
|
|
|
|
|
# flatten value to shape (number of labels, 1)
|
|
|
|
value = self.reshape(value, (-1, 1))
|
|
|
|
|
|
|
|
# drop one dimension to match cdf
|
|
|
|
# clip value to be in range from 0 to num_classes -1 and cast into int32
|
|
|
|
less_than_zero = self.squeeze(self.less(value, 0.0))
|
|
|
|
value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
|
|
|
|
value_clipped = self.cast(value_clipped, self.index_type)
|
|
|
|
|
|
|
|
index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
|
|
|
|
index = self.concat((index, value_clipped))
|
|
|
|
|
|
|
|
# reshape probs and fill less_than_zero places with 0
|
|
|
|
probs = self.reshape(probs, (-1, num_classes))
|
|
|
|
cdf = self.gather(self.cumsum(probs, 1), index)
|
|
|
|
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
|
|
|
cdf = self.select(less_than_zero, zeros, cdf)
|
|
|
|
cdf = self.reshape(cdf, label_shape)
|
|
|
|
|
|
|
|
if drop_dim:
|
|
|
|
return self.squeeze(cdf)
|
|
|
|
return cdf
|
|
|
|
|
|
|
|
def _sample(self, shape=(), probs=None):
|
|
|
|
"""
|
|
|
|
Sampling.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
shape (tuple): The shape of the sample. Default: ().
|
|
|
|
probs (Tensor): Event probabilities. Default: self.probs.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor, shape is shape(probs)[:-1] + sample_shape
|
|
|
|
"""
|
|
|
|
if self.device_target == 'Ascend':
|
|
|
|
raise_not_implemented_util('On d backend, sample', self.name)
|
|
|
|
shape = self.checktuple(shape, 'shape')
|
|
|
|
probs = self._check_param_type(probs)
|
|
|
|
num_classes = self.shape(probs)[-1]
|
|
|
|
batch_shape = self.shape(probs)[:-1]
|
|
|
|
|
|
|
|
sample_shape = shape + batch_shape
|
|
|
|
drop_dim = False
|
|
|
|
if sample_shape == ():
|
|
|
|
drop_dim = True
|
|
|
|
sample_shape = (1,)
|
|
|
|
|
|
|
|
probs_2d = self.reshape(probs, (-1, num_classes))
|
|
|
|
sample_tensor = self.fill(self.dtype, shape, 1.0)
|
|
|
|
sample_tensor = self.reshape(sample_tensor, (-1, 1))
|
|
|
|
num_sample = self.shape(sample_tensor)[0]
|
|
|
|
samples = self.multinomial(probs_2d, num_sample)
|
|
|
|
samples = self.squeeze(self.transpose(samples, (1, 0)))
|
|
|
|
samples = self.cast(self.reshape(samples, sample_shape), self.dtype)
|
|
|
|
if drop_dim:
|
|
|
|
return self.squeeze(samples)
|
|
|
|
return samples
|