Browse Source

!4707 Fixed bugs in cast_to_tensor and added more type check into distribution classes

Merge pull request !4707 from XunDeng/pp_poc_v3
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
2159c952d7
6 changed files with 319 additions and 140 deletions
  1. +37
    -16
      mindspore/nn/probability/distribution/_utils/utils.py
  2. +49
    -21
      mindspore/nn/probability/distribution/bernoulli.py
  3. +42
    -17
      mindspore/nn/probability/distribution/exponential.py
  4. +47
    -34
      mindspore/nn/probability/distribution/geometric.py
  5. +75
    -26
      mindspore/nn/probability/distribution/normal.py
  6. +69
    -26
      mindspore/nn/probability/distribution/uniform.py

+ 37
- 16
mindspore/nn/probability/distribution/_utils/utils.py View File

@@ -15,6 +15,7 @@
"""Utitly functions to help distribution class.""" """Utitly functions to help distribution class."""
import numpy as np import numpy as np
from mindspore.ops import _utils as utils from mindspore.ops import _utils as utils
from mindspore.ops.primitive import constexpr
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
@@ -23,7 +24,7 @@ from mindspore.ops import composite as C
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability as msp import mindspore.nn.probability as msp


def cast_to_tensor(t, hint_dtype=mstype.float32):
def cast_to_tensor(t, hint_type=mstype.float32):
""" """
Cast an user input value into a Tensor of dtype. Cast an user input value into a Tensor of dtype.
If the input t is of type Parameter, t is directly returned as a Parameter. If the input t is of type Parameter, t is directly returned as a Parameter.
@@ -38,24 +39,27 @@ def cast_to_tensor(t, hint_dtype=mstype.float32):
Returns: Returns:
Tensor. Tensor.
""" """
if t is None:
raise ValueError(f'Input cannot be None in cast_to_tensor')
if isinstance(t, Parameter): if isinstance(t, Parameter):
return t return t
t_type = hint_type
if isinstance(t, Tensor): if isinstance(t, Tensor):
if t.dtype != hint_dtype:
raise TypeError(f"Input tensor should be type {hint_dtype}.")
#check if the Tensor in shape of Tensor(4) #check if the Tensor in shape of Tensor(4)
if t.dim() == 0: if t.dim() == 0:
value = t.asnumpy() value = t.asnumpy()
return Tensor([value], dtype=hint_dtype)
return Tensor([value], dtype=t_type)
#convert the type of tensor to dtype #convert the type of tensor to dtype
return t
return Tensor(t.asnumpy(), dtype=t_type)
if isinstance(t, (list, np.ndarray)): if isinstance(t, (list, np.ndarray)):
return Tensor(t, dtype=hint_dtype)
if np.isscalar(t):
return Tensor([t], dtype=hint_dtype)
raise RuntimeError("Input type is not supported.")

def convert_to_batch(t, batch_shape, hint_dtype):
return Tensor(t, dtype=t_type)
if isinstance(t, bool):
raise TypeError(f'Input cannot be Type Bool')
if isinstance(t, (int, float)):
return Tensor([t], dtype=t_type)
raise TypeError("Input type is not supported.")

def convert_to_batch(t, batch_shape, required_type):
""" """
Convert a Tensor to a given batch shape. Convert a Tensor to a given batch shape.


@@ -72,8 +76,8 @@ def convert_to_batch(t, batch_shape, hint_dtype):
""" """
if isinstance(t, Parameter): if isinstance(t, Parameter):
return t return t
t = cast_to_tensor(t, hint_dtype)
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=hint_dtype)
t = cast_to_tensor(t, required_type)
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type)


def check_scalar_from_param(params): def check_scalar_from_param(params):
""" """
@@ -91,7 +95,7 @@ def check_scalar_from_param(params):
return False return False
if isinstance(value, (str, type(params['dtype']))): if isinstance(value, (str, type(params['dtype']))):
continue continue
elif np.isscalar(value):
elif isinstance(value, (int, float)):
continue continue
else: else:
return False return False
@@ -119,10 +123,11 @@ def calc_broadcast_shape_from_param(params):
if isinstance(value, Parameter): if isinstance(value, Parameter):
value_t = value.default_input value_t = value.default_input
else: else:
value_t = cast_to_tensor(value, params['dtype'])
value_t = cast_to_tensor(value, mstype.float32)
broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name']) broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name'])
return tuple(broadcast_shape) return tuple(broadcast_shape)



