Browse Source

added categorical distribution

tags/v1.1.0
Xun Deng 5 years ago
parent
commit
877b561e77
5 changed files with 828 additions and 105 deletions
  1. +12
    -0
      mindspore/nn/probability/distribution/_utils/utils.py
  2. +293
    -105
      mindspore/nn/probability/distribution/categorical.py
  3. +1
    -0
      mindspore/nn/probability/distribution/distribution.py
  4. +273
    -0
      tests/st/probability/distribution/test_categorical.py
  5. +249
    -0
      tests/ut/python/nn/probability/distribution/test_categorical.py

+ 12
- 0
mindspore/nn/probability/distribution/_utils/utils.py View File

@@ -158,6 +158,18 @@ def check_prob(p):
if not comp.all():
raise ValueError('Probabilities should be less than one')

def check_sum_equal_one(probs):
prob_sum = np.sum(probs.asnumpy(), axis=-1)
comp = np.equal(np.ones(prob_sum.shape), prob_sum)
if not comp.all():
raise ValueError('Probabilities for each category should sum to one for Categorical distribution.')

def check_rank(probs):
"""
Used in categorical distribution. check Rank >=1.
"""
if probs.asnumpy().ndim == 0:
raise ValueError('probs for Categorical distribution must have rank >= 1.')

