|
|
|
@@ -19,7 +19,8 @@ from mindspore.nn.cell import Cell |
|
|
|
from mindspore._checkparam import Validator as validator |
|
|
|
from mindspore._checkparam import Rel |
|
|
|
from mindspore.common import get_seed |
|
|
|
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device |
|
|
|
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\ |
|
|
|
raise_not_implemented_util |
|
|
|
from ._utils.utils import CheckTuple, CheckTensor |
|
|
|
from ._utils.custom_ops import broadcast_to, exp_generic, log_generic |
|
|
|
|
|
|
|
@@ -245,6 +246,8 @@ class Distribution(Cell): |
|
|
|
self._call_prob = self._prob |
|
|
|
elif hasattr(self, '_log_prob'): |
|
|
|
self._call_prob = self._calc_prob_from_log_prob |
|
|
|
else: |
|
|
|
self._call_prob = self._raise_not_implemented_error('prob') |
|
|
|
|
|
|
|
def _set_sd(self): |
|
|
|
""" |
|
|
|
@@ -254,6 +257,8 @@ class Distribution(Cell): |
|
|
|
self._call_sd = self._sd |
|
|
|
elif hasattr(self, '_var'): |
|
|
|
self._call_sd = self._calc_sd_from_var |
|
|
|
else: |
|
|
|
self._call_sd = self._raise_not_implemented_error('sd') |
|
|
|
|
|
|
|
def _set_var(self): |
|
|
|
""" |
|
|
|
@@ -263,6 +268,8 @@ class Distribution(Cell): |
|
|
|
self._call_var = self._var |
|
|
|
elif hasattr(self, '_sd'): |
|
|
|
self._call_var = self._calc_var_from_sd |
|
|
|
else: |
|
|
|
self._call_var = self._raise_not_implemented_error('var') |
|
|
|
|
|
|
|
def _set_log_prob(self): |
|
|
|
""" |
|
|
|
@@ -272,6 +279,8 @@ class Distribution(Cell): |
|
|
|
self._call_log_prob = self._log_prob |
|
|
|
elif hasattr(self, '_prob'): |
|
|
|
self._call_log_prob = self._calc_log_prob_from_prob |
|
|
|
else: |
|
|
|
self._call_log_prob = self._raise_not_implemented_error('log_prob') |
|
|
|
|
|
|
|
def _set_cdf(self): |
|
|
|
""" |
|
|
|
@@ -286,13 +295,18 @@ class Distribution(Cell): |
|
|
|
self._call_cdf = self._calc_cdf_from_survival |
|
|
|
elif hasattr(self, '_log_survival'): |
|
|
|
self._call_cdf = self._calc_cdf_from_log_survival |
|
|
|
else: |
|
|
|
self._call_cdf = self._raise_not_implemented_error('cdf') |
|
|
|
|
|
|
|
def _set_survival(self): |
|
|
|
""" |
|
|
|
Set survival function based on the availability of _survival function and `_log_survival` |
|
|
|
and `_call_cdf`. |
|
|
|
""" |
|
|
|
if hasattr(self, '_survival_function'): |
|
|
|
if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or \ |
|
|
|
hasattr(self, '_cdf') or hasattr(self, '_log_cdf')): |
|
|
|
self._call_survival = self._raise_not_implemented_error('survival_function') |
|
|
|
elif hasattr(self, '_survival_function'): |
|
|
|
self._call_survival = self._survival_function |
|
|
|
elif hasattr(self, '_log_survival'): |
|
|
|
self._call_survival = self._calc_survival_from_log_survival |
|
|
|
@@ -303,7 +317,10 @@ class Distribution(Cell): |
|
|
|
""" |
|
|
|
Set log cdf based on the availability of `_log_cdf` and `_call_cdf`. |
|
|
|
""" |
|
|
|
if hasattr(self, '_log_cdf'): |
|
|
|
if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or \ |
|
|
|
hasattr(self, '_survival_function') or hasattr(self, '_log_survival')): |
|
|
|
self._call_log_cdf = self._raise_not_implemented_error('log_cdf') |
|
|
|
elif hasattr(self, '_log_cdf'): |
|
|
|
self._call_log_cdf = self._log_cdf |
|
|
|
elif hasattr(self, '_call_cdf'): |
|
|
|
self._call_log_cdf = self._calc_log_cdf_from_call_cdf |
|
|
|
@@ -312,7 +329,10 @@ class Distribution(Cell): |
|
|
|
""" |
|
|
|
Set log survival based on the availability of `_log_survival` and `_call_survival`. |
|
|
|
""" |
|
|
|
if hasattr(self, '_log_survival'): |
|
|
|
if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or \ |
|
|
|
hasattr(self, '_log_cdf') or hasattr(self, '_cdf')): |
|
|
|
self._call_log_survival = self._raise_not_implemented_error('log_cdf') |
|
|
|
elif hasattr(self, '_log_survival'): |
|
|
|
self._call_log_survival = self._log_survival |
|
|
|
elif hasattr(self, '_call_survival'): |
|
|
|
self._call_log_survival = self._calc_log_survival_from_call_survival |
|
|
|
@@ -323,6 +343,14 @@ class Distribution(Cell): |
|
|
|
""" |
|
|
|
if hasattr(self, '_cross_entropy'): |
|
|
|
self._call_cross_entropy = self._cross_entropy |
|
|
|
else: |
|
|
|
self._call_cross_entropy = self._raise_not_implemented_error('cross_entropy') |
|
|
|
|
|
|
|
def _raise_not_implemented_error(self, func_name): |
|
|
|
name = self.name |
|
|
|
def raise_error(*args, **kwargs): |
|
|
|
return raise_not_implemented_util(func_name, name, *args, **kwargs) |
|
|
|
return raise_error |
|
|
|
|
|
|
|
def log_prob(self, value, *args, **kwargs): |
|
|
|
""" |
|
|
|
@@ -495,6 +523,9 @@ class Distribution(Cell): |
|
|
|
""" |
|
|
|
return self.log_base(self._call_survival(value, *args, **kwargs)) |
|
|
|
|
|
|
|
def _kl_loss(self, *args, **kwargs): |
|
|
|
return raise_not_implemented_util('kl_loss', self.name, *args, **kwargs) |
|
|
|
|
|
|
|
def kl_loss(self, dist, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the KL divergence, i.e. KL(a||b). |
|
|
|
@@ -510,6 +541,9 @@ class Distribution(Cell): |
|
|
|
""" |
|
|
|
return self._kl_loss(dist, *args, **kwargs) |
|
|
|
|
|
|
|
def _mean(self, *args, **kwargs): |
|
|
|
return raise_not_implemented_util('mean', self.name, *args, **kwargs) |
|
|
|
|
|
|
|
def mean(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the mean. |
|
|
|
@@ -524,6 +558,9 @@ class Distribution(Cell): |
|
|
|
""" |
|
|
|
return self._mean(*args, **kwargs) |
|
|
|
|
|
|
|
def _mode(self, *args, **kwargs): |
|
|
|
return raise_not_implemented_util('mode', self.name, *args, **kwargs) |
|
|
|
|
|
|
|
def mode(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the mode. |
|
|
|
@@ -584,6 +621,9 @@ class Distribution(Cell): |
|
|
|
""" |
|
|
|
return self.sq_base(self._sd(*args, **kwargs)) |
|
|
|
|
|
|
|
def _entropy(self, *args, **kwargs): |
|
|
|
return raise_not_implemented_util('entropy', self.name, *args, **kwargs) |
|
|
|
|
|
|
|
def entropy(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Evaluate the entropy. |
|
|
|
@@ -622,6 +662,9 @@ class Distribution(Cell): |
|
|
|
""" |
|
|
|
return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs) |
|
|
|
|
|
|
|
def _sample(self, *args, **kwargs): |
|
|
|
return raise_not_implemented_util('sample', self.name, *args, **kwargs) |
|
|
|
|
|
|
|
def sample(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Sampling function. |
|
|
|
@@ -680,4 +723,4 @@ class Distribution(Cell): |
|
|
|
return self._call_cross_entropy(*args, **kwargs) |
|
|
|
if name == 'sample': |
|
|
|
return self._sample(*args, **kwargs) |
|
|
|
return None |
|
|
|
return raise_not_implemented_util(name, self.name, *args, **kwargs) |