def check_greater_equal_zero(value, name): def check_greater_equal_zero(value, name):
""" """
Check if the given Tensor is greater zero. Check if the given Tensor is greater zero.
@@ -155,14 +160,17 @@ def check_greater_zero(value, name):
ValueError: if the input value is less than or equal to zero. ValueError: if the input value is less than or equal to zero.


""" """
if value is None:
raise ValueError(f'input value cannot be None in check_greater_zero')
if isinstance(value, Parameter): if isinstance(value, Parameter):
if isinstance(value.default_input, MetaTensor):
if not isinstance(value.default_input, Tensor):
return return
value = value.default_input value = value.default_input
comp = np.less(np.zeros(value.shape), value.asnumpy()) comp = np.less(np.zeros(value.shape), value.asnumpy())
if not comp.all(): if not comp.all():
raise ValueError(f'{name} should be greater than zero.') raise ValueError(f'{name} should be greater than zero.')



def check_greater(a, b, name_a, name_b): def check_greater(a, b, name_a, name_b):
""" """
Check if Tensor b is strictly greater than Tensor a. Check if Tensor b is strictly greater than Tensor a.
@@ -176,6 +184,8 @@ def check_greater(a, b, name_a, name_b):
Raises: Raises:
ValueError: if b is less than or equal to a ValueError: if b is less than or equal to a
""" """
if a is None or b is None:
raise ValueError(f'input value cannot be None in check_greater')
if isinstance(a, Parameter) or isinstance(b, Parameter): if isinstance(a, Parameter) or isinstance(b, Parameter):
return return
comp = np.less(a.asnumpy(), b.asnumpy()) comp = np.less(a.asnumpy(), b.asnumpy())
@@ -193,6 +203,8 @@ def check_prob(p):
Raises: Raises:
ValueError: if p is not a proper probability. ValueError: if p is not a proper probability.
""" """
if p is None:
raise ValueError(f'input value cannot be None in check_greater_zero')
if isinstance(p, Parameter): if isinstance(p, Parameter):
if not isinstance(p.default_input, Tensor): if not isinstance(p.default_input, Tensor):
return return
@@ -259,3 +271,12 @@ def check_tensor_type(name, inputs, valid_type):
def check_type(data_type, value_type, name): def check_type(data_type, value_type, name):
if not data_type in value_type: if not data_type in value_type:
raise TypeError(f"For {name}, valid type include {value_type}, {data_type} is invalid") raise TypeError(f"For {name}, valid type include {value_type}, {data_type} is invalid")

@constexpr
def raise_none_error(name):
raise ValueError(f"{name} should be specified. Value cannot be None")

@constexpr
def check_distribution_name(name, expected_name):
if name != expected_name:
raise ValueError(f"Distribution should be {expected_name}.")

+ 49
- 21
mindspore/nn/probability/distribution/bernoulli.py View File

@@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error