def logits_to_probs(logits, is_binary=False):
"""


+ 293
- 105
mindspore/nn/probability/distribution/categorical.py View File

@@ -13,108 +13,150 @@
# limitations under the License.
# ============================================================================
"""Categorical Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
import mindspore.nn as nn
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import logits_to_probs, probs_to_logits, check_type, cast_to_tensor, \
raise_probs_logits_error
from ._utils.utils import check_prob, check_sum_equal_one, check_type, check_rank,\
check_distribution_name, raise_not_implemented_util
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
class Categorical(Distribution):
"""
Create a categorical distribution parameterized by either probabilities or logits (but not both).
Create a categorical distribution parameterized by event probabilities.
Args:
probs (Tensor, list, numpy.ndarray, Parameter): Event probabilities.
logits (Tensor, list, numpy.ndarray, Parameter, float): Event log-odds.
seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None.
dtype (mindspore.dtype): The type of the distribution. Default: mstype.int32.
name (str): The name of the distribution. Default: Categorical.
Note:
`probs` must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1.
`probs` must have rank at least 1, values are proper probabilities and sum to 1.
Examples:
>>> # To initialize a Categorical distribution of prob is [0.5, 0.5]
>>> # To initialize a Categorical distribution of probs [0.5, 0.5]
>>> import mindspore.nn.probability.distribution as msd
>>> b = msd.Categorical(probs = [0.5, 0.5], dtype=mstype.int32)
>>>
>>> # To use Categorical in a network
>>> # To use a Categorical distribution in a network
>>> class net(Cell):
>>> def __init__(self, probs):
>>> super(net, self).__init__():
>>> self.ca = msd.Categorical(probs=probs, dtype=mstype.int32)
>>> self.ca = msd.Categorical(probs=[0.2, 0.8], dtype=mstype.int32)
>>> self.ca1 = msd.Categorical(probs=[0.2, 0.8], dtype=mstype.int32)
>>>
>>> # All the following calls in construct are valid
>>> def construct(self, value):
>>>
>>> # Similar calls can be made to logits
>>> ans = self.ca.probs
>>> # value must be Tensor(mstype.float32, bool, mstype.int32)
>>> ans = self.ca.log_prob(value)
>>> # Private interfaces of probability functions corresponding to public interfaces, including
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows.
>>> # Args:
>>> # value (Tensor): the value to be evaluated.
>>> # probs (Tensor): event probabilities. Default: self.probs.
>>>
>>> # Examples of `prob`.
>>> # Similar calls can be made to other probability functions
>>> # by replacing `prob` by the name of the function.
>>> ans = self.ca.prob(value)
>>> # Evaluate `prob` with respect to distribution b.
>>> ans = self.ca.prob(value, probs_b)
>>> # `probs` must be passed in during function calls.
>>> ans = self.ca1.prob(value, probs_a)
>>>
>>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
>>> # Args:
>>> # probs (Tensor): event probabilities. Default: self.probs.
>>>
>>> # Usage of enumerate_support
>>> ans = self.ca.enumerate_support()
>>> # Examples of `mean`. `sd`, `var`, and `entropy` are similar.
>>> ans = self.ca.mean() # return 0.8
>>> ans = self.ca.mean(probs_b)
>>> # `probs` must be passed in during function calls.
>>> ans = self.ca1.mean(probs_a)
>>>
>>> # Usage of entropy
>>> ans = self.ca.entropy()
>>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows:
>>> # Args:
>>> # dist (str): the name of the distribution. Only 'Categorical' is supported.
>>> # probs_b (Tensor): event probabilities of distribution b.
>>> # probs (Tensor): event probabilities of distribution a. Default: self.probs.
>>>
>>> # Sample
>>> # Examples of kl_loss. `cross_entropy` is similar.
>>> ans = self.ca.kl_loss('Categorical', probs_b)
>>> ans = self.ca.kl_loss('Categorical', probs_b, probs_a)
>>> # An additional `probs` must be passed in.
>>> ans = self.ca1.kl_loss('Categorical', probs_b, probs_a)
>>>
>>> # Examples of `sample`.
>>> # Args:
>>> # shape (tuple): the shape of the sample. Default: ().
>>> # probs (Tensor): event probabilities. Default: self.probs.
>>> ans = self.ca.sample()
>>> ans = self.ca.sample((2,3))
>>> ans = self.ca.sample((2,))
>>> ans = self.b1.sample((2,3), probs_b)
>>> ans = self.b2.sample((2,3), probs_a)
"""
def __init__(self,
probs=None,
logits=None,
seed=None,
dtype=mstype.int32,
name="Categorical"):
param = dict(locals())
param['param_dict'] = {'probs': probs, 'logits': logits}
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type
check_type(dtype, valid_dtype, "Categorical")
super(Categorical, self).__init__(seed, dtype, name, param)
if (probs is None) == (logits is None):
raise_probs_logits_error()
self.reduce_sum = P.ReduceSum(keep_dims=True)
self.reduce_sum1 = P.ReduceSum(keep_dims=False)
self.log = P.Log()
self.exp = P.Exp()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.div = P.RealDiv()
self.size = P.Size()
self.mutinomial = P.Multinomial(seed=self.seed)
self._probs = self._add_parameter(probs, 'probs')
if self.probs is not None:
check_rank(self.probs)
check_prob(self.probs)
check_sum_equal_one(self.probs)
# update is_scalar_batch and broadcast_shape
# drop one dimension
if self.probs.shape[:-1] == ():
self._is_scalar_batch = True
self._broadcast_shape = self._broadcast_shape[:-1]
self.argmax = P.Argmax()
self.broadcast = broadcast_to
self.cast = P.Cast()
self.expandim = P.ExpandDims()
self.gather = P.GatherNd()
self.clip_by_value = C.clip_by_value
self.concat = P.Concat(-1)
self.cumsum = P.CumSum()
self.dtypeop = P.DType()
self.exp = exp_generic
self.expand_dim = P.ExpandDims()
self.fill = P.Fill()
self.floor = P.Floor()
self.gather = P.GatherNd()
self.less = P.Less()
self.log = log_generic
self.log_softmax = P.LogSoftmax()
self.logicor = P.LogicalOr()
self.multinomial = P.Multinomial(seed=self.seed)
self.reshape = P.Reshape()
self.reduce_sum = P.ReduceSum(keep_dims=True)
self.select = P.Select()
self.shape = P.Shape()
self.softmax = P.Softmax()
self.squeeze = P.Squeeze()
self.square = P.Square()
self.transpose = P.Transpose()
if probs is not None:
self._probs = cast_to_tensor(probs, mstype.float32)
input_sum = self.reduce_sum(self._probs, -1)
self._probs = self.div(self._probs, input_sum)
self._logits = probs_to_logits(self._probs)
self._param = self._probs
else:
self._logits = cast_to_tensor(logits, mstype.float32)
input_sum = self.reduce_sum(self.exp(self._logits), -1)
self._logits = self._logits - self.log(input_sum)
self._probs = logits_to_probs(self._logits)
self._param = self._logits
self._num_events = self.shape(self._param)[-1]
self._param2d = self.reshape(self._param, (-1, self._num_events))
self._batch_shape = self.shape(self._param)[:-1]
self._batch_shape_n = (1,) * len(self._batch_shape)
@property
def logits(self):
"""
Return the logits.
"""
return self._logits
self.index_type = mstype.int32
def extend_repr(self):
if self.is_scalar_batch:
str_info = f'probs = {self.probs}'
else:
str_info = f'batch_shape = {self._broadcast_shape}'
return str_info
@property
def probs(self):
@@ -123,68 +165,214 @@ class Categorical(Distribution):
"""
return self._probs
def _sample(self, sample_shape=()):
def _mean(self, probs=None):
r"""
.. math::
E[X] = \sum_{i=0}^{num_classes-1} i*p_i
"""
Sampling.
probs = self._check_param_type(probs)
num_classes = self.shape(probs)[-1]
index = nn.Range(0., num_classes, 1.)()
return self.reduce_sum(index * probs, -1)
def _mode(self, probs=None):
probs = self._check_param_type(probs)
mode = self.cast(self.argmax(probs), self.dtype)
return self.squeeze(mode)
def _var(self, probs=None):
r"""
.. math::
VAR(X) = E[X^{2}] - (E[X])^{2}
"""
probs = self._check_param_type(probs)
num_classes = self.shape(probs)[-1]
index = nn.Range(0., num_classes, 1.)()
return self.reduce_sum(self.square(index) * probs, -1) -\
self.square(self.reduce_sum(index * probs, -1))
def _entropy(self, probs=None):
r"""
Evaluate entropy.
.. math::
H(X) = -\sum(logits * probs)
"""
probs = self._check_param_type(probs)
logits = self.log(probs)
return self.squeeze(-self.reduce_sum(logits * probs, -1))
def _kl_loss(self, dist, probs_b, probs=None):
"""
Evaluate KL divergence between Categorical distributions.
Args:
sample_shape (tuple): The shape of the sample. Default: ().
dist (str): The type of the distributions. Should be "Categorical" in this case.
probs_b (Tensor): Event probabilities of distribution b.
probs (Tensor): Event probabilities of distribution a. Default: self.probs.
"""
check_distribution_name(dist, 'Categorical')
probs_b = self._check_value(probs_b, 'probs_b')
probs_b = self.cast(probs_b, self.parameter_type)
probs_a = self._check_param_type(probs)
logits_a = self.log(probs_a)
logits_b = self.log(probs_b)
return self.squeeze(-self.reduce_sum(
self.softmax(logits_a) * (self.log_softmax(logits_a) - (self.log_softmax(logits_b))), -1))
Returns:
Tensor, shape is shape(probs)[:-1] + sample_shape
def _cross_entropy(self, dist, probs_b, probs=None):
"""
self.checktuple(sample_shape, 'shape')
num_sample = 1
for i in sample_shape:
num_sample *= i
probs_2d = self.reshape(self._probs, (-1, self._num_events))
samples = self.mutinomial(probs_2d, num_sample)
samples = self.transpose(samples, (1, 0))
extend_shape = sample_shape
if len(self.shape(self._probs)) > 1:
extend_shape = sample_shape + self.shape(self._probs)[:-1]
return self.cast(self.reshape(samples, extend_shape), self.dtype)
def _log_prob(self, value):
Evaluate cross entropy between Categorical distributions.
Args:
dist (str): The type of the distributions. Should be "Categorical" in this case.
probs_b (Tensor): Event probabilities of distribution b.
probs (Tensor): Event probabilities of distribution a. Default: self.probs.
"""
check_distribution_name(dist, 'Categorical')
return self._entropy(probs) + self._kl_loss(dist, probs_b, probs)
def _log_prob(self, value, probs=None):
r"""
Evaluate log probability.
Args:
value (Tensor): The value to be evaluated.
probs (Tensor): Event probabilities. Default: self.probs.
"""
value = self._check_value(value, 'value')
value = self.expandim(self.cast(value, mstype.float32), -1)
broad_shape = self.shape(value + self._logits)
broad = P.BroadcastTo(broad_shape)
logits_pmf = self.reshape(broad(self._logits), (-1, broad_shape[-1]))
value = self.reshape(broad(value)[..., :1], (-1, 1))
index = nn.Range(0., self.shape(value)[0], 1)()
index = self.reshape(index, (-1, 1))
value = self.concat((index, value))
value = self.cast(value, mstype.int32)
return self.reshape(self.gather(logits_pmf, value), broad_shape[:-1])
def _entropy(self):
r"""
Evaluate entropy.
value = self.cast(value, self.parameter_type)
probs = self._check_param_type(probs)
logits = self.log(probs)
# handle the case when value is of shape () and probs is a scalar batch
drop_dim = False
if self.shape(value) == () and self.shape(probs)[:-1] == ():
drop_dim = True
# manually add one more dimension: () -> (1,)
# drop this dimension before return
value = self.expand_dim(value, -1)
value = self.expand_dim(value, -1)
broadcast_shape_tensor = logits * value
broadcast_shape = self.shape(broadcast_shape_tensor)
# broadcast_shape (N, C)
num_classes = broadcast_shape[-1]
label_shape = broadcast_shape[:-1]
.. math::
H(X) = -\sum(logits * probs)
"""
p_log_p = self._logits * self._probs
return self.reduce_sum1(-p_log_p, -1)
# broadcasting logits and value
# logit_pmf shape (num of labels, C)
logits = self.broadcast(logits, broadcast_shape_tensor)
value = self.broadcast(value, broadcast_shape_tensor)[..., :1]
def enumerate_support(self, expand=True):
# flatten value to shape (number of labels, 1)
# clip value to be in range from 0 to num_classes -1 and cast into int32
value = self.reshape(value, (-1, 1))
out_of_bound = self.squeeze(self.logicor(\
self.less(value, 0.0), self.less(num_classes-1, value)))
value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
value_clipped = self.cast(value_clipped, self.index_type)
# create index from 0 ... NumOfLabels
index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
index = self.concat((index, value_clipped))
# index into logit_pmf, fill in out_of_bound places with -inf
# reshape into label shape N
logits_pmf = self.gather(self.reshape(logits, (-1, num_classes)), index)
neg_inf = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), -np.inf)
logits_pmf = self.select(out_of_bound, neg_inf, logits_pmf)
ans = self.reshape(logits_pmf, label_shape)
if drop_dim:
return self.squeeze(ans)
return ans
def _cdf(self, value, probs=None):
r"""
Enumerate categories.
Args:
expand (Bool): Whether to expand.
"""
num_events = self._num_events
values = nn.Range(0., num_events, 1)()
values = self.reshape(values, (num_events,) + self._batch_shape_n)
if expand:
values = P.BroadcastTo((num_events,) + self._batch_shape)(values)
values = self.cast(values, mstype.int32)
return values
Cumulative distribution function (cdf) of Categorical distributions.
Args:
value (Tensor): The value to be evaluated.
probs (Tensor): Event probabilities. Default: self.probs.
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.parameter_type)
value = self.floor(value)
probs = self._check_param_type(probs)
# handle the case when value is of shape () and probs is a scalar batch
drop_dim = False
if self.shape(value) == () and self.shape(probs)[:-1] == ():
drop_dim = True
# manually add one more dimension: () -> (1,)
# drop this dimension before return
value = self.expand_dim(value, -1)
value = self.expand_dim(value, -1)
broadcast_shape_tensor = probs * value
broadcast_shape = self.shape(broadcast_shape_tensor)
# broadcast_shape (N, C)
num_classes = broadcast_shape[-1]
label_shape = broadcast_shape[:-1]
probs = self.broadcast(probs, broadcast_shape_tensor)
value = self.broadcast(value, broadcast_shape_tensor)[..., :1]
# flatten value to shape (number of labels, 1)
value = self.reshape(value, (-1, 1))
# drop one dimension to match cdf
# clip value to be in range from 0 to num_classes -1 and cast into int32
less_than_zero = self.squeeze(self.less(value, 0.0))
value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
value_clipped = self.cast(value_clipped, self.index_type)
index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
index = self.concat((index, value_clipped))
# reshape probs and fill less_than_zero places with 0
probs = self.reshape(probs, (-1, num_classes))
cdf = self.gather(self.cumsum(probs, 1), index)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
cdf = self.select(less_than_zero, zeros, cdf)
cdf = self.reshape(cdf, label_shape)
if drop_dim:
return self.squeeze(cdf)
return cdf
def _sample(self, shape=(), probs=None):
"""
Sampling.
Args:
shape (tuple): The shape of the sample. Default: ().
probs (Tensor): Event probabilities. Default: self.probs.
Returns:
Tensor, shape is shape(probs)[:-1] + sample_shape
"""
if self.device_target == 'Ascend':
raise_not_implemented_util('On d backend, sample', self.name)
shape = self.checktuple(shape, 'shape')
probs = self._check_param_type(probs)
num_classes = self.shape(probs)[-1]
batch_shape = self.shape(probs)[:-1]
sample_shape = shape + batch_shape
drop_dim = False
if sample_shape == ():
drop_dim = True
sample_shape = (1,)
probs_2d = self.reshape(probs, (-1, num_classes))
sample_tensor = self.fill(self.dtype, shape, 1.0)
sample_tensor = self.reshape(sample_tensor, (-1, 1))
num_sample = self.shape(sample_tensor)[0]
samples = self.multinomial(probs_2d, num_sample)
samples = self.squeeze(self.transpose(samples, (1, 0)))
samples = self.cast(self.reshape(samples, sample_shape), self.dtype)
if drop_dim:
return self.squeeze(samples)
return samples

+ 1
- 0
mindspore/nn/probability/distribution/distribution.py View File

@@ -96,6 +96,7 @@ class Distribution(Cell):
self._set_cross_entropy()

self.context_mode = context.get_context('mode')
self.device_target = context.get_context('device_target')
self.checktuple = CheckTuple()
self.checktensor = CheckTensor()
self.broadcast = broadcast_to


+ 273
- 0
tests/st/probability/distribution/test_categorical.py View File

@@ -0,0 +1,273 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for cat distribution"""
import numpy as np
import pytest
from scipy import stats
import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor
from mindspore import dtype

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

