You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_softplus.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test cases for scalar affine"""
  16. import pytest
  17. import mindspore.nn as nn
  18. import mindspore.nn.probability.bijector as msb
  19. from mindspore import Tensor
  20. from mindspore import dtype
  21. def test_init():
  22. """
  23. Test initializations.
  24. """
  25. b = msb.Softplus()
  26. assert isinstance(b, msb.Bijector)
  27. b = msb.Softplus(1.0)
  28. assert isinstance(b, msb.Bijector)
  29. def test_type():
  30. with pytest.raises(TypeError):
  31. msb.Softplus(sharpness='sharpness')
  32. with pytest.raises(TypeError):
  33. msb.Softplus(name=0.1)
  34. class ForwardBackward(nn.Cell):
  35. """
  36. Test class: forward and backward pass.
  37. """
  38. def __init__(self):
  39. super(ForwardBackward, self).__init__()
  40. self.b1 = msb.Softplus(2.0)
  41. self.b2 = msb.Softplus()
  42. def construct(self, x_):
  43. ans1 = self.b1.inverse(self.b1.forward(x_))
  44. ans2 = self.b2.inverse(self.b2.forward(x_))
  45. return ans1 + ans2
  46. def test_forward_and_backward_pass():
  47. """
  48. Test forward and backward pass of Softplus bijector.
  49. """
  50. net = ForwardBackward()
  51. x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
  52. ans = net(x)
  53. assert isinstance(ans, Tensor)
  54. class ForwardJacobian(nn.Cell):
  55. """
  56. Test class: Forward log Jacobian.
  57. """
  58. def __init__(self):
  59. super(ForwardJacobian, self).__init__()
  60. self.b1 = msb.Softplus(2.0)
  61. self.b2 = msb.Softplus()
  62. def construct(self, x_):
  63. ans1 = self.b1.forward_log_jacobian(x_)
  64. ans2 = self.b2.forward_log_jacobian(x_)
  65. return ans1 + ans2
  66. def test_forward_jacobian():
  67. """
  68. Test forward log jacobian of Softplus bijector.
  69. """
  70. net = ForwardJacobian()
  71. x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
  72. ans = net(x)
  73. assert isinstance(ans, Tensor)
  74. class BackwardJacobian(nn.Cell):
  75. """
  76. Test class: Backward log Jacobian.
  77. """
  78. def __init__(self):
  79. super(BackwardJacobian, self).__init__()
  80. self.b1 = msb.Softplus(2.0)
  81. self.b2 = msb.Softplus()
  82. def construct(self, x_):
  83. ans1 = self.b1.inverse_log_jacobian(x_)
  84. ans2 = self.b2.inverse_log_jacobian(x_)
  85. return ans1 + ans2
  86. def test_backward_jacobian():
  87. """
  88. Test backward log jacobian of Softplus bijector.
  89. """
  90. net = BackwardJacobian()
  91. x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
  92. ans = net(x)
  93. assert isinstance(ans, Tensor)
  94. class Net(nn.Cell):
  95. """
  96. Test class: function calls going through construct.
  97. """
  98. def __init__(self):
  99. super(Net, self).__init__()
  100. self.b1 = msb.Softplus(1.0)
  101. self.b2 = msb.Softplus()
  102. def construct(self, x_):
  103. ans1 = self.b1('inverse', self.b1('forward', x_))
  104. ans2 = self.b2('inverse', self.b2('forward', x_))
  105. ans3 = self.b1('forward_log_jacobian', x_)
  106. ans4 = self.b2('forward_log_jacobian', x_)
  107. ans5 = self.b1('inverse_log_jacobian', x_)
  108. ans6 = self.b2('inverse_log_jacobian', x_)
  109. return ans1 - ans2 + ans3 -ans4 + ans5 - ans6
  110. def test_old_api():
  111. """
  112. Test old api which goes through construct.
  113. """
  114. net = Net()
  115. x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
  116. ans = net(x)
  117. assert isinstance(ans, Tensor)