Merge pull request !7268 from XunDeng/gumbel_cdftags/v1.1.0
| @@ -21,6 +21,8 @@ from .power_transform import PowerTransform | |||||
| from .exp import Exp | from .exp import Exp | ||||
| from .scalar_affine import ScalarAffine | from .scalar_affine import ScalarAffine | ||||
| from .softplus import Softplus | from .softplus import Softplus | ||||
| from .gumbel_cdf import GumbelCDF | |||||
| from .invert import Invert | |||||
| __all__ = [ | __all__ = [ | ||||
| 'Bijector', | 'Bijector', | ||||
| @@ -28,4 +30,6 @@ __all__ = [ | |||||
| 'Exp', | 'Exp', | ||||
| 'ScalarAffine', | 'ScalarAffine', | ||||
| 'Softplus', | 'Softplus', | ||||
| 'GumbelCDF', | |||||
| 'Invert', | |||||
| ] | ] | ||||
| @@ -0,0 +1,107 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """GumbelCDF Bijector""" | |||||
| from mindspore.common import dtype as mstype | |||||
| from ..distribution._utils.utils import cast_to_tensor, check_greater_zero, set_param_type | |||||
| from ..distribution._utils.custom_ops import exp_generic, log_generic | |||||
| from .bijector import Bijector | |||||
| class GumbelCDF(Bijector): | |||||
| r""" | |||||
| GumbelCDF Bijector. | |||||
| This Bijector performs the operation: | |||||
| .. math:: | |||||
| Y = \exp(-\exp(\frac{-(X - loc)}{scale})) | |||||
| Note: | |||||
| For `reverse` and `reverse_log_jacobian`, input should be in range of (0, 1). | |||||
| Args: | |||||
| loc (int, float, list, numpy.ndarray, Tensor): The location. Default: 0.. | |||||
| scale (int, float, list, numpy.ndarray, Tensor): The scale. Default: 1.0. | |||||
| name (str): The name of the Bijector. Default: 'Gumbel_CDF'. | |||||
| Examples: | |||||
| >>> # To initialize a GumbelCDF bijector of loc 0.0, and scale 1.0. | |||||
| >>> import mindspore.nn.probability.bijector as msb | |||||
| >>> gum = msb.GumbelCDF(0.0, 1.0) | |||||
| >>> | |||||
| >>> # To use GumbelCDF bijector in a network. | |||||
| >>> class net(Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(net, self).__init__(): | |||||
| >>> self.gum = msb.GumbelCDF(0.0, 1.0) | |||||
| >>> | |||||
| >>> def construct(self, value): | |||||
| >>> # Similar calls can be made to other functions | |||||
| >>> # by replacing 'forward' by the name of the function. | |||||
| >>> ans1 = self.gum.forward(value) | |||||
| >>> ans2 = self.gum.inverse(value) | |||||
| >>> ans3 = self.gum.forward_log_jacobian(value) | |||||
| >>> ans4 = self.gum.inverse_log_jacobian(value) | |||||
| """ | |||||
| def __init__(self, | |||||
| loc=0.0, | |||||
| scale=1.0, | |||||
| name='GumbelCDF'): | |||||
| """ | |||||
| Constructor of GumbelCDF Bijector. | |||||
| """ | |||||
| param = dict(locals()) | |||||
| parameter_type = set_param_type({'loc': loc, "scale": scale}, mstype.float32) | |||||
| super(GumbelCDF, self).__init__(name=name, dtype=parameter_type, param=param) | |||||
| self._loc = cast_to_tensor(loc, parameter_type) | |||||
| self._scale = cast_to_tensor(scale, parameter_type) | |||||
| check_greater_zero(self._scale, "scale") | |||||
| self.exp = exp_generic | |||||
| self.log = log_generic | |||||
| @property | |||||
| def loc(self): | |||||
| return self._loc | |||||
| @property | |||||
| def scale(self): | |||||
| return self._scale | |||||
| def extend_repr(self): | |||||
| str_info = f'loc = {self.loc}, scale = {self.scale}' | |||||
| return str_info | |||||
| def shape_mapping(self, shape): | |||||
| return shape | |||||
| def _forward(self, x): | |||||
| x = self._check_value(x, 'value') | |||||
| z = (x - self.loc) / self.scale | |||||
| return self.exp(-self.exp(-z)) | |||||
| def _inverse(self, y): | |||||
| y = self._check_value(y, 'value') | |||||
| return self.loc - self.scale * self.log(-self.log(y)) | |||||
| def _forward_log_jacobian(self, x): | |||||
| x = self._check_value(x, 'value') | |||||
| z = (x - self.loc) / self.scale | |||||
| return -z - self.exp(-z) - self.log(self.scale) | |||||
| def _inverse_log_jacobian(self, y): | |||||
| y = self._check_value(y, 'value') | |||||
| return self.log(self.scale / (-y * self.log(y))) | |||||
| @@ -0,0 +1,75 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Invert Bijector""" | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from .bijector import Bijector | |||||
| class Invert(Bijector): | |||||
| r""" | |||||
| Invert Bijector. | |||||
| Args: | |||||
| bijector (Bijector): Base Bijector. | |||||
| name (str): The name of the Bijector. Default: Invert. | |||||
| Examples: | |||||
| >>> # To initialize an Invert bijector. | |||||
| >>> import mindspore.nn.probability.bijector as msb | |||||
| >>> n = msb.Invert() | |||||
| >>> | |||||
| >>> # To use an Invert bijector in a network. | |||||
| >>> class net(Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(net, self).__init__(): | |||||
| >>> self.inv = msb.Invert(msb.Exp()) | |||||
| >>> | |||||
| >>> def construct(self, value): | |||||
| >>> # Similar calls can be made to other functions | |||||
| >>> # by replacing `forward` by the name of the function. | |||||
| >>> ans1 = self.inv.forward(value) | |||||
| >>> ans2 = self.inv.inverse(value) | |||||
| >>> ans3 = self.inv.forward_log_jacobian(value) | |||||
| >>> ans4 = self.inv.inverse_log_jacobian(value) | |||||
| """ | |||||
| def __init__(self, | |||||
| bijector, | |||||
| name='Invert'): | |||||
| param = dict(locals()) | |||||
| validator.check_value_type('bijector', bijector, [Bijector], "Invert") | |||||
| name = (name + bijector.name) if name == 'Invert' else name | |||||
| super(Invert, self).__init__(is_constant_jacobian=bijector.is_constant_jacobian, | |||||
| is_injective=bijector.is_injective, | |||||
| dtype=bijector.dtype, | |||||
| name=name, | |||||
| param=param) | |||||
| self._bijector = bijector | |||||
| @property | |||||
| def bijector(self): | |||||
| return self._bijector | |||||
| def inverse(self, y): | |||||
| return self.bijector("forward", y) | |||||
| def forward(self, x): | |||||
| return self.bijector("inverse", x) | |||||
| def inverse_log_jacobian(self, y): | |||||
| return self.bijector("forward_log_jacobian", y) | |||||
| def forward_log_jacobian(self, x): | |||||
| return self.bijector("inverse_log_jacobian", x) | |||||
| @@ -0,0 +1,108 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """test cases for gumbel_cdf""" | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| import mindspore.nn.probability.bijector as msb | |||||
| from mindspore import Tensor | |||||
| from mindspore import dtype | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| """ | |||||
| Test class: forward pass of bijector. | |||||
| """ | |||||
| def __init__(self, loc, scale): | |||||
| super(Net, self).__init__() | |||||
| self.bijector = msb.GumbelCDF(loc, scale) | |||||
| def construct(self, x_): | |||||
| return self.bijector.forward(x_) | |||||
| def test_forward(): | |||||
| loc = np.array([0.0]) | |||||
| scale = np.array([[1.0], [2.0]]) | |||||
| forward = Net(loc, scale) | |||||
| x = np.array([-2., -1., 0., 1., 2.]).astype(np.float32) | |||||
| ans = forward(Tensor(x, dtype=dtype.float32)) | |||||
| tol = 1e-6 | |||||
| expected = np.exp(-np.exp(-(x - loc)/scale)) | |||||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | |||||
| class Net1(nn.Cell): | |||||
| """ | |||||
| Test class: backward pass of bijector. | |||||
| """ | |||||
| def __init__(self, loc, scale): | |||||
| super(Net1, self).__init__() | |||||
| self.bijector = msb.GumbelCDF(loc, scale) | |||||
| def construct(self, x_): | |||||
| return self.bijector.inverse(x_) | |||||
| def test_backward(): | |||||
| loc = np.array([0.0]) | |||||
| scale = np.array([[1.0], [2.0]]) | |||||
| backward = Net1(loc, scale) | |||||
| x = np.array([0.1, 0.25, 0.5, 0.75, 0.9]).astype(np.float32) | |||||
| ans = backward(Tensor(x, dtype=dtype.float32)) | |||||
| tol = 1e-6 | |||||
| expected = loc - scale * np.log(-np.log(x)) | |||||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | |||||
| class Net2(nn.Cell): | |||||
| """ | |||||
| Test class: Forward Jacobian. | |||||
| """ | |||||
| def __init__(self, loc, scale): | |||||
| super(Net2, self).__init__() | |||||
| self.bijector = msb.GumbelCDF(loc, scale) | |||||
| def construct(self, x_): | |||||
| return self.bijector.forward_log_jacobian(x_) | |||||
| def test_forward_jacobian(): | |||||
| loc = np.array([0.0]) | |||||
| scale = np.array([[1.0], [2.0]]) | |||||
| forward_jacobian = Net2(loc, scale) | |||||
| x = np.array([-2., -1., 0., 1., 2.]).astype(np.float32) | |||||
| ans = forward_jacobian(Tensor(x)) | |||||
| z = (x - loc) / scale | |||||
| expected = -z - np.exp(-z) - np.log(scale) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | |||||
| class Net3(nn.Cell): | |||||
| """ | |||||
| Test class: Backward Jacobian. | |||||
| """ | |||||
| def __init__(self, loc, scale): | |||||
| super(Net3, self).__init__() | |||||
| self.bijector = msb.GumbelCDF(loc, scale) | |||||
| def construct(self, x_): | |||||
| return self.bijector.inverse_log_jacobian(x_) | |||||
| def test_backward_jacobian(): | |||||
| loc = np.array([0.0]) | |||||
| scale = np.array([[1.0], [2.0]]) | |||||
| backward_jacobian = Net3(loc, scale) | |||||
| x = np.array([0.1, 0.2, 0.5, 0.75, 0.9]).astype(np.float32) | |||||
| ans = backward_jacobian(Tensor(x)) | |||||
| expected = np.log(scale / (-x * np.log(x))) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | |||||
| @@ -0,0 +1,101 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """test cases for invert""" | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| import mindspore.nn.probability.bijector as msb | |||||
| from mindspore import Tensor | |||||
| from mindspore import dtype | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| """ | |||||
| Test class: forward pass of bijector. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.origin = msb.ScalarAffine(scale=2.0, shift=1.0) | |||||
| self.invert = msb.Invert(self.origin) | |||||
| def construct(self, x_): | |||||
| return self.invert.forward(x_), self.origin.inverse(x_) | |||||
| def test_forward(): | |||||
| forward = Net() | |||||
| x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) | |||||
| ans, ans2 = forward(Tensor(x, dtype=dtype.float32)) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(ans.asnumpy() - ans2.asnumpy()) < tol).all() | |||||
| class Net1(nn.Cell): | |||||
| """ | |||||
| Test class: backward pass of bijector. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net1, self).__init__() | |||||
| self.origin = msb.ScalarAffine(scale=2.0, shift=1.0) | |||||
| self.invert = msb.Invert(self.origin) | |||||
| def construct(self, x_): | |||||
| return self.invert.inverse(x_), self.origin.forward(x_) | |||||
| def test_backward(): | |||||
| backward = Net1() | |||||
| x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) | |||||
| ans, ans2 = backward(Tensor(x, dtype=dtype.float32)) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(ans.asnumpy() - ans2.asnumpy()) < tol).all() | |||||
| class Net2(nn.Cell): | |||||
| """ | |||||
| Test class: Forward Jacobian. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net2, self).__init__() | |||||
| self.origin = msb.ScalarAffine(scale=2.0, shift=1.0) | |||||
| self.invert = msb.Invert(self.origin) | |||||
| def construct(self, x_): | |||||
| return self.invert.forward_log_jacobian(x_),\ | |||||
| self.origin.inverse_log_jacobian(x_) | |||||
| def test_forward_jacobian(): | |||||
| forward_jacobian = Net2() | |||||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||||
| ans, ans2 = forward_jacobian(x) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(ans.asnumpy() - ans2.asnumpy()) < tol).all() | |||||
| class Net3(nn.Cell): | |||||
| """ | |||||
| Test class: Backward Jacobian. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net3, self).__init__() | |||||
| self.origin = msb.ScalarAffine(scale=2.0, shift=1.0) | |||||
| self.invert = msb.Invert(self.origin) | |||||
| def construct(self, x_): | |||||
| return self.invert.inverse_log_jacobian(x_),\ | |||||
| self.origin.forward_log_jacobian(x_) | |||||
| def test_backward_jacobian(): | |||||
| backward_jacobian = Net3() | |||||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||||
| ans, ans2 = backward_jacobian(x) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(ans.asnumpy() - ans2.asnumpy()) < tol).all() | |||||
| @@ -0,0 +1,148 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """test cases for gumbel_cdf""" | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| import mindspore.nn.probability.bijector as msb | |||||
| from mindspore import Tensor | |||||
| from mindspore import dtype | |||||
| def test_init(): | |||||
| """ | |||||
| Test initializations. | |||||
| """ | |||||
| b = msb.GumbelCDF() | |||||
| assert isinstance(b, msb.Bijector) | |||||
| b = msb.GumbelCDF(scale=1.0) | |||||
| assert isinstance(b, msb.Bijector) | |||||
| b = msb.GumbelCDF(loc=0.0) | |||||
| assert isinstance(b, msb.Bijector) | |||||
| b = msb.GumbelCDF(3.0, 4.0) | |||||
| assert isinstance(b, msb.Bijector) | |||||
| def test_type(): | |||||
| with pytest.raises(TypeError): | |||||
| msb.GumbelCDF(scale='scale') | |||||
| with pytest.raises(TypeError): | |||||
| msb.GumbelCDF(loc='loc') | |||||
| with pytest.raises(TypeError): | |||||
| msb.GumbelCDF(name=0.1) | |||||
| def test_invalid_scale(): | |||||
| """ | |||||
| Test invalid scale. | |||||
| """ | |||||
| with pytest.raises(ValueError): | |||||
| msb.GumbelCDF(scale=0.0) | |||||
| with pytest.raises(ValueError): | |||||
| msb.GumbelCDF(scale=-1.0) | |||||
| class ForwardBackward(nn.Cell): | |||||
| """ | |||||
| Test class: forward and backward pass. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(ForwardBackward, self).__init__() | |||||
| self.b1 = msb.GumbelCDF(1.0, 2.0) | |||||
| self.b2 = msb.GumbelCDF() | |||||
| def construct(self, x_): | |||||
| ans1 = self.b1.inverse(self.b1.forward(x_)) | |||||
| ans2 = self.b2.inverse(self.b2.forward(x_)) | |||||
| return ans1 + ans2 | |||||
| def test_forward_and_backward_pass(): | |||||
| """ | |||||
| Test forward and backward pass of ScalarAffine bijector. | |||||
| """ | |||||
| net = ForwardBackward() | |||||
| x = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype.float32) | |||||
| ans = net(x) | |||||
| assert isinstance(ans, Tensor) | |||||
| class ForwardJacobian(nn.Cell): | |||||
| """ | |||||
| Test class: Forward log Jacobian. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(ForwardJacobian, self).__init__() | |||||
| self.b1 = msb.GumbelCDF(1.0, 2.0) | |||||
| self.b2 = msb.GumbelCDF() | |||||
| def construct(self, x_): | |||||
| ans1 = self.b1.forward_log_jacobian(x_) | |||||
| ans2 = self.b2.forward_log_jacobian(x_) | |||||
| return ans1 + ans2 | |||||
| def test_forward_jacobian(): | |||||
| """ | |||||
| Test forward log jacobian of ScalarAffine bijector. | |||||
| """ | |||||
| net = ForwardJacobian() | |||||
| x = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype.float32) | |||||
| ans = net(x) | |||||
| assert isinstance(ans, Tensor) | |||||
| class BackwardJacobian(nn.Cell): | |||||
| """ | |||||
| Test class: Backward log Jacobian. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BackwardJacobian, self).__init__() | |||||
| self.b1 = msb.GumbelCDF(1.0, 2.0) | |||||
| self.b2 = msb.GumbelCDF() | |||||
| def construct(self, x_): | |||||
| ans1 = self.b1.inverse_log_jacobian(x_) | |||||
| ans2 = self.b2.inverse_log_jacobian(x_) | |||||
| return ans1 + ans2 | |||||
| def test_backward_jacobian(): | |||||
| """ | |||||
| Test backward log jacobian of ScalarAffine bijector. | |||||
| """ | |||||
| net = BackwardJacobian() | |||||
| x = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype.float32) | |||||
| ans = net(x) | |||||
| assert isinstance(ans, Tensor) | |||||
| class Net(nn.Cell): | |||||
| """ | |||||
| Test class: function calls going through construct. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.b1 = msb.GumbelCDF(1.0, 2.0) | |||||
| self.b2 = msb.GumbelCDF() | |||||
| def construct(self, x_): | |||||
| ans1 = self.b1('inverse', self.b1('forward', x_)) | |||||
| ans2 = self.b2('inverse', self.b2('forward', x_)) | |||||
| ans3 = self.b1('forward_log_jacobian', x_) | |||||
| ans4 = self.b2('forward_log_jacobian', x_) | |||||
| ans5 = self.b1('inverse_log_jacobian', x_) | |||||
| ans6 = self.b2('inverse_log_jacobian', x_) | |||||
| return ans1 - ans2 + ans3 -ans4 + ans5 - ans6 | |||||
| def test_old_api(): | |||||
| """ | |||||
| Test old api which goes through construct. | |||||
| """ | |||||
| net = Net() | |||||
| x = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype.float32) | |||||
| ans = net(x) | |||||
| assert isinstance(ans, Tensor) | |||||
| @@ -0,0 +1,136 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """test cases for invert""" | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| import mindspore.nn.probability.bijector as msb | |||||
| from mindspore import Tensor | |||||
| from mindspore import dtype | |||||
| def test_init(): | |||||
| """ | |||||
| Test initializations. | |||||
| """ | |||||
| b = msb.Invert(msb.ScalarAffine(scale=1.0)) | |||||
| assert isinstance(b, msb.Bijector) | |||||
| b = msb.Invert(msb.Exp()) | |||||
| assert isinstance(b, msb.Bijector) | |||||
| def test_type(): | |||||
| with pytest.raises(TypeError): | |||||
| msb.Invert(msb.Exp(), name=0.1) | |||||
| with pytest.raises(TypeError): | |||||
| msb.Invert(0.1) | |||||
| def test_name(): | |||||
| b = msb.Invert(msb.ScalarAffine(scale=1.0)) | |||||
| assert b.name == 'InvertScalarAffine' | |||||
| class ForwardBackward(nn.Cell): | |||||
| """ | |||||
| Test class: forward and backward pass. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(ForwardBackward, self).__init__() | |||||
| self.inv1 = msb.Invert(msb.Exp()) | |||||
| self.inv2 = msb.Invert(msb.ScalarAffine()) | |||||
| def construct(self, x_): | |||||
| ans1 = self.inv1.inverse(x_) + self.inv1.inverse(x_) | |||||
| ans2 = self.inv2.inverse(x_) + self.inv2.forward(x_) | |||||
| return ans1 + ans2 | |||||
| def test_forward_and_backward_pass(): | |||||
| """ | |||||
| Test forward and backward pass of ScalarAffine bijector. | |||||
| """ | |||||
| net = ForwardBackward() | |||||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||||
| ans = net(x) | |||||
| assert isinstance(ans, Tensor) | |||||
| class ForwardJacobian(nn.Cell): | |||||
| """ | |||||
| Test class: Forward log Jacobian. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(ForwardJacobian, self).__init__() | |||||
| self.inv1 = msb.Invert(msb.Exp()) | |||||
| self.inv2 = msb.Invert(msb.ScalarAffine()) | |||||
| def construct(self, x_): | |||||
| ans1 = self.inv1.forward_log_jacobian(x_) | |||||
| ans2 = self.inv2.forward_log_jacobian(x_) | |||||
| return ans1 + ans2 | |||||
| def test_forward_jacobian(): | |||||
| """ | |||||
| Test forward log jacobian of ScalarAffine bijector. | |||||
| """ | |||||
| net = ForwardJacobian() | |||||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||||
| ans = net(x) | |||||
| assert isinstance(ans, Tensor) | |||||
| class BackwardJacobian(nn.Cell): | |||||
| """ | |||||
| Test class: Backward log Jacobian. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BackwardJacobian, self).__init__() | |||||
| self.inv1 = msb.Invert(msb.Exp()) | |||||
| self.inv2 = msb.Invert(msb.ScalarAffine()) | |||||
| def construct(self, x_): | |||||
| ans1 = self.inv1.inverse_log_jacobian(x_) | |||||
| ans2 = self.inv2.inverse_log_jacobian(x_) | |||||
| return ans1 + ans2 | |||||
| def test_backward_jacobian(): | |||||
| """ | |||||
| Test backward log jacobian of ScalarAffine bijector. | |||||
| """ | |||||
| net = BackwardJacobian() | |||||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||||
| ans = net(x) | |||||
| assert isinstance(ans, Tensor) | |||||
| class Net(nn.Cell): | |||||
| """ | |||||
| Test class: function calls going through construct. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.b1 = msb.ScalarAffine(2.0, 1.0) | |||||
| self.inv = msb.Invert(self.b1) | |||||
| def construct(self, x_): | |||||
| ans1 = self.inv('inverse', self.inv('forward', x_)) | |||||
| ans2 = self.inv('forward_log_jacobian', x_) | |||||
| ans3 = self.inv('inverse_log_jacobian', x_) | |||||
| return ans1 + ans2 + ans3 | |||||
| def test_old_api(): | |||||
| """ | |||||
| Test old api which goes through construct. | |||||
| """ | |||||
| net = Net() | |||||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||||
| ans = net(x) | |||||
| assert isinstance(ans, Tensor) | |||||