class Bernoulli(Distribution): class Bernoulli(Distribution):
""" """
@@ -99,8 +99,9 @@ class Bernoulli(Distribution):
valid_dtype = mstype.int_type + mstype.uint_type valid_dtype = mstype.int_type + mstype.uint_type
check_type(dtype, valid_dtype, "Bernoulli") check_type(dtype, valid_dtype, "Bernoulli")
super(Bernoulli, self).__init__(seed, dtype, name, param) super(Bernoulli, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32
if probs is not None: if probs is not None:
self._probs = cast_to_tensor(probs, hint_dtype=mstype.float32)
self._probs = cast_to_tensor(probs, mstype.float32)
check_prob(self.probs) check_prob(self.probs)
else: else:
self._probs = probs self._probs = probs
@@ -111,6 +112,7 @@ class Bernoulli(Distribution):
self.dtypeop = P.DType() self.dtypeop = P.DType()
self.erf = P.Erf() self.erf = P.Erf()
self.exp = P.Exp() self.exp = P.Exp()
self.floor = P.Floor()
self.fill = P.Fill() self.fill = P.Fill()
self.log = P.Log() self.log = P.Log()
self.less = P.Less() self.less = P.Less()
@@ -139,14 +141,19 @@ class Bernoulli(Distribution):
.. math:: .. math::
MEAN(B) = probs1 MEAN(B) = probs1
""" """
return self.probs if probs1 is None else probs1
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
if probs1 is None:
raise_none_error("probs1")
return probs1


def _mode(self, probs1=None): def _mode(self, probs1=None):
r""" r"""
.. math:: .. math::
MODE(B) = 1 if probs1 > 0.5 else = 0 MODE(B) = 1 if probs1 > 0.5 else = 0
""" """
probs1 = self.probs if probs1 is None else probs1
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
if probs1 is None:
raise_none_error("probs1")
prob_type = self.dtypeop(probs1) prob_type = self.dtypeop(probs1)
zeros = self.fill(prob_type, self.shape(probs1), 0.0) zeros = self.fill(prob_type, self.shape(probs1), 0.0)
ones = self.fill(prob_type, self.shape(probs1), 1.0) ones = self.fill(prob_type, self.shape(probs1), 1.0)
@@ -158,7 +165,9 @@ class Bernoulli(Distribution):
.. math:: .. math::
VAR(B) = probs1 * probs0 VAR(B) = probs1 * probs0
""" """
probs1 = self.probs if probs1 is None else probs1
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
if probs1 is None:
raise_none_error("probs1")
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return self.exp(self.log(probs0) + self.log(probs1)) return self.exp(self.log(probs0) + self.log(probs1))


@@ -167,7 +176,9 @@ class Bernoulli(Distribution):
.. math:: .. math::
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
""" """
probs1 = self.probs if probs is None else probs
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
probs0 = 1 - probs1 probs0 = 1 - probs1
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))


@@ -180,9 +191,8 @@ class Bernoulli(Distribution):
probs1_b (Tensor): probs1 of distribution b. probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs. probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
""" """
if dist == 'Bernoulli':
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
return None
check_distribution_name(dist, 'Bernoulli')
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)


def _log_prob(self, value, probs=None): def _log_prob(self, value, probs=None):
r""" r"""
@@ -196,7 +206,13 @@ class Bernoulli(Distribution):
pmf(k) = probs1 if k = 1; pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0; pmf(k) = probs0 if k = 0;
""" """
probs1 = self.probs if probs is None else probs
if value is None:
raise_none_error("value")
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return self.log(probs1) * value + self.log(probs0) * (1.0 - value) return self.log(probs1) * value + self.log(probs0) * (1.0 - value)


@@ -213,7 +229,13 @@ class Bernoulli(Distribution):
cdf(k) = probs0 if 0 <= k <1; cdf(k) = probs0 if 0 <= k <1;
cdf(k) = 1 if k >=1; cdf(k) = 1 if k >=1;
""" """
probs1 = self.probs if probs is None else probs
if value is None:
raise_none_error("value")
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
prob_type = self.dtypeop(probs1) prob_type = self.dtypeop(probs1)
value = value * self.fill(prob_type, self.shape(probs1), 1.0) value = value * self.fill(prob_type, self.shape(probs1), 1.0)
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
@@ -230,19 +252,23 @@ class Bernoulli(Distribution):


Args: Args:
dist (str): type of the distributions. Should be "Bernoulli" in this case. dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
probs1_b (Tensor, Number): probs1 of distribution b.
probs1_a (Tensor, Number): probs1 of distribution a. Default: self.probs.


.. math:: .. math::
KL(a||b) = probs1_a * \log(\frac{probs1_a}{probs1_b}) + KL(a||b) = probs1_a * \log(\frac{probs1_a}{probs1_b}) +
probs0_a * \log(\frac{probs0_a}{probs0_b}) probs0_a * \log(\frac{probs0_a}{probs0_b})
""" """
if dist == 'Bernoulli':
probs1_a = self.probs if probs1_a is None else probs1_a
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
return None
check_distribution_name(dist, 'Bernoulli')
if probs1_b is None:
raise_none_error("probs1_b")
probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self.cast(probs1_a, self.parameter_type) if probs1_a is not None else self.probs
if probs1_a is None:
raise_none_error("probs1_a")
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)


