Browse Source

added raise_not_implemented_error in distribution

tags/v1.1.0
Xun Deng 5 years ago
parent
commit
ea57699ed1
3 changed files with 155 additions and 5 deletions
  1. +5
    -0
      mindspore/nn/probability/distribution/_utils/utils.py
  2. +48
    -5
      mindspore/nn/probability/distribution/distribution.py
  3. +102
    -0
      tests/ut/python/nn/probability/distribution/test_distribution.py

+ 5
- 0
mindspore/nn/probability/distribution/_utils/utils.py View File

@@ -218,6 +218,11 @@ def raise_not_impl_error(name):
raise ValueError(
f"{name} function should be implemented for non-linear transformation")

@constexpr
def raise_not_implemented_util(func_name, obj, *args, **kwargs):
raise NotImplementedError(
f"{func_name} is not implemented for {obj} distribution.")


@constexpr
def check_distribution_name(name, expected_name):


+ 48
- 5
mindspore/nn/probability/distribution/distribution.py View File

@@ -19,7 +19,8 @@ from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common import get_seed
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
raise_not_implemented_util
from ._utils.utils import CheckTuple, CheckTensor
from ._utils.custom_ops import broadcast_to, exp_generic, log_generic

@@ -245,6 +246,8 @@ class Distribution(Cell):
self._call_prob = self._prob
elif hasattr(self, '_log_prob'):
self._call_prob = self._calc_prob_from_log_prob
else:
self._call_prob = self._raise_not_implemented_error('prob')

def _set_sd(self):
"""
@@ -254,6 +257,8 @@ class Distribution(Cell):
self._call_sd = self._sd
elif hasattr(self, '_var'):
self._call_sd = self._calc_sd_from_var
else:
self._call_sd = self._raise_not_implemented_error('sd')

def _set_var(self):
"""
@@ -263,6 +268,8 @@ class Distribution(Cell):
self._call_var = self._var
elif hasattr(self, '_sd'):
self._call_var = self._calc_var_from_sd
else:
self._call_var = self._raise_not_implemented_error('var')

def _set_log_prob(self):
"""
@@ -272,6 +279,8 @@ class Distribution(Cell):
self._call_log_prob = self._log_prob
elif hasattr(self, '_prob'):
self._call_log_prob = self._calc_log_prob_from_prob
else:
self._call_log_prob = self._raise_not_implemented_error('log_prob')

def _set_cdf(self):
"""
@@ -286,13 +295,18 @@ class Distribution(Cell):
self._call_cdf = self._calc_cdf_from_survival
elif hasattr(self, '_log_survival'):
self._call_cdf = self._calc_cdf_from_log_survival
else:
self._call_cdf = self._raise_not_implemented_error('cdf')

