| @@ -21,7 +21,13 @@ The high-level components(Bijectors) used to construct the probabilistic network | |||
| from .bijector import Bijector | |||
| from .power_transform import PowerTransform | |||
| from .exp import Exp | |||
| from .scalar_affine import ScalarAffine | |||
| from .softplus import Softplus | |||
| __all__ = ['Bijector', | |||
| 'PowerTransform', | |||
| 'Exp'] | |||
| __all__ = [ | |||
| 'Bijector', | |||
| 'PowerTransform', | |||
| 'Exp', | |||
| 'ScalarAffine', | |||
| 'Softplus', | |||
| ] | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """Bijector""" | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore._checkparam import Validator as validator | |||
| from ..distribution import Distribution | |||
| from ..distribution import TransformedDistribution | |||
| @@ -39,6 +40,9 @@ class Bijector(Cell): | |||
| Constructor of bijector class. | |||
| """ | |||
| super(Bijector, self).__init__() | |||
| validator.check_value_type('name', name, [str], 'Bijector') | |||
| validator.check_value_type('is_constant_jacobian', is_constant_jacobian, [bool], name) | |||
| validator.check_value_type('is_injective', is_injective, [bool], name) | |||
| self._name = name | |||
| self._dtype = dtype | |||
| self._parameters = {} | |||
| @@ -0,0 +1,116 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Scalar Affine Bijector""" | |||
| from mindspore.ops import operations as P | |||
| from mindspore._checkparam import Validator as validator | |||
| from ..distribution._utils.utils import cast_to_tensor | |||
| from .bijector import Bijector | |||
| class ScalarAffine(Bijector): | |||
| """ | |||
| Scalar Affine Bijector. | |||
| This Bijector performs the operation: Y = a * X + b, where a is the scale | |||
| factor and b is the shift factor. | |||
| Args: | |||
| scale (float): scale factor. Default: 1.0. | |||
| shift (float): shift factor. Default: 0.0. | |||
| Examples: | |||
| >>> # To initialize a ScalarAffine bijector of scale 1 and shift 2 | |||
| >>> scalaraffine = nn.probability.bijector.ScalarAffine(1, 2) | |||
| >>> | |||
| >>> # To use ScalarAffine bijector in a network | |||
| >>> class net(Cell): | |||
| >>> def __init__(self): | |||
| >>> super(net, self).__init__(): | |||
| >>> self.s1 = nn.probability.bijector.ScalarAffine(1, 2) | |||
| >>> | |||
| >>> def construct(self, value): | |||
| >>> # Similar calls can be made to other probability functions | |||
| >>> # by replacing 'forward' with the name of the function | |||
| >>> ans = self.s1.forward(value) | |||
| >>> ans = self.s1.inverse(value) | |||
| >>> ans = self.s1.forward_log_jacobian(value) | |||
| >>> ans = self.s1.inverse_log_jacobian(value) | |||
| """ | |||
| def __init__(self, | |||
| scale=1.0, | |||
| shift=0.0, | |||
| name='ScalarAffine'): | |||
| """ | |||
| Constructor of scalar affine bijector. | |||
| """ | |||
| param = dict(locals()) | |||
| validator.check_value_type('scale', scale, [float], name) | |||
| validator.check_value_type('shift', shift, [float], name) | |||
| self._scale = cast_to_tensor(scale) | |||
| self._shift = cast_to_tensor(shift) | |||
| super(ScalarAffine, self).__init__( | |||
| is_constant_jacobian=True, | |||
| is_injective=True, | |||
| name=name, | |||
| dtype=None, | |||
| param=param) | |||
| self.log = P.Log() | |||
| self.oneslike = P.OnesLike() | |||
| @property | |||
| def scale(self): | |||
| return self._scale | |||
| @property | |||
| def shift(self): | |||
| return self._shift | |||
| def extend_repr(self): | |||
| str_info = f'scale = {self.scale}, shift = {self.shift}' | |||
| return str_info | |||
| def shape_mapping(self, shape): | |||
| return shape | |||
| def _forward(self, x): | |||
| r""" | |||
| .. math:: | |||
| f(x) = a * x + b | |||
| """ | |||
| return self.scale * x + self.shift | |||
| def _inverse(self, y): | |||
| r""" | |||
| .. math:: | |||
| f(y) = \frac{y - b}{a} | |||
| """ | |||
| return (y - self.shift) / self.scale | |||
| def _forward_log_jacobian(self, value): | |||
| r""" | |||
| .. math:: | |||
| f(x) = a * x + b | |||
| f'(x) = a | |||
| \log(f'(x)) = \log(a) | |||
| """ | |||
| return self.log(self.scale) * self.oneslike(value) | |||
| def _inverse_log_jacobian(self, value): | |||
| r""" | |||
| .. math:: | |||
| f(y) = \frac{(y - b)}{a} | |||
| f'(x) = \frac{1.0}{a} | |||
| \log(f'(x)) = - \log(a) | |||
| """ | |||
| return -1. * self.log(self.scale) * self.oneslike(value) | |||
| @@ -0,0 +1,124 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Softplus Bijector""" | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.layer.activation import LogSigmoid | |||
| from mindspore._checkparam import Validator as validator | |||
| from ..distribution._utils.utils import cast_to_tensor | |||
| from .bijector import Bijector | |||
| class Softplus(Bijector): | |||
| r""" | |||
| Softplus Bijector. | |||
| This Bijector performs the operation: Y = \frac{\log(1 + e ^ {kX})}{k}, where k is the sharpness factor. | |||
| Args: | |||
| sharpness (float): scale factor. Default: 1.0. | |||
| Examples: | |||
| >>> # To initialize a Softplus bijector of sharpness 2 | |||
| >>> softplus = nn.probability.bijector.Softfplus(2) | |||
| >>> | |||
| >>> # To use ScalarAffine bijector in a network | |||
| >>> class net(Cell): | |||
| >>> def __init__(self): | |||
| >>> super(net, self).__init__(): | |||
| >>> self.sp1 = nn.probability.bijector.Softflus(2) | |||
| >>> | |||
| >>> def construct(self, value): | |||
| >>> # Similar calls can be made to other probability functions | |||
| >>> # by replacing 'forward' with the name of the function | |||
| >>> ans = self.sp1.forward(value) | |||
| >>> ans = self.sp1.inverse(value) | |||
| >>> ans = self.sp1.forward_log_jacobian(value) | |||
| >>> ans = self.sp1.inverse_log_jacobian(value) | |||
| """ | |||
| def __init__(self, | |||
| sharpness=1.0, | |||
| name='Softplus'): | |||
| param = dict(locals()) | |||
| validator.check_value_type('sharpness', sharpness, [float], name) | |||
| super(Softplus, self).__init__(name=name, param=param) | |||
| self._sharpness = cast_to_tensor(sharpness) | |||
| self.exp = P.Exp() | |||
| self.expm1 = self._expm1_by_step | |||
| self.log_sigmoid = LogSigmoid() | |||
| self.log = P.Log() | |||
| self.sigmoid = P.Sigmoid() | |||
| self.softplus = self._softplus | |||
| self.inverse_softplus = self._inverse_softplus | |||
| def _expm1_by_step(self, x): | |||
| """ | |||
| Expm1 ops under GPU context. | |||
| """ | |||
| return self.exp(x) - 1.0 | |||
| def _softplus(self, x): | |||
| return self.log(self.exp(x) + 1.0) | |||
| def _inverse_softplus(self, x): | |||
| r""" | |||
| .. math:: | |||
| f(x) = \frac{\log(1 + e^{x}))} | |||
| f^{-1}(y) = \frac{\log(e^{y} - 1)} | |||
| """ | |||
| return self.log(self.expm1(x)) | |||
| @property | |||
| def sharpness(self): | |||
| return self._sharpness | |||
| def extend_repr(self): | |||
| str_info = f'sharpness = {self.sharpness}' | |||
| return str_info | |||
| def shape_mapping(self, shape): | |||
| return shape | |||
| def _forward(self, x): | |||
| scaled_value = self.sharpness * x | |||
| return self.softplus(scaled_value) / self.sharpness | |||
| def _inverse(self, y): | |||
| r""" | |||
| .. math:: | |||
| f(x) = \frac{\log(1 + e^{kx}))}{k} | |||
| f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} | |||
| """ | |||
| scaled_value = self.sharpness * y | |||
| return self.inverse_softplus(scaled_value) / self.sharpness | |||
| def _forward_log_jacobian(self, x): | |||
| r""" | |||
| .. math: | |||
| f(x) = \log(1 + e^{kx}) / k | |||
| f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} | |||
| \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) | |||
| """ | |||
| scaled_value = self.sharpness * x | |||
| return self.log_sigmoid(scaled_value) | |||
| def _inverse_log_jacobian(self, y): | |||
| r""" | |||
| .. math: | |||
| f(y) = \frac{\log(e^{ky} - 1)}{k} | |||
| f'(y) = \frac{e^{ky}}{e^{ky} - 1} | |||
| \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) | |||
| """ | |||
| scaled_value = self.sharpness * y | |||
| return scaled_value - self.inverse_softplus(scaled_value) | |||
| @@ -0,0 +1,99 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cases for scalar affine""" | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability.bijector as msb | |||
| from mindspore import Tensor | |||
| from mindspore import dtype | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| """ | |||
| Test class: forward pass of bijector. | |||
| """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.bijector = msb.ScalarAffine(scale=2.0, shift=1.0) | |||
| def construct(self, x_): | |||
| return self.bijector.forward(x_) | |||
| def test_forward(): | |||
| forward = Net() | |||
| x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) | |||
| ans = forward(Tensor(x, dtype=dtype.float32)) | |||
| tol = 1e-6 | |||
| expected = 2 * x + 1 | |||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | |||
| class Net1(nn.Cell): | |||
| """ | |||
| Test class: backward pass of bijector. | |||
| """ | |||
| def __init__(self): | |||
| super(Net1, self).__init__() | |||
| self.bijector = msb.ScalarAffine(shift=1.0, scale=2.0) | |||
| def construct(self, x_): | |||
| return self.bijector.inverse(x_) | |||
| def test_backward(): | |||
| backward = Net1() | |||
| x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) | |||
| ans = backward(Tensor(x, dtype=dtype.float32)) | |||
| tol = 1e-6 | |||
| expected = 0.5 * (x - 1.0) | |||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | |||
| class Net2(nn.Cell): | |||
| """ | |||
| Test class: Forward Jacobian. | |||
| """ | |||
| def __init__(self): | |||
| super(Net2, self).__init__() | |||
| self.bijector = msb.ScalarAffine(shift=1.0, scale=2.0) | |||
| def construct(self, x_): | |||
| return self.bijector.forward_log_jacobian(x_) | |||
| def test_forward_jacobian(): | |||
| forward_jacobian = Net2() | |||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||
| ans = forward_jacobian(x) | |||
| expected = np.log([2.0, 2.0, 2.0, 2.0]) | |||
| tol = 1e-6 | |||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | |||
| class Net3(nn.Cell): | |||
| """ | |||
| Test class: Backward Jacobian. | |||
| """ | |||
| def __init__(self): | |||
| super(Net3, self).__init__() | |||
| self.bijector = msb.ScalarAffine(shift=1.0, scale=2.0) | |||
| def construct(self, x_): | |||
| return self.bijector.inverse_log_jacobian(x_) | |||
| def test_backward_jacobian(): | |||
| backward_jacobian = Net3() | |||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||
| ans = backward_jacobian(x) | |||
| expected = np.log([0.5, 0.5, 0.5, 0.5]) | |||
| tol = 1e-6 | |||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | |||
| @@ -0,0 +1,99 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cases for scalar affine""" | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability.bijector as msb | |||
| from mindspore import Tensor | |||
| from mindspore import dtype | |||
| context.set_context(device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| """ | |||
| Test class: forward pass of bijector. | |||
| """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.bijector = msb.Softplus(sharpness=2.0) | |||
| def construct(self, x_): | |||
| return self.bijector.forward(x_) | |||
| def test_forward(): | |||
| forward = Net() | |||
| x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) | |||
| ans = forward(Tensor(x, dtype=dtype.float32)) | |||
| expected = np.log(1 + np.exp(2 * x)) * 0.5 | |||
| tol = 1e-6 | |||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | |||
| class Net1(nn.Cell): | |||
| """ | |||
| Test class: backward pass of bijector. | |||
| """ | |||
| def __init__(self): | |||
| super(Net1, self).__init__() | |||
| self.bijector = msb.Softplus(sharpness=2.0) | |||
| def construct(self, x_): | |||
| return self.bijector.inverse(x_) | |||
| def test_backward(): | |||
| backward = Net1() | |||
| x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) | |||
| ans = backward(Tensor(x, dtype=dtype.float32)) | |||
| expected = np.log(np.exp(2 * x) - 1) * 0.5 | |||
| tol = 1e-6 | |||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | |||
| class Net2(nn.Cell): | |||
| """ | |||
| Test class: Forward Jacobian. | |||
| """ | |||
| def __init__(self): | |||
| super(Net2, self).__init__() | |||
| self.bijector = msb.Softplus(sharpness=2.0) | |||
| def construct(self, x_): | |||
| return self.bijector.forward_log_jacobian(x_) | |||
| def test_forward_jacobian(): | |||
| forward_jacobian = Net2() | |||
| x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) | |||
| ans = forward_jacobian(Tensor(x, dtype=dtype.float32)) | |||
| expected = np.log(np.exp(2 * x) / (1 + np.exp(2.0 * x))) | |||
| tol = 1e-6 | |||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | |||
| class Net3(nn.Cell): | |||
| """ | |||
| Test class: Backward Jacobian. | |||
| """ | |||
| def __init__(self): | |||
| super(Net3, self).__init__() | |||
| self.bijector = msb.Softplus(sharpness=2.0) | |||
| def construct(self, x_): | |||
| return self.bijector.inverse_log_jacobian(x_) | |||
| def test_backward_jacobian(): | |||
| backward_jacobian = Net3() | |||
| x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) | |||
| ans = backward_jacobian(Tensor(x, dtype=dtype.float32)) | |||
| expected = np.log(np.exp(2.0 * x) / np.expm1(2.0 * x)) | |||
| tol = 1e-6 | |||
| assert (np.abs(ans.asnumpy() - expected) < tol).all() | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cases for exp""" | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability.bijector as msb | |||
| from mindspore import Tensor | |||
| @@ -21,8 +22,10 @@ from mindspore import dtype | |||
| def test_init(): | |||
| b = msb.Exp() | |||
| assert isinstance(b, msb.Bijector) | |||
| b = msb.Exp(1.0) | |||
| assert isinstance(b, msb.Bijector) | |||
| def test_type(): | |||
| with pytest.raises(TypeError): | |||
| msb.Exp(name=0.1) | |||
| class Net(nn.Cell): | |||
| """ | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cases for powertransform""" | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability.bijector as msb | |||
| from mindspore import Tensor | |||
| @@ -24,6 +25,12 @@ def test_init(): | |||
| b = msb.PowerTransform(1) | |||
| assert isinstance(b, msb.Bijector) | |||
| def test_type(): | |||
| with pytest.raises(TypeError): | |||
| msb.PowerTransform(power='power') | |||
| with pytest.raises(TypeError): | |||
| msb.PowerTransform(name=0.1) | |||
| class Net(nn.Cell): | |||
| """ | |||
| Test class: forward and inverse pass of bijector. | |||
| @@ -0,0 +1,139 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cases for scalar affine""" | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability.bijector as msb | |||
| from mindspore import Tensor | |||
| from mindspore import dtype | |||
| def test_init(): | |||
| """ | |||
| Test initializations. | |||
| """ | |||
| b = msb.ScalarAffine() | |||
| assert isinstance(b, msb.Bijector) | |||
| b = msb.ScalarAffine(scale=1.0) | |||
| assert isinstance(b, msb.Bijector) | |||
| b = msb.ScalarAffine(shift=2.0) | |||
| assert isinstance(b, msb.Bijector) | |||
| b = msb.ScalarAffine(3.0, 4.0) | |||
| assert isinstance(b, msb.Bijector) | |||
| def test_type(): | |||
| with pytest.raises(TypeError): | |||
| msb.ScalarAffine(scale='scale') | |||
| with pytest.raises(TypeError): | |||
| msb.ScalarAffine(shift='shift') | |||
| with pytest.raises(TypeError): | |||
| msb.ScalarAffine(name=0.1) | |||
| class ForwardBackward(nn.Cell): | |||
| """ | |||
| Test class: forward and backward pass. | |||
| """ | |||
| def __init__(self): | |||
| super(ForwardBackward, self).__init__() | |||
| self.b1 = msb.ScalarAffine(2.0, 1.0) | |||
| self.b2 = msb.ScalarAffine() | |||
| def construct(self, x_): | |||
| ans1 = self.b1.inverse(self.b1.forward(x_)) | |||
| ans2 = self.b2.inverse(self.b2.forward(x_)) | |||
| return ans1 + ans2 | |||
| def test_forward_and_backward_pass(): | |||
| """ | |||
| Test forward and backward pass of ScalarAffine bijector. | |||
| """ | |||
| net = ForwardBackward() | |||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||
| ans = net(x) | |||
| assert isinstance(ans, Tensor) | |||
| class ForwardJacobian(nn.Cell): | |||
| """ | |||
| Test class: Forward log Jacobian. | |||
| """ | |||
| def __init__(self): | |||
| super(ForwardJacobian, self).__init__() | |||
| self.b1 = msb.ScalarAffine(2.0, 1.0) | |||
| self.b2 = msb.ScalarAffine() | |||
| def construct(self, x_): | |||
| ans1 = self.b1.forward_log_jacobian(x_) | |||
| ans2 = self.b2.forward_log_jacobian(x_) | |||
| return ans1 + ans2 | |||
| def test_forward_jacobian(): | |||
| """ | |||
| Test forward log jacobian of ScalarAffine bijector. | |||
| """ | |||
| net = ForwardJacobian() | |||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||
| ans = net(x) | |||
| assert isinstance(ans, Tensor) | |||
| class BackwardJacobian(nn.Cell): | |||
| """ | |||
| Test class: Backward log Jacobian. | |||
| """ | |||
| def __init__(self): | |||
| super(BackwardJacobian, self).__init__() | |||
| self.b1 = msb.ScalarAffine(2.0, 1.0) | |||
| self.b2 = msb.ScalarAffine() | |||
| def construct(self, x_): | |||
| ans1 = self.b1.inverse_log_jacobian(x_) | |||
| ans2 = self.b2.inverse_log_jacobian(x_) | |||
| return ans1 + ans2 | |||
| def test_backward_jacobian(): | |||
| """ | |||
| Test backward log jacobian of ScalarAffine bijector. | |||
| """ | |||
| net = BackwardJacobian() | |||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||
| ans = net(x) | |||
| assert isinstance(ans, Tensor) | |||
| class Net(nn.Cell): | |||
| """ | |||
| Test class: function calls going through construct. | |||
| """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.b1 = msb.ScalarAffine(1.0, 0.0) | |||
| self.b2 = msb.ScalarAffine() | |||
| def construct(self, x_): | |||
| ans1 = self.b1('inverse', self.b1('forward', x_)) | |||
| ans2 = self.b2('inverse', self.b2('forward', x_)) | |||
| ans3 = self.b1('forward_log_jacobian', x_) | |||
| ans4 = self.b2('forward_log_jacobian', x_) | |||
| ans5 = self.b1('inverse_log_jacobian', x_) | |||
| ans6 = self.b2('inverse_log_jacobian', x_) | |||
| return ans1 - ans2 + ans3 -ans4 + ans5 - ans6 | |||
| def test_old_api(): | |||
| """ | |||
| Test old api which goes through construct. | |||
| """ | |||
| net = Net() | |||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||
| ans = net(x) | |||
| assert isinstance(ans, Tensor) | |||
| @@ -0,0 +1,133 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cases for scalar affine""" | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability.bijector as msb | |||
| from mindspore import Tensor | |||
| from mindspore import dtype | |||
| def test_init(): | |||
| """ | |||
| Test initializations. | |||
| """ | |||
| b = msb.Softplus() | |||
| assert isinstance(b, msb.Bijector) | |||
| b = msb.Softplus(1.0) | |||
| assert isinstance(b, msb.Bijector) | |||
| def test_type(): | |||
| with pytest.raises(TypeError): | |||
| msb.Softplus(sharpness='sharpness') | |||
| with pytest.raises(TypeError): | |||
| msb.Softplus(name=0.1) | |||
| class ForwardBackward(nn.Cell): | |||
| """ | |||
| Test class: forward and backward pass. | |||
| """ | |||
| def __init__(self): | |||
| super(ForwardBackward, self).__init__() | |||
| self.b1 = msb.Softplus(2.0) | |||
| self.b2 = msb.Softplus() | |||
| def construct(self, x_): | |||
| ans1 = self.b1.inverse(self.b1.forward(x_)) | |||
| ans2 = self.b2.inverse(self.b2.forward(x_)) | |||
| return ans1 + ans2 | |||
| def test_forward_and_backward_pass(): | |||
| """ | |||
| Test forward and backward pass of Softplus bijector. | |||
| """ | |||
| net = ForwardBackward() | |||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||
| ans = net(x) | |||
| assert isinstance(ans, Tensor) | |||
| class ForwardJacobian(nn.Cell): | |||
| """ | |||
| Test class: Forward log Jacobian. | |||
| """ | |||
| def __init__(self): | |||
| super(ForwardJacobian, self).__init__() | |||
| self.b1 = msb.Softplus(2.0) | |||
| self.b2 = msb.Softplus() | |||
| def construct(self, x_): | |||
| ans1 = self.b1.forward_log_jacobian(x_) | |||
| ans2 = self.b2.forward_log_jacobian(x_) | |||
| return ans1 + ans2 | |||
| def test_forward_jacobian(): | |||
| """ | |||
| Test forward log jacobian of Softplus bijector. | |||
| """ | |||
| net = ForwardJacobian() | |||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||
| ans = net(x) | |||
| assert isinstance(ans, Tensor) | |||
| class BackwardJacobian(nn.Cell): | |||
| """ | |||
| Test class: Backward log Jacobian. | |||
| """ | |||
| def __init__(self): | |||
| super(BackwardJacobian, self).__init__() | |||
| self.b1 = msb.Softplus(2.0) | |||
| self.b2 = msb.Softplus() | |||
| def construct(self, x_): | |||
| ans1 = self.b1.inverse_log_jacobian(x_) | |||
| ans2 = self.b2.inverse_log_jacobian(x_) | |||
| return ans1 + ans2 | |||
| def test_backward_jacobian(): | |||
| """ | |||
| Test backward log jacobian of Softplus bijector. | |||
| """ | |||
| net = BackwardJacobian() | |||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||
| ans = net(x) | |||
| assert isinstance(ans, Tensor) | |||
| class Net(nn.Cell): | |||
| """ | |||
| Test class: function calls going through construct. | |||
| """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.b1 = msb.Softplus(1.0) | |||
| self.b2 = msb.Softplus() | |||
| def construct(self, x_): | |||
| ans1 = self.b1('inverse', self.b1('forward', x_)) | |||
| ans2 = self.b2('inverse', self.b2('forward', x_)) | |||
| ans3 = self.b1('forward_log_jacobian', x_) | |||
| ans4 = self.b2('forward_log_jacobian', x_) | |||
| ans5 = self.b1('inverse_log_jacobian', x_) | |||
| ans6 = self.b2('inverse_log_jacobian', x_) | |||
| return ans1 - ans2 + ans3 -ans4 + ans5 - ans6 | |||
| def test_old_api(): | |||
| """ | |||
| Test old api which goes through construct. | |||
| """ | |||
| net = Net() | |||
| x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) | |||
| ans = net(x) | |||
| assert isinstance(ans, Tensor) | |||