def _sample(self, shape=(), probs=None): def _sample(self, shape=(), probs=None):
""" """
@@ -250,12 +276,14 @@ class Bernoulli(Distribution):


Args: Args:
shape (tuple): shape of the sample. Default: (). shape (tuple): shape of the sample. Default: ().
probs (Tensor): probs1 of the samples. Default: self.probs.
probs (Tensor, Number): probs1 of the samples. Default: self.probs.


Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
probs1 = self.probs if probs is None else probs
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
l_zero = self.const(0.0) l_zero = self.const(0.0)
h_one = self.const(1.0) h_one = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one, self.seed) sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one, self.seed)


+ 42
- 17
mindspore/nn/probability/distribution/exponential.py View File

@@ -18,7 +18,8 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error


class Exponential(Distribution): class Exponential(Distribution):
""" """
@@ -100,8 +101,9 @@ class Exponential(Distribution):
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Exponential") check_type(dtype, valid_dtype, "Exponential")
super(Exponential, self).__init__(seed, dtype, name, param) super(Exponential, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype
if rate is not None: if rate is not None:
self._rate = cast_to_tensor(rate, dtype)
self._rate = cast_to_tensor(rate, self.parameter_type)
check_greater_zero(self._rate, "rate") check_greater_zero(self._rate, "rate")
else: else:
self._rate = rate self._rate = rate
@@ -141,16 +143,19 @@ class Exponential(Distribution):
.. math:: .. math::
MEAN(EXP) = \frac{1.0}{\lambda}. MEAN(EXP) = \frac{1.0}{\lambda}.
""" """
rate = self.rate if rate is None else rate
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
return 1.0 / rate return 1.0 / rate



def _mode(self, rate=None): def _mode(self, rate=None):
r""" r"""
.. math:: .. math::
MODE(EXP) = 0. MODE(EXP) = 0.
""" """
rate = self.rate if rate is None else rate
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
return self.fill(self.dtype, self.shape(rate), 0.) return self.fill(self.dtype, self.shape(rate), 0.)


def _sd(self, rate=None): def _sd(self, rate=None):
@@ -158,7 +163,9 @@ class Exponential(Distribution):
.. math:: .. math::
sd(EXP) = \frac{1.0}{\lambda}. sd(EXP) = \frac{1.0}{\lambda}.
""" """
rate = self.rate if rate is None else rate
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
return 1.0 / rate return 1.0 / rate


def _entropy(self, rate=None): def _entropy(self, rate=None):
@@ -166,7 +173,9 @@ class Exponential(Distribution):
.. math:: .. math::
H(Exp) = 1 - \log(\lambda). H(Exp) = 1 - \log(\lambda).
""" """
rate = self.rate if rate is None else rate
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
return 1.0 - self.log(rate) return 1.0 - self.log(rate)




@@ -179,9 +188,9 @@ class Exponential(Distribution):
rate_b (Tensor): rate of distribution b. rate_b (Tensor): rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self.rate. rate_a (Tensor): rate of distribution a. Default: self.rate.
""" """
if dist == 'Exponential':
return self._entropy(rate=rate_a) + self._kl_loss(dist, rate_b, rate_a)
return None
check_distribution_name(dist, 'Exponential')
return self._entropy(rate=rate_a) + self._kl_loss(dist, rate_b, rate_a)


def _prob(self, value, rate=None): def _prob(self, value, rate=None):
r""" r"""
@@ -198,7 +207,12 @@ class Exponential(Distribution):
.. math:: .. math::
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
""" """
rate = self.rate if rate is None else rate
if value is None:
raise_none_error("value")
value = self.cast(value, self.dtype)
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
prob = self.exp(self.log(rate) - rate * value) prob = self.exp(self.log(rate) - rate * value)
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
@@ -218,7 +232,12 @@ class Exponential(Distribution):
.. math:: .. math::
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
""" """
rate = self.rate if rate is None else rate
if value is None:
raise_none_error("value")
value = self.cast(value, self.dtype)
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
cdf = 1.0 - self.exp(-1. * rate * value) cdf = 1.0 - self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
@@ -234,10 +253,14 @@ class Exponential(Distribution):
rate_b (Tensor): rate of distribution b. rate_b (Tensor): rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self.rate. rate_a (Tensor): rate of distribution a. Default: self.rate.
""" """
if dist == 'Exponential':
rate_a = self.rate if rate_a is None else rate_a
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
return None
check_distribution_name(dist, 'Exponential')
if rate_b is None:
raise_none_error("rate_b")
rate_b = self.cast(rate_b, self.parameter_type)
rate_a = self.cast(rate_a, self.parameter_type) if rate_a is not None else self.rate
if rate_a is None:
raise_none_error("rate_a")
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0


def _sample(self, shape=(), rate=None): def _sample(self, shape=(), rate=None):
""" """
@@ -250,7 +273,9 @@ class Exponential(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
rate = self.rate if rate is None else rate
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
minval = self.const(self.minval) minval = self.const(self.minval)
maxval = self.const(1.0) maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(rate), minval, maxval, self.seed) sample_uniform = self.uniform(shape + self.shape(rate), minval, maxval, self.seed)


+ 47
- 34
mindspore/nn/probability/distribution/geometric.py View File

@@ -18,7 +18,8 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
raise_none_error


class Geometric(Distribution): class Geometric(Distribution):
""" """
@@ -101,8 +102,9 @@ class Geometric(Distribution):
valid_dtype = mstype.int_type + mstype.uint_type valid_dtype = mstype.int_type + mstype.uint_type
check_type(dtype, valid_dtype, "Geometric") check_type(dtype, valid_dtype, "Geometric")
super(Geometric, self).__init__(seed, dtype, name, param) super(Geometric, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32
if probs is not None: if probs is not None:
self._probs = cast_to_tensor(probs, hint_dtype=mstype.float32)
self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self._probs) check_prob(self._probs)
else: else:
self._probs = probs self._probs = probs
@@ -145,7 +147,9 @@ class Geometric(Distribution):
.. math:: .. math::
MEAN(Geo) = \fratc{1 - probs1}{probs1} MEAN(Geo) = \fratc{1 - probs1}{probs1}
""" """
probs1 = self.probs if probs1 is None else probs1
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
if probs1 is None:
raise_none_error("probs1")
return (1. - probs1) / probs1 return (1. - probs1) / probs1


def _mode(self, probs1=None): def _mode(self, probs1=None):
@@ -153,7 +157,9 @@ class Geometric(Distribution):
.. math:: .. math::
MODE(Geo) = 0 MODE(Geo) = 0
""" """
probs1 = self.probs if probs1 is None else probs1
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
if probs1 is None:
raise_none_error("probs1")
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)