class Prob(nn.Cell):
"""
Test class: probability of categorical distribution.
"""
def __init__(self):
super(Prob, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)

def construct(self, x_):
return self.c.prob(x_)

def test_pmf():
"""
Test pmf.
"""
expect_pmf = [0.7, 0.3, 0.7, 0.3, 0.3]
pmf = Prob()
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32)
output = pmf(x_)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()


class LogProb(nn.Cell):
"""
Test class: log probability of categorical distribution.
"""
def __init__(self):
super(LogProb, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)

def construct(self, x_):
return self.c.log_prob(x_)

def test_log_likelihood():
"""
Test log_pmf.
"""
expect_logpmf = np.log([0.7, 0.3, 0.7, 0.3, 0.3])
logprob = LogProb()
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32)
output = logprob(x_)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()

class KL(nn.Cell):
"""
Test class: kl_loss between categorical distributions.
"""
def __init__(self):
super(KL, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)

def construct(self, x_):
return self.c.kl_loss('Categorical', x_)

def test_kl_loss():
"""
Test kl_loss.
"""
kl_loss = KL()
output = kl_loss(Tensor([0.7, 0.3], dtype=dtype.float32))
tol = 1e-6
assert (np.abs(output.asnumpy()) < tol).all()

class Sampling(nn.Cell):
"""
Test class: sampling of categorical distribution.
"""
def __init__(self):
super(Sampling, self).__init__()
self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32)
self.shape = (2, 3)