def _set_survival(self):
"""
Set survival function based on the availability of _survival function and `_log_survival`
and `_call_cdf`.
"""
if hasattr(self, '_survival_function'):
if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or \
hasattr(self, '_cdf') or hasattr(self, '_log_cdf')):
self._call_survival = self._raise_not_implemented_error('survival_function')
elif hasattr(self, '_survival_function'):
self._call_survival = self._survival_function
elif hasattr(self, '_log_survival'):
self._call_survival = self._calc_survival_from_log_survival
@@ -303,7 +317,10 @@ class Distribution(Cell):
"""
Set log cdf based on the availability of `_log_cdf` and `_call_cdf`.
"""
if hasattr(self, '_log_cdf'):
if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or \
hasattr(self, '_survival_function') or hasattr(self, '_log_survival')):
self._call_log_cdf = self._raise_not_implemented_error('log_cdf')
elif hasattr(self, '_log_cdf'):
self._call_log_cdf = self._log_cdf
elif hasattr(self, '_call_cdf'):
self._call_log_cdf = self._calc_log_cdf_from_call_cdf
@@ -312,7 +329,10 @@ class Distribution(Cell):
"""
Set log survival based on the availability of `_log_survival` and `_call_survival`.
"""
if hasattr(self, '_log_survival'):
if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or \
hasattr(self, '_log_cdf') or hasattr(self, '_cdf')):
self._call_log_survival = self._raise_not_implemented_error('log_cdf')
elif hasattr(self, '_log_survival'):
self._call_log_survival = self._log_survival
elif hasattr(self, '_call_survival'):
self._call_log_survival = self._calc_log_survival_from_call_survival
@@ -323,6 +343,14 @@ class Distribution(Cell):
"""
if hasattr(self, '_cross_entropy'):
self._call_cross_entropy = self._cross_entropy
else:
self._call_cross_entropy = self._raise_not_implemented_error('cross_entropy')

def _raise_not_implemented_error(self, func_name):
name = self.name
def raise_error(*args, **kwargs):
return raise_not_implemented_util(func_name, name, *args, **kwargs)
return raise_error

def log_prob(self, value, *args, **kwargs):
"""
@@ -495,6 +523,9 @@ class Distribution(Cell):
"""
return self.log_base(self._call_survival(value, *args, **kwargs))

def _kl_loss(self, *args, **kwargs):
return raise_not_implemented_util('kl_loss', self.name, *args, **kwargs)

def kl_loss(self, dist, *args, **kwargs):
"""
Evaluate the KL divergence, i.e. KL(a||b).
@@ -510,6 +541,9 @@ class Distribution(Cell):
"""
return self._kl_loss(dist, *args, **kwargs)

def _mean(self, *args, **kwargs):
return raise_not_implemented_util('mean', self.name, *args, **kwargs)

def mean(self, *args, **kwargs):
"""
Evaluate the mean.
@@ -524,6 +558,9 @@ class Distribution(Cell):
"""
return self._mean(*args, **kwargs)

def _mode(self, *args, **kwargs):
return raise_not_implemented_util('mode', self.name, *args, **kwargs)

def mode(self, *args, **kwargs):
"""
Evaluate the mode.
@@ -584,6 +621,9 @@ class Distribution(Cell):
"""
return self.sq_base(self._sd(*args, **kwargs))

def _entropy(self, *args, **kwargs):
return raise_not_implemented_util('entropy', self.name, *args, **kwargs)

def entropy(self, *args, **kwargs):
"""
Evaluate the entropy.
@@ -622,6 +662,9 @@ class Distribution(Cell):
"""
return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs)

def _sample(self, *args, **kwargs):
return raise_not_implemented_util('sample', self.name, *args, **kwargs)

def sample(self, *args, **kwargs):
"""
Sampling function.
@@ -680,4 +723,4 @@ class Distribution(Cell):
return self._call_cross_entropy(*args, **kwargs)
if name == 'sample':
return self._sample(*args, **kwargs)
return None
return raise_not_implemented_util(name, self.name, *args, **kwargs)

+ 102
- 0
tests/ut/python/nn/probability/distribution/test_distribution.py View File

@@ -0,0 +1,102 @@
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test nn.probability.distribution.
"""
import pytest

import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype as mstype
from mindspore import Tensor
from mindspore import context

func_name_list = ['prob', 'log_prob', 'cdf', 'log_cdf',
'survival_function', 'log_survival',
'sd', 'var', 'mode', 'mean',
'entropy', 'kl_loss', 'cross_entropy',
'sample']

class MyExponential(msd.Distribution):
"""
Test distirbution class: no function is implemented.
"""
def __init__(self, rate=None, seed=None, dtype=mstype.float32, name="MyExponential"):
param = dict(locals())
param['param_dict'] = {'rate': rate}
super(MyExponential, self).__init__(seed, dtype, name, param)

class Net(nn.Cell):
"""
Test Net: function called through construct.
"""
def __init__(self, func_name):
super(Net, self).__init__()
self.dist = MyExponential()
self.name = func_name

def construct(self, *args, **kwargs):
return self.dist(self.name, *args, **kwargs)


def test_raise_not_implemented_error_construct():
"""
test raise not implemented error in pynative mode.
"""
value = Tensor([0.2], dtype=mstype.float32)
for func_name in func_name_list:
with pytest.raises(NotImplementedError):
net = Net(func_name)
net(value)

def test_raise_not_implemented_error_construct_graph_mode():
"""
test raise not implemented error in graph mode.
"""
context.set_context(mode=context.GRAPH_MODE)
value = Tensor([0.2], dtype=mstype.float32)
for func_name in func_name_list:
with pytest.raises(NotImplementedError):
net = Net(func_name)
net(value)

class Net1(nn.Cell):
"""
Test Net: function called directly.
"""
def __init__(self, func_name):
super(Net1, self).__init__()
self.dist = MyExponential()
self.func = getattr(self.dist, func_name)

def construct(self, *args, **kwargs):
return self.func(*args, **kwargs)

def test_raise_not_implemented_error():
"""
test raise not implemented error in pynative mode.
"""
value = Tensor([0.2], dtype=mstype.float32)
for func_name in func_name_list:
with pytest.raises(NotImplementedError):
net = Net1(func_name)
net(value)

def test_raise_not_implemented_error_graph_mode():
"""
test raise not implemented error in graph mode.
"""
context.set_context(mode=context.GRAPH_MODE)
value = Tensor([0.2], dtype=mstype.float32)
for func_name in func_name_list:
with pytest.raises(NotImplementedError):
net = Net1(func_name)
net(value)

Loading…
Cancel
Save