def _var(self, probs1=None): def _var(self, probs1=None):
@@ -161,7 +167,9 @@ class Geometric(Distribution):
.. math:: .. math::
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}} VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
""" """
probs1 = self.probs if probs1 is None else probs1
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
if probs1 is None:
raise_none_error("probs1")
return (1.0 - probs1) / self.sq(probs1) return (1.0 - probs1) / self.sq(probs1)


def _entropy(self, probs=None): def _entropy(self, probs=None):
@@ -169,7 +177,9 @@ class Geometric(Distribution):
.. math:: .. math::
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
""" """
probs1 = self.probs if probs is None else probs
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1


@@ -182,9 +192,8 @@ class Geometric(Distribution):
probs1_b (Tensor): probability of success of distribution b. probs1_b (Tensor): probability of success of distribution b.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs. probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
""" """
if dist == 'Geometric':
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
return None
check_distribution_name(dist, 'Geometric')
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)


def _prob(self, value, probs=None): def _prob(self, value, probs=None):
r""" r"""
@@ -198,14 +207,13 @@ class Geometric(Distribution):
pmf(k) = probs0 ^k * probs1 if k >= 0; pmf(k) = probs0 ^k * probs1 if k >= 0;
pmf(k) = 0 if k < 0. pmf(k) = 0 if k < 0.
""" """
probs1 = self.probs if probs is None else probs
dtype = self.dtypeop(value)
if self.issubclass(dtype, mstype.int_):
pass
elif self.issubclass(dtype, mstype.float_):
value = self.floor(value)
else:
return None
if value is None:
raise_none_error("value")
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
@@ -224,15 +232,14 @@ class Geometric(Distribution):
cdf(k) = 0 if k < 0. cdf(k) = 0 if k < 0.


