Browse Source

Complement the arg passing conventions in distribution and bijector base classes

tags/v0.7.0-beta
peixu_ren 5 years ago
parent
commit
60bb6bebda
3 changed files with 85 additions and 91 deletions
  1. +15
    -15
      mindspore/nn/probability/bijector/bijector.py
  2. +68
    -71
      mindspore/nn/probability/distribution/distribution.py
  3. +2
    -5
      mindspore/nn/probability/distribution/normal.py

+ 15
- 15
mindspore/nn/probability/bijector/bijector.py View File

@@ -69,31 +69,31 @@ class Bijector(Cell):
def is_injective(self): def is_injective(self):
return self._is_injective return self._is_injective


def forward(self, *args):
def forward(self, *args, **kwargs):
""" """
Forward transformation: transform the input value to another distribution. Forward transformation: transform the input value to another distribution.
""" """
return self._forward(*args)
return self._forward(*args, **kwargs)


def inverse(self, *args):
def inverse(self, *args, **kwargs):
""" """
Inverse transformation: transform the input value back to the original distribution. Inverse transformation: transform the input value back to the original distribution.
""" """
return self._inverse(*args)
return self._inverse(*args, **kwargs)


def forward_log_jacobian(self, *args):
def forward_log_jacobian(self, *args, **kwargs):
""" """
Logarithm of the derivative of forward transformation. Logarithm of the derivative of forward transformation.
""" """
return self._forward_log_jacobian(*args)
return self._forward_log_jacobian(*args, **kwargs)


def inverse_log_jacobian(self, *args):
def inverse_log_jacobian(self, *args, **kwargs):
""" """
Logarithm of the derivative of forward transformation. Logarithm of the derivative of forward transformation.
""" """
return self._inverse_log_jacobian(*args)
return self._inverse_log_jacobian(*args, **kwargs)


def __call__(self, *args):
def __call__(self, *args, **kwargs):
""" """
Call Bijector directly. Call Bijector directly.
This __call__ may go into two directions: This __call__ may go into two directions:
@@ -107,9 +107,9 @@ class Bijector(Cell):
""" """
if isinstance(args[0], Distribution): if isinstance(args[0], Distribution):
return TransformedDistribution(self, args[0]) return TransformedDistribution(self, args[0])
return super(Bijector, self).__call__(*args)
return super(Bijector, self).__call__(*args, **kwargs)


def construct(self, name, *args):
def construct(self, name, *args, **kwargs):
""" """
Override construct in Cell. Override construct in Cell.


@@ -120,11 +120,11 @@ class Bijector(Cell):
Always raise RuntimeError as Distribution should not be called directly. Always raise RuntimeError as Distribution should not be called directly.
""" """
if name == 'forward': if name == 'forward':
return self.forward(*args)
return self.forward(*args, **kwargs)
if name == 'inverse': if name == 'inverse':
return self.inverse(*args)
return self.inverse(*args, **kwargs)
if name == 'forward_log_jacobian': if name == 'forward_log_jacobian':
return self.forward_log_jacobian(*args)
return self.forward_log_jacobian(*args, **kwargs)
if name == 'inverse_log_jacobian': if name == 'inverse_log_jacobian':
return self.inverse_log_jacobian(*args)
return self.inverse_log_jacobian(*args, **kwargs)
return None return None

+ 68
- 71
mindspore/nn/probability/distribution/distribution.py View File

@@ -27,7 +27,7 @@ class Distribution(Cell):


Note: Note:
Derived class should override operations such as ,_mean, _prob, Derived class should override operations such as ,_mean, _prob,
and _log_prob. Arguments should be passed in through *args.
and _log_prob. Arguments should be passed in through *args or **kwargs.


Dist_spec_args are unique for each type of distribution. For example, mean and sd Dist_spec_args are unique for each type of distribution. For example, mean and sd
are the dist_spec_args for a Normal distribution. are the dist_spec_args for a Normal distribution.
@@ -171,7 +171,7 @@ class Distribution(Cell):
if hasattr(self, '_cross_entropy'): if hasattr(self, '_cross_entropy'):
self._call_cross_entropy = self._cross_entropy self._call_cross_entropy = self._cross_entropy