def construct(self):
return self.c.sample(self.shape)

def test_sample():
"""
Test sample.
"""
with pytest.raises(NotImplementedError):
sample = Sampling()
sample()

class Basics(nn.Cell):
"""
Test class: mean/var/mode of categorical distribution.
"""
def __init__(self):
super(Basics, self).__init__()
self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32)

def construct(self):
return self.c.mean(), self.c.var(), self.c.mode()

def test_basics():
"""
Test mean/variance/mode.
"""
basics = Basics()
mean, var, mode = basics()
expect_mean = 0 * 0.2 + 1 * 0.1 + 2 * 0.7
expect_var = 0 * 0.2 + 1 * 0.1 + 4 * 0.7 - (expect_mean * expect_mean)
expect_mode = 2
tol = 1e-6
assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
assert (np.abs(var.asnumpy() - expect_var) < tol).all()
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()


class CDF(nn.Cell):
"""
Test class: cdf of categorical distributions.
"""
def __init__(self):
super(CDF, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)

def construct(self, x_):
return self.c.cdf(x_)

def test_cdf():
"""
Test cdf.
"""
expect_cdf = [0.7, 0.7, 1, 0.7, 1]
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32)
cdf = CDF()
output = cdf(x_)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()

class LogCDF(nn.Cell):
"""
Test class: log cdf of categorical distributions.
"""
def __init__(self):
super(LogCDF, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)