""" """
probs1 = self.probs if probs is None else probs
if value is None:
raise_none_error("value")
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
dtype = self.dtypeop(value)
if self.issubclass(dtype, mstype.int_):
pass
elif self.issubclass(dtype, mstype.float_):
value = self.floor(value)
else:
return None
cdf = 1.0 - self.pow(probs0, value + 1.0) cdf = 1.0 - self.pow(probs0, value + 1.0)
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
@@ -251,12 +258,16 @@ class Geometric(Distribution):
.. math:: .. math::
KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b}) KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b})
""" """
if dist == 'Geometric':
probs1_a = self.probs if probs1_a is None else probs1_a
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
return None
check_distribution_name(dist, 'Geometric')
if probs1_b is None:
raise_none_error("probs1_b")
probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self.cast(probs1_a, self.parameter_type) if probs1_a is not None else self.probs
if probs1_a is None:
raise_none_error("probs1_a")
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)


def _sample(self, shape=(), probs=None): def _sample(self, shape=(), probs=None):
""" """
@@ -269,9 +280,11 @@ class Geometric(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
probs = self.probs if probs is None else probs
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
minval = self.const(self.minval) minval = self.const(self.minval)
maxval = self.const(1.0) maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval, self.seed)
sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs))
sample_uniform = self.uniform(shape + self.shape(probs1), minval, maxval, self.seed)
sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs1))
return self.cast(sample, self.dtype) return self.cast(sample, self.dtype)

+ 75
- 26
mindspore/nn/probability/distribution/normal.py View File

@@ -18,8 +18,8 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_zero, check_type
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
raise_none_error


class Normal(Distribution): class Normal(Distribution):
""" """
@@ -103,9 +103,10 @@ class Normal(Distribution):
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Normal") check_type(dtype, valid_dtype, "Normal")
super(Normal, self).__init__(seed, dtype, name, param) super(Normal, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype
if mean is not None and sd is not None: if mean is not None and sd is not None:
self._mean_value = convert_to_batch(mean, self.broadcast_shape, dtype)
self._sd_value = convert_to_batch(sd, self.broadcast_shape, dtype)
self._mean_value = convert_to_batch(mean, self.broadcast_shape, self.parameter_type)
self._sd_value = convert_to_batch(sd, self.broadcast_shape, self.parameter_type)
check_greater_zero(self._sd_value, "Standard deviation") check_greater_zero(self._sd_value, "Standard deviation")
else: else:
self._mean_value = mean self._mean_value = mean
@@ -113,6 +114,7 @@ class Normal(Distribution):




#ops needed for the class #ops needed for the class
self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.erf = P.Erf() self.erf = P.Erf()
self.exp = P.Exp() self.exp = P.Exp()
@@ -141,31 +143,51 @@ class Normal(Distribution):
""" """
Mean of the distribution. Mean of the distribution.
""" """
mean = self._mean_value if mean is None or sd is None else mean
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
return mean return mean


def _mode(self, mean=None, sd=None): def _mode(self, mean=None, sd=None):
""" """
Mode of the distribution. Mode of the distribution.
""" """
mean = self._mean_value if mean is None or sd is None else mean
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
return mean return mean


def _sd(self, mean=None, sd=None): def _sd(self, mean=None, sd=None):
""" """
Standard deviation of the distribution. Standard deviation of the distribution.
""" """
sd = self._sd_value if mean is None or sd is None else sd
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
return sd return sd


def _entropy(self, sd=None):
def _entropy(self, mean=None, sd=None):
r""" r"""
Evaluate entropy. Evaluate entropy.


.. math:: .. math::
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
""" """
sd = self._sd_value if sd is None else sd
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)


def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
@@ -179,9 +201,8 @@ class Normal(Distribution):
mean_a (Tensor): mean of distribution a. Default: self._mean_value. mean_a (Tensor): mean of distribution a. Default: self._mean_value.
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
""" """
if dist == 'Normal':
return self._entropy(sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a)
return None
check_distribution_name(dist, 'Normal')
return self._entropy(mean=mean_a, sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a)