def log_prob(self, *args):
def log_prob(self, *args, **kwargs):
""" """
Evaluate the log probability(pdf or pmf) at the given value. Evaluate the log probability(pdf or pmf) at the given value.


@@ -179,18 +179,18 @@ class Distribution(Cell):
Args must include value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_log_prob(*args)
return self._call_log_prob(*args, **kwargs)


def _calc_prob_from_log_prob(self, *args):
def _calc_prob_from_log_prob(self, *args, **kwargs):
r""" r"""
Evaluate prob from log probability. Evaluate prob from log probability.


.. math:: .. math::
probability(x) = \exp(log_likehood(x)) probability(x) = \exp(log_likehood(x))
""" """
return self.exp(self._log_prob(*args))
return self.exp(self._log_prob(*args, **kwargs))


def prob(self, *args):
def prob(self, *args, **kwargs):
""" """
Evaluate the probability (pdf or pmf) at given value. Evaluate the probability (pdf or pmf) at given value.


@@ -198,18 +198,18 @@ class Distribution(Cell):
Args must include value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_prob(*args)
return self._call_prob(*args, **kwargs)


def _calc_log_prob_from_prob(self, *args):
def _calc_log_prob_from_prob(self, *args, **kwargs):
r""" r"""
Evaluate log probability from probability. Evaluate log probability from probability.


.. math:: .. math::
log_prob(x) = \log(prob(x)) log_prob(x) = \log(prob(x))
""" """
return self.log(self._prob(*args))
return self.log(self._prob(*args, **kwargs))


def cdf(self, *args):
def cdf(self, *args, **kwargs):
""" """
Evaluate the cdf at given value. Evaluate the cdf at given value.


@@ -217,36 +217,36 @@ class Distribution(Cell):
Args must include value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_cdf(*args)
return self._call_cdf(*args, **kwargs)


def _calc_cdf_from_log_cdf(self, *args):
def _calc_cdf_from_log_cdf(self, *args, **kwargs):
r""" r"""
Evaluate cdf from log_cdf. Evaluate cdf from log_cdf.


.. math:: .. math::
cdf(x) = \exp(log_cdf(x)) cdf(x) = \exp(log_cdf(x))
""" """
return self.exp(self._log_cdf(*args))
return self.exp(self._log_cdf(*args, **kwargs))


def _calc_cdf_from_survival(self, *args):
def _calc_cdf_from_survival(self, *args, **kwargs):
r""" r"""
Evaluate cdf from survival function. Evaluate cdf from survival function.


.. math:: .. math::
cdf(x) = 1 - (survival_function(x)) cdf(x) = 1 - (survival_function(x))
""" """
return 1.0 - self._survival_function(*args)
return 1.0 - self._survival_function(*args, **kwargs)


def _calc_cdf_from_log_survival(self, *args):
def _calc_cdf_from_log_survival(self, *args, **kwargs):
r""" r"""
Evaluate cdf from log survival function. Evaluate cdf from log survival function.


.. math:: .. math::
cdf(x) = 1 - (\exp(log_survival(x))) cdf(x) = 1 - (\exp(log_survival(x)))
""" """
return 1.0 - self.exp(self._log_survival(*args))
return 1.0 - self.exp(self._log_survival(*args, **kwargs))


def log_cdf(self, *args):
def log_cdf(self, *args, **kwargs):
""" """
Evaluate the log cdf at given value. Evaluate the log cdf at given value.


@@ -254,18 +254,18 @@ class Distribution(Cell):
Args must include value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_log_cdf(*args)
return self._call_log_cdf(*args, **kwargs)


def _calc_log_cdf_from_call_cdf(self, *args):
def _calc_log_cdf_from_call_cdf(self, *args, **kwargs):
r""" r"""
Evaluate log cdf from cdf. Evaluate log cdf from cdf.


.. math:: .. math::
log_cdf(x) = \log(cdf(x)) log_cdf(x) = \log(cdf(x))
""" """
return self.log(self._call_cdf(*args))
return self.log(self._call_cdf(*args, **kwargs))