def construct(self, x_):
return self.c.log_cdf(x_)

def test_logcdf():
"""
Test log_cdf.
"""
expect_logcdf = np.log([0.7, 0.7, 1, 0.7, 1])
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32)
logcdf = LogCDF()
output = logcdf(x_)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()


class SF(nn.Cell):
"""
Test class: survival function of categorical distributions.
"""
def __init__(self):
super(SF, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)

def construct(self, x_):
return self.c.survival_function(x_)

def test_survival():
"""
Test survival funciton.
"""
expect_survival = [0.3, 0., 0., 0.3, 0.3]
x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(np.int32), dtype=dtype.float32)
sf = SF()
output = sf(x_)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_survival) < tol).all()


class LogSF(nn.Cell):
"""
Test class: log survival function of categorical distributions.
"""
def __init__(self):
super(LogSF, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)

def construct(self, x_):
return self.c.log_survival(x_)

def test_log_survival():
"""
Test log survival funciton.
"""
expect_logsurvival = np.log([1., 0.3, 0.3, 0.3, 0.3])
x_ = Tensor(np.array([-0.1, 0, 0, 0.5, 0.5]).astype(np.float32), dtype=dtype.float32)
log_sf = LogSF()
output = log_sf(x_)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()

class EntropyH(nn.Cell):
"""
Test class: entropy of categorical distributions.
"""
def __init__(self):
super(EntropyH, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)