def _log_prob(self, value, mean=None, sd=None): def _log_prob(self, value, mean=None, sd=None):
r""" r"""
@@ -195,10 +216,17 @@ class Normal(Distribution):
.. math:: .. math::
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
""" """
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
if value is None:
raise_none_error("value")
value = self.cast(value, self.dtype)
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
neg_normalization = -1. * self.log(self.sqrt(self.const(2. * np.pi))) - self.log(sd)
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
return unnormalized_log_prob + neg_normalization return unnormalized_log_prob + neg_normalization


def _cdf(self, value, mean=None, sd=None): def _cdf(self, value, mean=None, sd=None):
@@ -213,8 +241,15 @@ class Normal(Distribution):
.. math:: .. math::
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
""" """
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
if value is None:
raise_none_error("value")
value = self.cast(value, self.dtype)
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
sqrt2 = self.sqrt(self.const(2.0)) sqrt2 = self.sqrt(self.const(2.0))
adjusted = (value - mean) / (sd * sqrt2) adjusted = (value - mean) / (sd * sqrt2)
return 0.5 * (1.0 + self.erf(adjusted)) return 0.5 * (1.0 + self.erf(adjusted))
@@ -234,13 +269,23 @@ class Normal(Distribution):
KL(a||b) = 0.5 * (\frac{MEAN(a)}{STD(b)} - \frac{MEAN(b)}{STD(b)}) ^ 2 + KL(a||b) = 0.5 * (\frac{MEAN(a)}{STD(b)} - \frac{MEAN(b)}{STD(b)}) ^ 2 +
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
""" """
if dist == 'Normal':
mean_a = self._mean_value if mean_a is None else mean_a
sd_a = self._sd_value if sd_a is None else sd_a
diff_log_scale = self.log(sd_a) - self.log(sd_b)
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
return None
check_distribution_name(dist, 'Normal')
if mean_b is None:
raise_none_error("mean_b")
if sd_b is None:
raise_none_error("sd_b")
mean_b = self.cast(mean_b, self.parameter_type)
sd_b = self.cast(sd_b, self.parameter_type)
mean_a = self.cast(mean_a, self.parameter_type) if mean_a is not None else self._mean_value
sd_a = self.cast(sd_a, self.parameter_type) if sd_a is not None else self._sd_value
if mean_a is None:
raise_none_error("mean_a")
if sd_a is None:
raise_none_error("sd_a")
diff_log_scale = self.log(sd_a) - self.log(sd_b)
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale



def _sample(self, shape=(), mean=None, sd=None): def _sample(self, shape=(), mean=None, sd=None):
""" """
@@ -254,8 +299,12 @@ class Normal(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd)) batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
sample_shape = shape + batch_shape sample_shape = shape + batch_shape
sample_norm = C.normal(sample_shape, mean, sd, self.seed) sample_norm = C.normal(sample_shape, mean, sd, self.seed)


+ 69
- 26
mindspore/nn/probability/distribution/uniform.py View File

@@ -17,7 +17,8 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater, check_type
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
raise_none_error


class Uniform(Distribution): class Uniform(Distribution):
""" """
@@ -101,6 +102,7 @@ class Uniform(Distribution):
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, "Uniform") check_type(dtype, valid_dtype, "Uniform")
super(Uniform, self).__init__(seed, dtype, name, param) super(Uniform, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype
if low is not None and high is not None: if low is not None and high is not None:
self._low = convert_to_batch(low, self.broadcast_shape, dtype) self._low = convert_to_batch(low, self.broadcast_shape, dtype)
self._high = convert_to_batch(high, self.broadcast_shape, dtype) self._high = convert_to_batch(high, self.broadcast_shape, dtype)
@@ -153,8 +155,12 @@ class Uniform(Distribution):
.. math:: .. math::
range(U) = high -low range(U) = high -low
""" """
low = self.low if low is None else low
high = self.high if high is None else high
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
return high - low return high - low


def _mean(self, low=None, high=None): def _mean(self, low=None, high=None):
@@ -162,18 +168,25 @@ class Uniform(Distribution):
.. math:: .. math::
MEAN(U) = \frac{low + high}{2}. MEAN(U) = \frac{low + high}{2}.
""" """
low = self.low if low is None else low
high = self.high if high is None else high
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
return (low + high) / 2. return (low + high) / 2.



def _var(self, low=None, high=None): def _var(self, low=None, high=None):
r""" r"""
.. math:: .. math::
VAR(U) = \frac{(high -low) ^ 2}{12}. VAR(U) = \frac{(high -low) ^ 2}{12}.
""" """
low = self.low if low is None else low
high = self.high if high is None else high
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
return self.sq(high - low) / 12.0 return self.sq(high - low) / 12.0


def _entropy(self, low=None, high=None): def _entropy(self, low=None, high=None):
@@ -181,8 +194,12 @@ class Uniform(Distribution):
.. math:: .. math::
H(U) = \log(high - low). H(U) = \log(high - low).
""" """
low = self.low if low is None else low
high = self.high if high is None else high
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
return self.log(high - low) return self.log(high - low)


def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None): def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None):
@@ -196,9 +213,8 @@ class Uniform(Distribution):
low_a (Tensor): lower bound of distribution a. Default: self.low. low_a (Tensor): lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high. high_a (Tensor): upper bound of distribution a. Default: self.high.
""" """
if dist == 'Uniform':
return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a)
return None
check_distribution_name(dist, 'Uniform')
return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a)