def survival_function(self, *args):
def survival_function(self, *args, **kwargs):
""" """
Evaluate the survival function at given value. Evaluate the survival function at given value.


@@ -273,27 +273,27 @@ class Distribution(Cell):
Args must include value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_survival(*args)
return self._call_survival(*args, **kwargs)


def _calc_survival_from_call_cdf(self, *args):
def _calc_survival_from_call_cdf(self, *args, **kwargs):
r""" r"""
Evaluate survival function from cdf. Evaluate survival function from cdf.


.. math:: .. math::
survival_function(x) = 1 - (cdf(x)) survival_function(x) = 1 - (cdf(x))
""" """
return 1.0 - self._call_cdf(*args)
return 1.0 - self._call_cdf(*args, **kwargs)


def _calc_survival_from_log_survival(self, *args):
def _calc_survival_from_log_survival(self, *args, **kwargs):
r""" r"""
Evaluate survival function from log survival function. Evaluate survival function from log survival function.


.. math:: .. math::
survival(x) = \exp(survival_function(x)) survival(x) = \exp(survival_function(x))
""" """
return self.exp(self._log_survival(*args))
return self.exp(self._log_survival(*args, **kwargs))


def log_survival(self, *args):
def log_survival(self, *args, **kwargs):
""" """
Evaluate the log survival function at given value. Evaluate the log survival function at given value.


@@ -301,18 +301,18 @@ class Distribution(Cell):
Args must include value. Args must include value.
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_log_survival(*args)
return self._call_log_survival(*args, **kwargs)


def _calc_log_survival_from_call_survival(self, *args):
def _calc_log_survival_from_call_survival(self, *args, **kwargs):
r""" r"""
Evaluate log survival function from survival function. Evaluate log survival function from survival function.


.. math:: .. math::
log_survival(x) = \log(survival_function(x)) log_survival(x) = \log(survival_function(x))
""" """
return self.log(self._call_survival(*args))
return self.log(self._call_survival(*args, **kwargs))


def kl_loss(self, *args):
def kl_loss(self, *args, **kwargs):
""" """
Evaluate the KL divergence, i.e. KL(a||b). Evaluate the KL divergence, i.e. KL(a||b).


@@ -320,72 +320,72 @@ class Distribution(Cell):
Args must include type of the distribution, parameters of distribution b. Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional. Parameters for distribution a are optional.
""" """
return self._kl_loss(*args)
return self._kl_loss(*args, **kwargs)


def mean(self, *args):
def mean(self, *args, **kwargs):
""" """
Evaluate the mean. Evaluate the mean.


Note: Note:
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._mean(*args)
return self._mean(*args, **kwargs)


def mode(self, *args):
def mode(self, *args, **kwargs):
""" """
Evaluate the mode. Evaluate the mode.


Note: Note:
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._mode(*args)
return self._mode(*args, **kwargs)


def sd(self, *args):
def sd(self, *args, **kwargs):
""" """
Evaluate the standard deviation. Evaluate the standard deviation.


Note: Note:
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_sd(*args)
return self._call_sd(*args, **kwargs)


def var(self, *args):
def var(self, *args, **kwargs):
""" """
Evaluate the variance. Evaluate the variance.


Note: Note:
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._call_var(*args)
return self._call_var(*args, **kwargs)


def _calc_sd_from_var(self, *args):
def _calc_sd_from_var(self, *args, **kwargs):
r""" r"""
Evaluate log probability from probability. Evaluate log probability from probability.


.. math:: .. math::
STD(x) = \sqrt(VAR(x)) STD(x) = \sqrt(VAR(x))
""" """
return self.sqrt(self._var(*args))
return self.sqrt(self._var(*args, **kwargs))


def _calc_var_from_sd(self, *args):
def _calc_var_from_sd(self, *args, **kwargs):
r""" r"""
Evaluate log probability from probability. Evaluate log probability from probability.


.. math:: .. math::
VAR(x) = STD(x) ^ 2 VAR(x) = STD(x) ^ 2
""" """
return self.sq(self._sd(*args))
return self.sq(self._sd(*args, **kwargs))