def construct(self):
return self.c.entropy()

def test_entropy():
"""
Test entropy.
"""
cat_benchmark = stats.multinomial(n=1, p=[0.7, 0.3])
expect_entropy = cat_benchmark.entropy().astype(np.float32)
entropy = EntropyH()
output = entropy()
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()

class CrossEntropy(nn.Cell):
"""
Test class: cross entropy between categorical distributions.
"""
def __init__(self):
super(CrossEntropy, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)

def construct(self, x_):
entropy = self.c.entropy()
kl_loss = self.c.kl_loss('Categorical', x_)
h_sum_kl = entropy + kl_loss
cross_entropy = self.c.cross_entropy('Categorical', x_)
return h_sum_kl - cross_entropy

def test_cross_entropy():
"""
Test cross_entropy.
"""
cross_entropy = CrossEntropy()
prob = Tensor([0.7, 0.3], dtype=dtype.float32)
diff = cross_entropy(prob)
tol = 1e-6
assert (np.abs(diff.asnumpy()) < tol).all()

+ 249
- 0
tests/ut/python/nn/probability/distribution/test_categorical.py View File

@@ -0,0 +1,249 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test nn.probability.distribution.Categorical.
"""
import numpy as np
import pytest

import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype
from mindspore import Tensor


def test_arguments():
"""
Args passing during initialization.
"""
c = msd.Categorical()
assert isinstance(c, msd.Distribution)
c = msd.Categorical([0.1, 0.9], dtype=dtype.int32)
assert isinstance(c, msd.Distribution)


def test_type():
with pytest.raises(TypeError):
msd.Categorical([0.1], dtype=dtype.bool_)


def test_name():
with pytest.raises(TypeError):
msd.Categorical([0.1], name=1.0)


def test_seed():
with pytest.raises(TypeError):
msd.Categorical([0.1], seed='seed')


def test_prob():
"""
Invalid probability.
"""
with pytest.raises(ValueError):
msd.Categorical([-0.1], dtype=dtype.int32)
with pytest.raises(ValueError):
msd.Categorical([1.1], dtype=dtype.int32)
with pytest.raises(ValueError):
msd.Categorical([0.0], dtype=dtype.int32)
with pytest.raises(ValueError):
msd.Categorical([1.0], dtype=dtype.int32)

def test_categorical_sum():
"""
Invaild probabilities.
"""
with pytest.raises(ValueError):
msd.Categorical([[0.1, 0.2], [0.4, 0.6]], dtype=dtype.int32)
with pytest.raises(ValueError):
msd.Categorical([[0.5, 0.7], [0.6, 0.6]], dtype=dtype.int32)

def rank():
"""
Rank dimenshion less than 1.
"""
with pytest.raises(ValueError):
msd.Categorical(0.2, dtype=dtype.int32)
with pytest.raises(ValueError):
msd.Categorical(np.array(0.3).astype(np.float32), dtype=dtype.int32)
with pytest.raises(ValueError):
msd.Categorical(Tensor(np.array(0.3).astype(np.float32)), dtype=dtype.int32)

class CategoricalProb(nn.Cell):
"""
Categorical distribution: initialize with probs.
"""

def __init__(self):
super(CategoricalProb, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)

def construct(self, value):
prob = self.c.prob(value)
log_prob = self.c.log_prob(value)
cdf = self.c.cdf(value)
log_cdf = self.c.log_cdf(value)
sf = self.c.survival_function(value)
log_sf = self.c.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf


def test_categorical_prob():
"""
Test probability functions: passing value through construct.
"""
net = CategoricalProb()
value = Tensor([0, 1, 0, 1, 0], dtype=dtype.float32)
ans = net(value)
assert isinstance(ans, Tensor)


class CategoricalProb1(nn.Cell):
"""
Categorical distribution: initialize without probs.
"""

def __init__(self):
super(CategoricalProb1, self).__init__()
self.c = msd.Categorical(dtype=dtype.int32)

def construct(self, value, probs):
prob = self.c.prob(value, probs)
log_prob = self.c.log_prob(value, probs)
cdf = self.c.cdf(value, probs)
log_cdf = self.c.log_cdf(value, probs)
sf = self.c.survival_function(value, probs)
log_sf = self.c.log_survival(value, probs)
return prob + log_prob + cdf + log_cdf + sf + log_sf


def test_categorical_prob1():
"""
Test probability functions: passing value/probs through construct.
"""
net = CategoricalProb1()
value = Tensor([0, 1, 0, 1, 0], dtype=dtype.float32)
probs = Tensor([0.3, 0.7], dtype=dtype.float32)
ans = net(value, probs)
assert isinstance(ans, Tensor)


class CategoricalKl(nn.Cell):
"""
Test class: kl_loss between Categorical distributions.
"""

def __init__(self):
super(CategoricalKl, self).__init__()
self.c1 = msd.Categorical([0.2, 0.2, 0.6], dtype=dtype.int32)
self.c2 = msd.Categorical(dtype=dtype.int32)

def construct(self, probs_b, probs_a):
kl1 = self.c1.kl_loss('Categorical', probs_b)
kl2 = self.c2.kl_loss('Categorical', probs_b, probs_a)
return kl1 + kl2


def test_kl():
"""
Test kl_loss function.
"""
ber_net = CategoricalKl()
probs_b = Tensor([0.3, 0.1, 0.6], dtype=dtype.float32)
probs_a = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32)
ans = ber_net(probs_b, probs_a)
assert isinstance(ans, Tensor)


class CategoricalCrossEntropy(nn.Cell):
"""
Test class: cross_entropy of Categorical distribution.
"""

def __init__(self):
super(CategoricalCrossEntropy, self).__init__()
self.c1 = msd.Categorical([0.1, 0.7, 0.2], dtype=dtype.int32)
self.c2 = msd.Categorical(dtype=dtype.int32)

def construct(self, probs_b, probs_a):
h1 = self.c1.cross_entropy('Categorical', probs_b)
h2 = self.c2.cross_entropy('Categorical', probs_b, probs_a)
return h1 + h2


def test_cross_entropy():
"""
Test cross_entropy between Categorical distributions.
"""
net = CategoricalCrossEntropy()
probs_b = Tensor([0.3, 0.1, 0.6], dtype=dtype.float32)
probs_a = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32)
ans = net(probs_b, probs_a)
assert isinstance(ans, Tensor)


class CategoricalConstruct(nn.Cell):
"""
Categorical distribution: going through construct.
"""

def __init__(self):
super(CategoricalConstruct, self).__init__()
self.c = msd.Categorical([0.1, 0.8, 0.1], dtype=dtype.int32)
self.c1 = msd.Categorical(dtype=dtype.int32)

def construct(self, value, probs):
prob = self.c('prob', value)
prob1 = self.c('prob', value, probs)
prob2 = self.c1('prob', value, probs)
return prob + prob1 + prob2

def test_categorical_construct():
"""
Test probability function going through construct.
"""
net = CategoricalConstruct()
value = Tensor([0, 1, 2, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5, 0.4, 0.1], dtype=dtype.float32)
ans = net(value, probs)
assert isinstance(ans, Tensor)


class CategoricalBasics(nn.Cell):
"""
Test class: basic mean/var/mode/entropy function.
"""

def __init__(self):
super(CategoricalBasics, self).__init__()
self.c = msd.Categorical([0.2, 0.7, 0.1], dtype=dtype.int32)
self.c1 = msd.Categorical(dtype=dtype.int32)

def construct(self, probs):
basics1 = self.c.mean() + self.c.var() + self.c.mode() + self.c.entropy()
basics2 = self.c1.mean(probs) + self.c1.var(probs) +\
self.c1.mode(probs) + self.c1.entropy(probs)
return basics1 + basics2


def test_basics():
"""
Test basics functionality of Categorical distribution.
"""
net = CategoricalBasics()
probs = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32)
ans = net(probs)
assert isinstance(ans, Tensor)

Loading…
Cancel
Save