|
|
|
@@ -27,7 +27,7 @@ class Distribution(Cell): |
|
|
|
|
|
|
|
Note: |
|
|
|
Derived class should override operations such as ,_mean, _prob, |
|
|
|
and _log_prob. Arguments should be passed in through *args. |
|
|
|
and _log_prob. Arguments should be passed in through *args or **kwargs. |
|
|
|
|
|
|
|
Dist_spec_args are unique for each type of distribution. For example, mean and sd |
|
|
|
are the dist_spec_args for a Normal distribution. |
|
|
|
@@ -171,7 +171,7 @@ class Distribution(Cell): |
|
|
|
if hasattr(self, '_cross_entropy'): |
|
|
|
self._call_cross_entropy = self._cross_entropy |
|
|
|
|
|
|
|
def log_prob(self, *args): |
|
|
|
def log_prob(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the log probability(pdf or pmf) at the given value. |
|
|
|
|
|
|
|
@@ -179,18 +179,18 @@ class Distribution(Cell): |
|
|
|
Args must include value. |
|
|
|
Dist_spec_args are optional. |
|
|
|
""" |
|
|
|
return self._call_log_prob(*args) |
|
|
|
return self._call_log_prob(*args, **kwargs) |
|
|
|
|
|
|
|
def _calc_prob_from_log_prob(self, *args): |
|
|
|
def _calc_prob_from_log_prob(self, *args, **kwargs): |
|
|
|
r""" |
|
|
|
Evaluate prob from log probability. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
probability(x) = \exp(log_likehood(x)) |
|
|
|
""" |
|
|
|
return self.exp(self._log_prob(*args)) |
|
|
|
return self.exp(self._log_prob(*args, **kwargs)) |
|
|
|
|
|
|
|
def prob(self, *args): |
|
|
|
def prob(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the probability (pdf or pmf) at given value. |
|
|
|
|
|
|
|
@@ -198,18 +198,18 @@ class Distribution(Cell): |
|
|
|
Args must include value. |
|
|
|
Dist_spec_args are optional. |
|
|
|
""" |
|
|
|
return self._call_prob(*args) |
|
|
|
return self._call_prob(*args, **kwargs) |
|
|
|
|
|
|
|
def _calc_log_prob_from_prob(self, *args): |
|
|
|
def _calc_log_prob_from_prob(self, *args, **kwargs): |
|
|
|
r""" |
|
|
|
Evaluate log probability from probability. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
log_prob(x) = \log(prob(x)) |
|
|
|
""" |
|
|
|
return self.log(self._prob(*args)) |
|
|
|
return self.log(self._prob(*args, **kwargs)) |
|
|
|
|
|
|
|
def cdf(self, *args): |
|
|
|
def cdf(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the cdf at given value. |
|
|
|
|
|
|
|
@@ -217,36 +217,36 @@ class Distribution(Cell): |
|
|
|
Args must include value. |
|
|
|
Dist_spec_args are optional. |
|
|
|
""" |
|
|
|
return self._call_cdf(*args) |
|
|
|
return self._call_cdf(*args, **kwargs) |
|
|
|
|
|
|
|
def _calc_cdf_from_log_cdf(self, *args): |
|
|
|
def _calc_cdf_from_log_cdf(self, *args, **kwargs): |
|
|
|
r""" |
|
|
|
Evaluate cdf from log_cdf. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
cdf(x) = \exp(log_cdf(x)) |
|
|
|
""" |
|
|
|
return self.exp(self._log_cdf(*args)) |
|
|
|
return self.exp(self._log_cdf(*args, **kwargs)) |
|
|
|
|
|
|
|
def _calc_cdf_from_survival(self, *args): |
|
|
|
def _calc_cdf_from_survival(self, *args, **kwargs): |
|
|
|
r""" |
|
|
|
Evaluate cdf from survival function. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
cdf(x) = 1 - (survival_function(x)) |
|
|
|
""" |
|
|
|
return 1.0 - self._survival_function(*args) |
|
|
|
return 1.0 - self._survival_function(*args, **kwargs) |
|
|
|
|
|
|
|
def _calc_cdf_from_log_survival(self, *args): |
|
|
|
def _calc_cdf_from_log_survival(self, *args, **kwargs): |
|
|
|
r""" |
|
|
|
Evaluate cdf from log survival function. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
cdf(x) = 1 - (\exp(log_survival(x))) |
|
|
|
""" |
|
|
|
return 1.0 - self.exp(self._log_survival(*args)) |
|
|
|
return 1.0 - self.exp(self._log_survival(*args, **kwargs)) |
|
|
|
|
|
|
|
def log_cdf(self, *args): |
|
|
|
def log_cdf(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the log cdf at given value. |
|
|
|
|
|
|
|
@@ -254,18 +254,18 @@ class Distribution(Cell): |
|
|
|
Args must include value. |
|
|
|
Dist_spec_args are optional. |
|
|
|
""" |
|
|
|
return self._call_log_cdf(*args) |
|
|
|
return self._call_log_cdf(*args, **kwargs) |
|
|
|
|
|
|
|
def _calc_log_cdf_from_call_cdf(self, *args): |
|
|
|
def _calc_log_cdf_from_call_cdf(self, *args, **kwargs): |
|
|
|
r""" |
|
|
|
Evaluate log cdf from cdf. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
log_cdf(x) = \log(cdf(x)) |
|
|
|
""" |
|
|
|
return self.log(self._call_cdf(*args)) |
|
|
|
return self.log(self._call_cdf(*args, **kwargs)) |
|
|
|
|
|
|
|
def survival_function(self, *args): |
|
|
|
def survival_function(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the survival function at given value. |
|
|
|
|
|
|
|
@@ -273,27 +273,27 @@ class Distribution(Cell): |
|
|
|
Args must include value. |
|
|
|
Dist_spec_args are optional. |
|
|
|
""" |
|
|
|
return self._call_survival(*args) |
|
|
|
return self._call_survival(*args, **kwargs) |
|
|
|
|
|
|
|
def _calc_survival_from_call_cdf(self, *args): |
|
|
|
def _calc_survival_from_call_cdf(self, *args, **kwargs): |
|
|
|
r""" |
|
|
|
Evaluate survival function from cdf. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
survival_function(x) = 1 - (cdf(x)) |
|
|
|
""" |
|
|
|
return 1.0 - self._call_cdf(*args) |
|
|
|
return 1.0 - self._call_cdf(*args, **kwargs) |
|
|
|
|
|
|
|
def _calc_survival_from_log_survival(self, *args): |
|
|
|
def _calc_survival_from_log_survival(self, *args, **kwargs): |
|
|
|
r""" |
|
|
|
Evaluate survival function from log survival function. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
survival(x) = \exp(survival_function(x)) |
|
|
|
""" |
|
|
|
return self.exp(self._log_survival(*args)) |
|
|
|
return self.exp(self._log_survival(*args, **kwargs)) |
|
|
|
|
|
|
|
def log_survival(self, *args): |
|
|
|
def log_survival(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the log survival function at given value. |
|
|
|
|
|
|
|
@@ -301,18 +301,18 @@ class Distribution(Cell): |
|
|
|
Args must include value. |
|
|
|
Dist_spec_args are optional. |
|
|
|
""" |
|
|
|
return self._call_log_survival(*args) |
|
|
|
return self._call_log_survival(*args, **kwargs) |
|
|
|
|
|
|
|
def _calc_log_survival_from_call_survival(self, *args): |
|
|
|
def _calc_log_survival_from_call_survival(self, *args, **kwargs): |
|
|
|
r""" |
|
|
|
Evaluate log survival function from survival function. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
log_survival(x) = \log(survival_function(x)) |
|
|
|
""" |
|
|
|
return self.log(self._call_survival(*args)) |
|
|
|
return self.log(self._call_survival(*args, **kwargs)) |
|
|
|
|
|
|
|
def kl_loss(self, *args): |
|
|
|
def kl_loss(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the KL divergence, i.e. KL(a||b). |
|
|
|
|
|
|
|
@@ -320,72 +320,72 @@ class Distribution(Cell): |
|
|
|
Args must include type of the distribution, parameters of distribution b. |
|
|
|
Parameters for distribution a are optional. |
|
|
|
""" |
|
|
|
return self._kl_loss(*args) |
|
|
|
return self._kl_loss(*args, **kwargs) |
|
|
|
|
|
|
|
def mean(self, *args): |
|
|
|
def mean(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the mean. |
|
|
|
|
|
|
|
Note: |
|
|
|
Dist_spec_args are optional. |
|
|
|
""" |
|
|
|
return self._mean(*args) |
|
|
|
return self._mean(*args, **kwargs) |
|
|
|
|
|
|
|
def mode(self, *args): |
|
|
|
def mode(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the mode. |
|
|
|
|
|
|
|
Note: |
|
|
|
Dist_spec_args are optional. |
|
|
|
""" |
|
|
|
return self._mode(*args) |
|
|
|
return self._mode(*args, **kwargs) |
|
|
|
|
|
|
|
def sd(self, *args): |
|
|
|
def sd(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the standard deviation. |
|
|
|
|
|
|
|
Note: |
|
|
|
Dist_spec_args are optional. |
|
|
|
""" |
|
|
|
return self._call_sd(*args) |
|
|
|
return self._call_sd(*args, **kwargs) |
|
|
|
|
|
|
|
def var(self, *args): |
|
|
|
def var(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the variance. |
|
|
|
|
|
|
|
Note: |
|
|
|
Dist_spec_args are optional. |
|
|
|
""" |
|
|
|
return self._call_var(*args) |
|
|
|
return self._call_var(*args, **kwargs) |
|
|
|
|
|
|
|
def _calc_sd_from_var(self, *args): |
|
|
|
def _calc_sd_from_var(self, *args, **kwargs): |
|
|
|
r""" |
|
|
|
Evaluate log probability from probability. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
STD(x) = \sqrt(VAR(x)) |
|
|
|
""" |
|
|
|
return self.sqrt(self._var(*args)) |
|
|
|
return self.sqrt(self._var(*args, **kwargs)) |
|
|
|
|
|
|
|
def _calc_var_from_sd(self, *args): |
|
|
|
def _calc_var_from_sd(self, *args, **kwargs): |
|
|
|
r""" |
|
|
|
Evaluate log probability from probability. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
VAR(x) = STD(x) ^ 2 |
|
|
|
""" |
|
|
|
return self.sq(self._sd(*args)) |
|
|
|
return self.sq(self._sd(*args, **kwargs)) |
|
|
|
|
|
|
|
def entropy(self, *args): |
|
|
|
def entropy(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the entropy. |
|
|
|
|
|
|
|
Note: |
|
|
|
Dist_spec_args are optional. |
|
|
|
""" |
|
|
|
return self._entropy(*args) |
|
|
|
return self._entropy(*args, **kwargs) |
|
|
|
|
|
|
|
def cross_entropy(self, *args): |
|
|
|
def cross_entropy(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the cross_entropy between distribution a and b. |
|
|
|
|
|
|
|
@@ -393,32 +393,29 @@ class Distribution(Cell): |
|
|
|
Args must include type of the distribution, parameters of distribution b. |
|
|
|
Parameters for distribution a are optional. |
|
|
|
""" |
|
|
|
return self._call_cross_entropy(*args) |
|
|
|
return self._call_cross_entropy(*args, **kwargs) |
|
|
|
|
|
|
|
def _calc_cross_entropy(self, *args): |
|
|
|
def _calc_cross_entropy(self, *args, **kwargs): |
|
|
|
r""" |
|
|
|
Evaluate cross_entropy from entropy and kl divergence. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
H(X, Y) = H(X) + KL(X||Y) |
|
|
|
""" |
|
|
|
return self._entropy(*args) + self._kl_loss(*args) |
|
|
|
return self._entropy(*args, **kwargs) + self._kl_loss(*args, **kwargs) |
|
|
|
|
|
|
|
def sample(self, *args): |
|
|
|
def sample(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Sampling function. |
|
|
|
|
|
|
|
Args: |
|
|
|
*args (list): arguments passed in through construct. |
|
|
|
|
|
|
|
Note: |
|
|
|
Shape of the sample is default to (). |
|
|
|
Dist_spec_args are optional. |
|
|
|
""" |
|
|
|
return self._sample(*args) |
|
|
|
return self._sample(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
def construct(self, name, *args): |
|
|
|
def construct(self, name, *args, **kwargs): |
|
|
|
""" |
|
|
|
Override construct in Cell. |
|
|
|
|
|
|
|
@@ -433,31 +430,31 @@ class Distribution(Cell): |
|
|
|
""" |
|
|
|
|
|
|
|
if name == 'log_prob': |
|
|
|
return self._call_log_prob(*args) |
|
|
|
return self._call_log_prob(*args, **kwargs) |
|
|
|
if name == 'prob': |
|
|
|
return self._call_prob(*args) |
|
|
|
return self._call_prob(*args, **kwargs) |
|
|
|
if name == 'cdf': |
|
|
|
return self._call_cdf(*args) |
|
|
|
return self._call_cdf(*args, **kwargs) |
|
|
|
if name == 'log_cdf': |
|
|
|
return self._call_log_cdf(*args) |
|
|
|
return self._call_log_cdf(*args, **kwargs) |
|
|
|
if name == 'survival_function': |
|
|
|
return self._call_survival(*args) |
|
|
|
return self._call_survival(*args, **kwargs) |
|
|
|
if name == 'log_survival': |
|
|
|
return self._call_log_survival(*args) |
|
|
|
return self._call_log_survival(*args, **kwargs) |
|
|
|
if name == 'kl_loss': |
|
|
|
return self._kl_loss(*args) |
|
|
|
return self._kl_loss(*args, **kwargs) |
|
|
|
if name == 'mean': |
|
|
|
return self._mean(*args) |
|
|
|
return self._mean(*args, **kwargs) |
|
|
|
if name == 'mode': |
|
|
|
return self._mode(*args) |
|
|
|
return self._mode(*args, **kwargs) |
|
|
|
if name == 'sd': |
|
|
|
return self._call_sd(*args) |
|
|
|
return self._call_sd(*args, **kwargs) |
|
|
|
if name == 'var': |
|
|
|
return self._call_var(*args) |
|
|
|
return self._call_var(*args, **kwargs) |
|
|
|
if name == 'entropy': |
|
|
|
return self._entropy(*args) |
|
|
|
return self._entropy(*args, **kwargs) |
|
|
|
if name == 'cross_entropy': |
|
|
|
return self._call_cross_entropy(*args) |
|
|
|
return self._call_cross_entropy(*args, **kwargs) |
|
|
|
if name == 'sample': |
|
|
|
return self._sample(*args) |
|
|
|
return self._sample(*args, **kwargs) |
|
|
|
return None |