def _prob(self, value, low=None, high=None): def _prob(self, value, low=None, high=None):
r""" r"""
@@ -214,8 +230,15 @@ class Uniform(Distribution):
pdf(x) = \frac{1.0}{high -low} if low <= x <= high; pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
pdf(x) = 0 if x > high; pdf(x) = 0 if x > high;
""" """
low = self.low if low is None else low
high = self.high if high is None else high
if value is None:
raise_none_error("value")
value = self.cast(value, self.dtype)
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
neg_ones = self.fill(self.dtype, self.shape(value), -1.0) neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
prob = self.exp(neg_ones * self.log(high - low)) prob = self.exp(neg_ones * self.log(high - low))
broadcast_shape = self.shape(prob) broadcast_shape = self.shape(prob)
@@ -236,13 +259,22 @@ class Uniform(Distribution):
low_a (Tensor): lower bound of distribution a. Default: self.low. low_a (Tensor): lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high. high_a (Tensor): upper bound of distribution a. Default: self.high.
""" """
if dist == 'Uniform':
low_a = self.low if low_a is None else low_a
high_a = self.high if high_a is None else high_a
kl = self.log(high_b - low_b) / self.log(high_a - low_a)
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
return self.select(comp, kl, self.log(self.zeroslike(kl)))
return None
check_distribution_name(dist, 'Uniform')
if low_b is None:
raise_none_error("low_b")
if high_b is None:
raise_none_error("high_b")
low_b = self.cast(low_b, self.parameter_type)
high_b = self.cast(high_b, self.parameter_type)
low_a = self.cast(low_a, self.parameter_type) if low_a is not None else self.low
if low_a is None:
raise_none_error("low_a")
high_a = self.cast(high_a, self.parameter_type) if high_a is not None else self.high
if high_a is None:
raise_none_error("high_a")
kl = self.log(high_b - low_b) / self.log(high_a - low_a)
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
return self.select(comp, kl, self.log(self.zeroslike(kl)))


def _cdf(self, value, low=None, high=None): def _cdf(self, value, low=None, high=None):
r""" r"""
@@ -258,8 +290,15 @@ class Uniform(Distribution):
cdf(x) = \frac{x - low}{high -low} if low <= x <= high; cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
cdf(x) = 1 if x > high; cdf(x) = 1 if x > high;
""" """
low = self.low if low is None else low
high = self.high if high is None else high
if value is None:
raise_none_error("value")
value = self.cast(value, self.dtype)
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
prob = (value - low) / (high - low) prob = (value - low) / (high - low)
broadcast_shape = self.shape(prob) broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
@@ -281,8 +320,12 @@ class Uniform(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
low = self.low if low is None else low
high = self.high if high is None else high
low = self.cast(low, self.parameter_type) if low is not None else self.low
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
broadcast_shape = self.shape(low + high) broadcast_shape = self.shape(low + high)
l_zero = self.const(0.0) l_zero = self.const(0.0)
h_one = self.const(1.0) h_one = self.const(1.0)


Loading…
Cancel
Save