def entropy(self, *args):
def entropy(self, *args, **kwargs):
""" """
Evaluate the entropy. Evaluate the entropy.


Note: Note:
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._entropy(*args)
return self._entropy(*args, **kwargs)


def cross_entropy(self, *args):
def cross_entropy(self, *args, **kwargs):
""" """
Evaluate the cross_entropy between distribution a and b. Evaluate the cross_entropy between distribution a and b.


@@ -393,32 +393,29 @@ class Distribution(Cell):
Args must include type of the distribution, parameters of distribution b. Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional. Parameters for distribution a are optional.
""" """
return self._call_cross_entropy(*args)
return self._call_cross_entropy(*args, **kwargs)


def _calc_cross_entropy(self, *args):
def _calc_cross_entropy(self, *args, **kwargs):
r""" r"""
Evaluate cross_entropy from entropy and kl divergence. Evaluate cross_entropy from entropy and kl divergence.


.. math:: .. math::
H(X, Y) = H(X) + KL(X||Y) H(X, Y) = H(X) + KL(X||Y)
""" """
return self._entropy(*args) + self._kl_loss(*args)
return self._entropy(*args, **kwargs) + self._kl_loss(*args, **kwargs)


def sample(self, *args):
def sample(self, *args, **kwargs):
""" """
Sampling function. Sampling function.


Args:
*args (list): arguments passed in through construct.

Note: Note:
Shape of the sample is default to (). Shape of the sample is default to ().
Dist_spec_args are optional. Dist_spec_args are optional.
""" """
return self._sample(*args)
return self._sample(*args, **kwargs)




def construct(self, name, *args):
def construct(self, name, *args, **kwargs):
""" """
Override construct in Cell. Override construct in Cell.


@@ -433,31 +430,31 @@ class Distribution(Cell):
""" """


if name == 'log_prob': if name == 'log_prob':
return self._call_log_prob(*args)
return self._call_log_prob(*args, **kwargs)
if name == 'prob': if name == 'prob':
return self._call_prob(*args)
return self._call_prob(*args, **kwargs)
if name == 'cdf': if name == 'cdf':
return self._call_cdf(*args)
return self._call_cdf(*args, **kwargs)
if name == 'log_cdf': if name == 'log_cdf':
return self._call_log_cdf(*args)
return self._call_log_cdf(*args, **kwargs)
if name == 'survival_function': if name == 'survival_function':
return self._call_survival(*args)
return self._call_survival(*args, **kwargs)
if name == 'log_survival': if name == 'log_survival':
return self._call_log_survival(*args)
return self._call_log_survival(*args, **kwargs)
if name == 'kl_loss': if name == 'kl_loss':
return self._kl_loss(*args)
return self._kl_loss(*args, **kwargs)
if name == 'mean': if name == 'mean':
return self._mean(*args)
return self._mean(*args, **kwargs)
if name == 'mode': if name == 'mode':
return self._mode(*args)
return self._mode(*args, **kwargs)
if name == 'sd': if name == 'sd':
return self._call_sd(*args)
return self._call_sd(*args, **kwargs)
if name == 'var': if name == 'var':
return self._call_var(*args)
return self._call_var(*args, **kwargs)
if name == 'entropy': if name == 'entropy':
return self._entropy(*args)
return self._entropy(*args, **kwargs)
if name == 'cross_entropy': if name == 'cross_entropy':
return self._call_cross_entropy(*args)
return self._call_cross_entropy(*args, **kwargs)
if name == 'sample': if name == 'sample':
return self._sample(*args)
return self._sample(*args, **kwargs)
return None return None

+ 2
- 5
mindspore/nn/probability/distribution/normal.py View File

@@ -256,8 +256,5 @@ class Normal(Distribution):
sd = self._sd_value if sd is None else sd sd = self._sd_value if sd is None else sd
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd)) batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
sample_shape = shape + batch_shape sample_shape = shape + batch_shape
mean_zero = self.const(0.0)
sd_one = self.const(1.0)
sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed)
sample = mean + sample_norm * sd
return sample
sample_norm = C.normal(sample_shape, mean, sd, self.seed)
return sample_norm

Loading…
Cancel
Save